mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-31 06:14:38 +08:00 
			
		
		
		
	Compare commits
	
		
			413 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 7d092fc32c | |||
| 1a6c27f271 | |||
| 3c6fd286b4 | |||
| 536fd33003 | |||
| 619b9f5c7e | |||
| d1b689c445 | |||
| 9854dc9040 | |||
| ff5c60fad8 | |||
| 6f1229f91d | |||
| 1819fbda63 | |||
| 7f0367109e | |||
| fb14d53cf6 | |||
| b024a42e93 | |||
| cb97f2bfc5 | |||
| 359200f6ac | |||
| 220aee902a | |||
| 67d25eca05 | |||
| 363528de27 | |||
| 4ff61ababa | |||
| 0ec3779df7 | |||
| b616f6a53d | |||
| 2e25bb12a8 | |||
| 9965c47d0d | |||
| 059d4cdb49 | |||
| bdb84e26b0 | |||
| 3dd359147d | |||
| 657f2f301a | |||
| a1aafc827a | |||
| 139508a418 | |||
| d265414dbc | |||
| 48fb076cbc | |||
| c1909e7e8c | |||
| b95877509b | |||
| 706ff13224 | |||
| ccbfb1d1c9 | |||
| 9e5552aa13 | |||
| 0c600b9ab6 | |||
| e303dcf523 | |||
| ae9c4d416f | |||
| d853520b3e | |||
| ba51aea65e | |||
| 8452946c06 | |||
| 2e7cbf2d7d | |||
| 7da296be04 | |||
| b205e8467d | |||
| be0cfb2b68 | |||
| 1a03dd496b | |||
| 27b8017636 | |||
| 9ec1e3065a | |||
| 9dae7d46bf | |||
| 7058d7dd5d | |||
| a0389e0554 | |||
| 3be8d312a2 | |||
| 3abfe22154 | |||
| e81fbefe8a | |||
| 9290de5667 | |||
| 7f280d69c9 | |||
| 02cabff207 | |||
| 3d19d47d91 | |||
| 8acb4badee | |||
| 314af8617c | |||
| 0e96cc9b7e | |||
| ecad851cbd | |||
| ed70f3c64f | |||
| 650d5dbd04 | |||
| 9025a9a705 | |||
| c05596f1a3 | |||
| 787b13389e | |||
| 96453cfa83 | |||
| b1c1fe35a5 | |||
| 08d81f1014 | |||
| 6cc1e7d96d | |||
| 9909726d2a | |||
| 22e9d42040 | |||
| 86debab54c | |||
| be250bbc67 | |||
| 27949354fa | |||
| bd5038af07 | |||
| a2f14dc8f9 | |||
| 92ee7baaf9 | |||
| 7151f92241 | |||
| e28533a16f | |||
| 6d42ce8315 | |||
| ded1fb635b | |||
| 97d9524fe9 | |||
| d8cf819a9a | |||
| 551ef1631a | |||
| 2863befce3 | |||
| 2965c99c86 | |||
| 2062c0723d | |||
| 1c50e100a9 | |||
| 3ee56e26be | |||
| 8fe7fc8634 | |||
| e936e401de | |||
| f5dfa07531 | |||
| 022c58b80f | |||
| 19108ef311 | |||
| 5a52f389dd | |||
| 65b1cbb138 | |||
| 6c9837a761 | |||
| 6f2f53a82d | |||
| 7b1895e6ce | |||
| 4d36693687 | |||
| daec9dea6e | |||
| daceac57c7 | |||
| 8615d9776f | |||
| 7b460c25f9 | |||
| f719772281 | |||
| d45417b804 | |||
| a29e62ea34 | |||
| e53be6f00a | |||
| c329ceca6d | |||
| 3c545c0c3b | |||
| e8c3bd2cd1 | |||
| c6c983053d | |||
| aafabaa0d5 | |||
| 94a55c7681 | |||
| aa0dc77ef5 | |||
| 4ab3ac285e | |||
| d1c956dc0f | |||
| dec197e3e5 | |||
| 6e244ae091 | |||
| cd4cfee689 | |||
| e110930680 | |||
| 8b64c895c0 | |||
| 0740e29b66 | |||
| 44d2e6af63 | |||
| 2d7779f888 | |||
| a57d57fa72 | |||
| 71799fd005 | |||
| e9fd658a73 | |||
| 07b8fae219 | |||
| 562308816c | |||
| 04e1642e32 | |||
| b69781f107 | |||
| 0bceac9810 | |||
| 34878a0b48 | |||
| 6393b03986 | |||
| 0907d507bf | |||
| c894c5dc1f | |||
| 1f5d178e9c | |||
| 27c065df50 | |||
| 84c260caeb | |||
| 167aca45cb | |||
| 0567c8249f | |||
| d188913d99 | |||
| 1d7c29f5fe | |||
| 65397e40f5 | |||
| 9502c38138 | |||
| 2582683566 | |||
| 754b00edb3 | |||
| 296ce95d8e | |||
| 2d7620c3eb | |||
| 55c65ab495 | |||
| 2cc2069970 | |||
| 9f0608fc16 | |||
| 4e0db57fff | |||
| c40692bf9a | |||
| 4734704b30 | |||
| 8b8c209e35 | |||
| 23a04e0895 | |||
| 02c97d9a92 | |||
| e795d723ed | |||
| 8359f4c8d8 | |||
| bf5181583f | |||
| c53fec1fcb | |||
| 0f9e7354f5 | |||
| ba7ba35cda | |||
| 015fab8c2f | |||
| f59fc60fb3 | |||
| 879f69bed3 | |||
| 7108934142 | |||
| 3443aaf8dd | |||
| 2273ec322c | |||
| a6c4b87fbc | |||
| 1afa9948f5 | |||
| 0d06b533a0 | |||
| c01d1c5aba | |||
| ead369845d | |||
| c6e3bba8e6 | |||
| 91f7d9d0b6 | |||
| 8619e7158c | |||
| c635c5f744 | |||
| a045b7e89a | |||
| 981eeca41a | |||
| 26d34eb67e | |||
| 53da4cd397 | |||
| 9a3b88328f | |||
| 3014c920da | |||
| 0eed516951 | |||
| ee5ad8d2c5 | |||
| a738dbb2a1 | |||
| 33d5e29be9 | |||
| 4671ac6e2a | |||
| dd2ccf8dde | |||
| a3bc76e4b5 | |||
| e6327c9b3e | |||
| d0132f025d | |||
| 61f4fc5dc6 | |||
| 68aaeb3749 | |||
| c3649e4fee | |||
| 53243e5c42 | |||
| a6e6604d32 | |||
| b82e0f82cb | |||
| 5111642a6f | |||
| 1bcd15edc7 | |||
| 2ebff5b77c | |||
| f17aec0d63 | |||
| 493c275352 | |||
| f39ab2d4bd | |||
| 4a0f7888a3 | |||
| c4cf260677 | |||
| 33d51f599e | |||
| e91386cde1 | |||
| 2c11a29f0b | |||
| c76a506bd6 | |||
| ec0db6f51c | |||
| c305a2109d | |||
| 202c5df935 | |||
| 2bb246b8f7 | |||
| 4c409cabc2 | |||
| 3b1e4c6a23 | |||
| 2c5302fadd | |||
| caa680fd2e | |||
| c3bf9bad11 | |||
| 6f170f11dd | |||
| 8ca81bb069 | |||
| e773a9e1c2 | |||
| 71baf85ae1 | |||
| 79f2f1c2a1 | |||
| 2e3e3c86dc | |||
| 7e8977fcd4 | |||
| f1e840e842 | |||
| 7771d1de88 | |||
| 71d1219545 | |||
| e384f2f108 | |||
| 089a306f19 | |||
| 5e666f72cd | |||
| e3a3e4db46 | |||
| e41bf15cd0 | |||
| 5aa4a015ce | |||
| b6bad3d186 | |||
| ee9a1531aa | |||
| 10d82f9ac5 | |||
| ea10dd9d9e | |||
| ead2110297 | |||
| 01220ce89a | |||
| 6f68c49220 | |||
| 4719460644 | |||
| 466166dcfd | |||
| 1d0ae26c85 | |||
| 6021999573 | |||
| c7b370c603 | |||
| aa20d10a91 | |||
| 2de12be428 | |||
| 83ca9ae47b | |||
| e2148dc5ea | |||
| b1098b4072 | |||
| 799397ee4f | |||
| 4959915089 | |||
| 8d1e89d946 | |||
| 36239f79dd | |||
| dfada85eee | |||
| ed33349738 | |||
| d49adea1f9 | |||
| 14fdd21d39 | |||
| 04fefe7c9a | |||
| 3b523e38d9 | |||
| 16c16301c8 | |||
| 9206d0ff01 | |||
| a89209b78d | |||
| ffacb222cb | |||
| 12575cfa7a | |||
| 8b6e1d639c | |||
| 735a9de71f | |||
| 257ab95439 | |||
| cca91a7a10 | |||
| f04d604567 | |||
| 19a53b2783 | |||
| eccdc8318c | |||
| 5f52a84685 | |||
| d4629dc43f | |||
| 6e9cc73f67 | |||
| c53711bd63 | |||
| dac8cc49f4 | |||
| a44b1c951d | |||
| b447624ee3 | |||
| cda92307c1 | |||
| bf57ccc5c2 | |||
| ffb2cd6b54 | |||
| ca94d7fa00 | |||
| 5a1c2e15d8 | |||
| 4c8f64faa7 | |||
| 93aee29fdb | |||
| 154d063b9f | |||
| ccd7c05089 | |||
| c48c6c4008 | |||
| aed8468642 | |||
| 5c76b9cdaf | |||
| ddfed314f9 | |||
| 5b3ad5ecf2 | |||
| ede5c4ebdf | |||
| 07334959d8 | |||
| 119f683949 | |||
| 0860087aff | |||
| 6bc7b57315 | |||
| 90f9c2eb5c | |||
| 387bdf0ab9 | |||
| 5e5baa91aa | |||
| 836d4ce140 | |||
| c3fec47bb7 | |||
| 1173804dca | |||
| 4d5424029b | |||
| 3e7506975c | |||
| ee35e96ac3 | |||
| dec66d253b | |||
| 8d120701fd | |||
| f40f763f12 | |||
| 26bc46ef89 | |||
| a77aea59fd | |||
| b692e9cd07 | |||
| 367871a469 | |||
| 92183b41f3 | |||
| c6703d1e0d | |||
| a5e7242d5f | |||
| 91b2c17a55 | |||
| 055915e6ce | |||
| 3d330c4c09 | |||
| 0b73736a0d | |||
| ee1531bc38 | |||
| e13945f9dd | |||
| 08500011d3 | |||
| 861a0a0a39 | |||
| bc956b38d0 | |||
| 294fc1e2c9 | |||
| 2db9044ab6 | |||
| 6fa718a460 | |||
| 06be858828 | |||
| d1e34cc9ac | |||
| bd517eb9fe | |||
| d65668b4e8 | |||
| aafbbd981f | |||
| 0f0874515a | |||
| 3597b06a4f | |||
| 1015296b79 | |||
| ce9dc02c93 | |||
| a24cb91600 | |||
| 7e8d97dd3f | |||
| d70bc7c029 | |||
| ce688ad46e | |||
| cefdb9962d | |||
| ace5cdaff0 | |||
| 6458721108 | |||
| bb4a0decef | |||
| c707cfc12e | |||
| 7b3c9ff91d | |||
| c68698b326 | |||
| e3b12667d4 | |||
| e6aab5de29 | |||
| c57bb199b3 | |||
| dba68f9159 | |||
| a3319f4f04 | |||
| 9d880f594d | |||
| 017ef648e9 | |||
| 4b25ab14e2 | |||
| f98548b9da | |||
| 96846bb360 | |||
| b6efafd9e4 | |||
| 1129e2b1ab | |||
| c742438f8b | |||
| 73e2e0118f | |||
| c9280e6346 | |||
| af09b3f0a0 | |||
| 4f6c42fa0a | |||
| dff680001d | |||
| 2e090bd5df | |||
| 1b0b065eb5 | |||
| d5bdf899e4 | |||
| 7e3e74c97c | |||
| 3f6341bf7f | |||
| e5d35d62f5 | |||
| 2f1c19b245 | |||
| 42f52cc95b | |||
| 97a9465bbc | |||
| c7ea0b56cd | |||
| 29fa5cac1c | |||
| b2d9be6f7d | |||
| 04a55612dd | |||
| 89b0f84e17 | |||
| 497a91e9f7 | |||
| 943ffa5703 | |||
| 5c8d34a42c | |||
| 3c8694eabe | |||
| 7484e1fce2 | |||
| a2142f0196 | |||
| 871d6b7c74 | |||
| 29a38f0352 | |||
| a5115f4ff5 | |||
| 68b4a26149 | |||
| b8e809a057 | |||
| 5039ec2336 | |||
| 7c644ab6d5 | |||
| 2d40665fe8 | |||
| 96ada386b7 | |||
| 1e473b3010 | |||
| 2b1e2111b0 | |||
| a45b979d9f | |||
| 3952731e8f | |||
| 77f0d465d0 | |||
| 22c3c0aa4a | |||
| 33f8dba7c6 | |||
| 5241ca50d6 | |||
| da9b523ce1 | 
| @ -11,7 +11,7 @@ See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performanc | ||||
|  | ||||
| ## Performance benchmark quick overview | ||||
|  | ||||
| **Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!), with different models. | ||||
| **Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) and Intel® Xeon® Processors, with different models. | ||||
|  | ||||
| **Benchmarking Duration**: about 1hr. | ||||
|  | ||||
| @ -31,13 +31,27 @@ Performance benchmark will be triggered when: | ||||
| - A PR being merged into vllm. | ||||
| - Every commit for those PRs with `perf-benchmarks` label AND `ready` label. | ||||
|  | ||||
| Manually Trigger the benchmark | ||||
|  | ||||
| ```bash | ||||
| bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh | ||||
| ``` | ||||
|  | ||||
| Runtime environment variables: | ||||
| - `ON_CPU`: set the value to '1' on Intel® Xeon® Processors. Default value is 0. | ||||
| - `SERVING_JSON`: JSON file to use for the serving tests. Default value is empty string (use default file). | ||||
| - `LATENCY_JSON`: JSON file to use for the latency tests. Default value is empty string (use default file). | ||||
| - `THROUGHPUT_JSON`: JSON file to use for the throughout tests. Default value is empty string (use default file). | ||||
| - `REMOTE_HOST`: IP for the remote vLLM service to benchmark. Default value is empty string. | ||||
| - `REMOTE_PORT`: Port for the remote vLLM service to benchmark. Default value is empty string. | ||||
|  | ||||
| Nightly benchmark will be triggered when: | ||||
| - Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label. | ||||
|  | ||||
| ## Performance benchmark details | ||||
|  | ||||
| See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. | ||||
|  | ||||
| > NOTE: For Intel® Xeon® Processors, use `tests/latency-tests-cpu.json`, `tests/throughput-tests-cpu.json`, `tests/serving-tests-cpu.json` instead. | ||||
| ### Latency test | ||||
|  | ||||
| Here is an example of one test inside `latency-tests.json`: | ||||
| @ -119,6 +133,30 @@ If you do not see the table, please wait till the benchmark finish running. | ||||
| The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file. | ||||
| The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking. | ||||
|  | ||||
| The `compare-json-results.py` helps to compare benchmark results JSON files converted using `convert-results-json-to-markdown.py`. | ||||
| When run, benchmark script generates results under `benchmark/results` folder, along with the `benchmark_results.md` and `benchmark_results.json`. | ||||
| `compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT. | ||||
|  | ||||
| Here is an example using the script to compare result_a and result_b without detail test name. | ||||
| `python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json --ignore_test_name` | ||||
|  | ||||
| |    | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio        | | ||||
| |----|----------------------------------------|----------------------------------------|----------| | ||||
| | 0  | 142.633982                             | 156.526018                             | 1.097396 | | ||||
| | 1  | 241.620334                             | 294.018783                             | 1.216863 | | ||||
| | 2  | 218.298905                             | 262.664916                             | 1.203235 | | ||||
| | 3  | 242.743860                             | 299.816190                             | 1.235113 | | ||||
|  | ||||
| Here is an example using the script to compare result_a and result_b with detail test name. | ||||
| `python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json` | ||||
| |   | results_a/benchmark_results.json_name | results_a/benchmark_results.json | results_b/benchmark_results.json_name | results_b/benchmark_results.json | perf_ratio        | | ||||
| |---|---------------------------------------------|----------------------------------------|---------------------------------------------|----------------------------------------|----------| | ||||
| | 0 | serving_llama8B_tp1_sharegpt_qps_1          | 142.633982                             | serving_llama8B_tp1_sharegpt_qps_1          | 156.526018                             | 1.097396 | | ||||
| | 1 | serving_llama8B_tp1_sharegpt_qps_16         | 241.620334                             | serving_llama8B_tp1_sharegpt_qps_16         | 294.018783                             | 1.216863 | | ||||
| | 2 | serving_llama8B_tp1_sharegpt_qps_4          | 218.298905                             | serving_llama8B_tp1_sharegpt_qps_4          | 262.664916                             | 1.203235 | | ||||
| | 3 | serving_llama8B_tp1_sharegpt_qps_inf        | 242.743860                             | serving_llama8B_tp1_sharegpt_qps_inf        | 299.816190                             | 1.235113 | | ||||
| | 4 | serving_llama8B_tp2_random_1024_128_qps_1   | 96.613390                              | serving_llama8B_tp4_random_1024_128_qps_1   | 108.404853                             | 1.122048 | | ||||
|  | ||||
| ## Nightly test details | ||||
|  | ||||
| See [nightly-descriptions.md](nightly-descriptions.md) for the detailed description on test workload, models and docker containers of benchmarking other llm engines. | ||||
|  | ||||
| @ -16,7 +16,7 @@ Please download the visualization scripts in the post | ||||
|   - Download `nightly-benchmarks.zip`. | ||||
|   - In the same folder, run the following code: | ||||
|  | ||||
|   ```console | ||||
|   ```bash | ||||
|   export HF_TOKEN=<your HF token> | ||||
|   apt update | ||||
|   apt install -y git | ||||
|  | ||||
| @ -4,7 +4,8 @@ | ||||
| - Input length: 32 tokens. | ||||
| - Output length: 128 tokens. | ||||
| - Batch size: fixed (8). | ||||
| - Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. | ||||
| - GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. | ||||
| - CPU Models: llama-3.1 8B. | ||||
| - Evaluation metrics: end-to-end latency (mean, median, p99). | ||||
|  | ||||
| {latency_tests_markdown_table} | ||||
| @ -14,7 +15,8 @@ | ||||
| - Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). | ||||
| - Output length: the corresponding output length of these 200 prompts. | ||||
| - Batch size: dynamically determined by vllm to achieve maximum throughput. | ||||
| - Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. | ||||
| - GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. | ||||
| - CPU Models: llama-3.1 8B. | ||||
| - Evaluation metrics: throughput. | ||||
|  | ||||
| {throughput_tests_markdown_table} | ||||
| @ -25,12 +27,18 @@ | ||||
| - Output length: the corresponding output length of these 200 prompts. | ||||
| - Batch size: dynamically determined by vllm and the arrival pattern of the requests. | ||||
| - **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed). | ||||
| - Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. | ||||
| - We also added a speculative decoding test for llama-3 70B, under QPS 2 | ||||
| - GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. | ||||
| - We also added a speculative decoding test for llama-3 70B on GPU, under QPS 2 | ||||
| - CPU Models: llama-3.1 8B. | ||||
| - Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99). | ||||
| - For CPU, we added random dataset tests to benchmark fixed input/output length with 100 prompts. | ||||
|  | ||||
| {serving_tests_markdown_table} | ||||
|  | ||||
| ## Platform Information | ||||
|  | ||||
| {platform_markdown_table} | ||||
|  | ||||
| ## json version of the benchmarking tables | ||||
|  | ||||
| This section contains the data of the markdown tables above in JSON format. | ||||
|  | ||||
| @ -0,0 +1,66 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
| import argparse | ||||
|  | ||||
| import pandas as pd | ||||
|  | ||||
|  | ||||
| def compare_data_columns( | ||||
|     files, name_column, data_column, drop_column, ignore_test_name=False | ||||
| ): | ||||
|     print("\ncompare_data_column: " + data_column) | ||||
|     frames = [] | ||||
|     compare_frames = [] | ||||
|     for file in files: | ||||
|         data_df = pd.read_json(file) | ||||
|         serving_df = data_df.dropna(subset=[drop_column], ignore_index=True) | ||||
|         if ignore_test_name is False: | ||||
|             serving_df = serving_df.rename(columns={name_column: file + "_name"}) | ||||
|             frames.append(serving_df[file + "_name"]) | ||||
|         serving_df = serving_df.rename(columns={data_column: file}) | ||||
|         frames.append(serving_df[file]) | ||||
|         compare_frames.append(serving_df[file]) | ||||
|         if len(compare_frames) >= 2: | ||||
|             # Compare numbers among two files | ||||
|             ratio_df = compare_frames[1] / compare_frames[0] | ||||
|             frames.append(ratio_df) | ||||
|             compare_frames.pop(1) | ||||
|  | ||||
|     concat_df = pd.concat(frames, axis=1) | ||||
|     return concat_df | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument( | ||||
|         "-f", "--file", action="append", type=str, help="input file name" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--ignore_test_name", action="store_true", help="ignore_test_name or not" | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|     files = args.file | ||||
|     print("comparing : " + ", ".join(files)) | ||||
|  | ||||
|     drop_column = "P99" | ||||
|     name_column = "Test name" | ||||
|     data_cols_to_compare = ["Output Tput (tok/s)", "Median TTFT (ms)", "Median"] | ||||
|     html_msgs_for_data_cols = [ | ||||
|         "Compare Output Tokens /n", | ||||
|         "Median TTFT /n", | ||||
|         "Median TPOT /n", | ||||
|     ] | ||||
|     ignore_test_name = args.ignore_test_name | ||||
|     with open("perf_comparison.html", "w") as text_file: | ||||
|         for i in range(len(data_cols_to_compare)): | ||||
|             output_df = compare_data_columns( | ||||
|                 files, | ||||
|                 name_column, | ||||
|                 data_cols_to_compare[i], | ||||
|                 drop_column, | ||||
|                 ignore_test_name=ignore_test_name, | ||||
|             ) | ||||
|             print(output_df) | ||||
|             html = output_df.to_html() | ||||
|             text_file.write(html_msgs_for_data_cols[i]) | ||||
|             text_file.write(html) | ||||
| @ -3,9 +3,11 @@ | ||||
|  | ||||
| import json | ||||
| import os | ||||
| from importlib import util | ||||
| from pathlib import Path | ||||
|  | ||||
| import pandas as pd | ||||
| import psutil | ||||
| from tabulate import tabulate | ||||
|  | ||||
| results_folder = Path("results/") | ||||
| @ -29,11 +31,11 @@ throughput_results = [] | ||||
| throughput_results_column_mapping = { | ||||
|     "test_name": "Test name", | ||||
|     "gpu_type": "GPU", | ||||
|     # "num_requests": "# of req.", | ||||
|     # "total_num_tokens": "Total # of tokens", | ||||
|     # "elapsed_time": "Elapsed time (s)", | ||||
|     "num_requests": "# of req.", | ||||
|     "total_num_tokens": "Total # of tokens", | ||||
|     "elapsed_time": "Elapsed time (s)", | ||||
|     "requests_per_second": "Tput (req/s)", | ||||
|     # "tokens_per_second": "Tput (tok/s)", | ||||
|     "tokens_per_second": "Tput (tok/s)", | ||||
| } | ||||
|  | ||||
| # serving results and the keys that will be printed into markdown | ||||
| @ -41,16 +43,18 @@ serving_results = [] | ||||
| serving_column_mapping = { | ||||
|     "test_name": "Test name", | ||||
|     "gpu_type": "GPU", | ||||
|     # "completed": "# of req.", | ||||
|     "completed": "# of req.", | ||||
|     "request_throughput": "Tput (req/s)", | ||||
|     # "input_throughput": "Input Tput (tok/s)", | ||||
|     # "output_throughput": "Output Tput (tok/s)", | ||||
|     "total_token_throughput": "Total Token Tput (tok/s)", | ||||
|     "output_throughput": "Output Tput (tok/s)", | ||||
|     "total_input_tokens": "Total input tokens", | ||||
|     "total_output_tokens": "Total output tokens", | ||||
|     "mean_ttft_ms": "Mean TTFT (ms)", | ||||
|     "median_ttft_ms": "Median TTFT (ms)", | ||||
|     "p99_ttft_ms": "P99 TTFT (ms)", | ||||
|     # "mean_tpot_ms": "Mean TPOT (ms)", | ||||
|     # "median_tpot_ms": "Median", | ||||
|     # "p99_tpot_ms": "P99", | ||||
|     "mean_tpot_ms": "Mean TPOT (ms)", | ||||
|     "median_tpot_ms": "Median", | ||||
|     "p99_tpot_ms": "P99", | ||||
|     "mean_itl_ms": "Mean ITL (ms)", | ||||
|     "median_itl_ms": "Median ITL (ms)", | ||||
|     "p99_itl_ms": "P99 ITL (ms)", | ||||
| @ -75,6 +79,20 @@ def results_to_json(latency, throughput, serving): | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def get_size_with_unit(bytes, suffix="B"): | ||||
|     """ | ||||
|     Scale bytes to its proper format | ||||
|     e.g: | ||||
|         1253656 => '1.20MB' | ||||
|         1253656678 => '1.17GB' | ||||
|     """ | ||||
|     factor = 1024 | ||||
|     for unit in ["", "K", "M", "G", "T", "P"]: | ||||
|         if bytes < factor: | ||||
|             return f"{bytes:.2f}{unit}{suffix}" | ||||
|         bytes /= factor | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     # collect results | ||||
|     for test_file in results_folder.glob("*.json"): | ||||
| @ -155,6 +173,27 @@ if __name__ == "__main__": | ||||
|     serving_results = pd.DataFrame.from_dict(serving_results) | ||||
|     throughput_results = pd.DataFrame.from_dict(throughput_results) | ||||
|  | ||||
|     svmem = psutil.virtual_memory() | ||||
|     platform_data = { | ||||
|         "Physical cores": [psutil.cpu_count(logical=False)], | ||||
|         "Total cores": [psutil.cpu_count(logical=True)], | ||||
|         "Total Memory": [get_size_with_unit(svmem.total)], | ||||
|     } | ||||
|  | ||||
|     if util.find_spec("numa") is not None: | ||||
|         from numa import info | ||||
|  | ||||
|         platform_data["Total NUMA nodes"] = [info.get_num_configured_nodes()] | ||||
|  | ||||
|     if util.find_spec("cpuinfo") is not None: | ||||
|         from cpuinfo import get_cpu_info | ||||
|  | ||||
|         platform_data["CPU Brand"] = [get_cpu_info()["brand_raw"]] | ||||
|  | ||||
|     platform_results = pd.DataFrame.from_dict( | ||||
|         platform_data, orient="index", columns=["Platform Info"] | ||||
|     ) | ||||
|  | ||||
|     raw_results_json = results_to_json( | ||||
|         latency_results, throughput_results, serving_results | ||||
|     ) | ||||
| @ -200,6 +239,9 @@ if __name__ == "__main__": | ||||
|     throughput_md_table = tabulate( | ||||
|         throughput_results, headers="keys", tablefmt="pipe", showindex=False | ||||
|     ) | ||||
|     platform_md_table = tabulate( | ||||
|         platform_results, headers="keys", tablefmt="pipe", showindex=True | ||||
|     ) | ||||
|  | ||||
|     # document the result | ||||
|     with open(results_folder / "benchmark_results.md", "w") as f: | ||||
| @ -211,6 +253,7 @@ if __name__ == "__main__": | ||||
|             latency_tests_markdown_table=latency_md_table, | ||||
|             throughput_tests_markdown_table=throughput_md_table, | ||||
|             serving_tests_markdown_table=serving_md_table, | ||||
|             platform_markdown_table=platform_md_table, | ||||
|             benchmarking_results_in_json_string=processed_results_json, | ||||
|         ) | ||||
|         f.write(results) | ||||
|  | ||||
| @ -31,6 +31,20 @@ check_gpus() { | ||||
|   echo "GPU type is $gpu_type" | ||||
| } | ||||
|  | ||||
| check_cpus() { | ||||
|   # check the number of CPUs and NUMA Node and GPU type. | ||||
|   declare -g numa_count=$(python3 -c  "from numa import info;numa_size = info.get_num_configured_nodes(); print(numa_size)") | ||||
|   if [[ $numa_count -gt 0 ]]; then | ||||
|     echo "NUMA found." | ||||
|     echo $numa_count | ||||
|   else | ||||
|     echo "Need at least 1 NUMA to run benchmarking." | ||||
|     exit 1 | ||||
|   fi | ||||
|   declare -g gpu_type="cpu" | ||||
|   echo "GPU type is $gpu_type" | ||||
| } | ||||
|  | ||||
| check_hf_token() { | ||||
|   # check if HF_TOKEN is available and valid | ||||
|   if [[ -z "$HF_TOKEN" ]]; then | ||||
| @ -69,6 +83,22 @@ json2args() { | ||||
|   echo "$args" | ||||
| } | ||||
|  | ||||
| json2envs() { | ||||
|   # transforms the JSON string to environment variables. | ||||
|   # example: | ||||
|   # input: { "VLLM_CPU_KVCACHE_SPACE": 5 } | ||||
|   # output: VLLM_CPU_KVCACHE_SPACE=5 | ||||
|   local json_string=$1 | ||||
|   local args=$( | ||||
|     echo "$json_string" | jq -r ' | ||||
|       to_entries | | ||||
|       map((.key ) + "=" + (.value | tostring)) | | ||||
|       join(" ") | ||||
|     ' | ||||
|   ) | ||||
|   echo "$args" | ||||
| } | ||||
|  | ||||
| wait_for_server() { | ||||
|   # wait for vllm server to start | ||||
|   # return 1 if vllm server crashes | ||||
| @ -158,15 +188,24 @@ run_latency_tests() { | ||||
|     # get arguments | ||||
|     latency_params=$(echo "$params" | jq -r '.parameters') | ||||
|     latency_args=$(json2args "$latency_params") | ||||
|     latency_environment_variables=$(echo "$params" | jq -r '.environment_variables') | ||||
|     latency_envs=$(json2envs "$latency_environment_variables") | ||||
|  | ||||
|     # check if there is enough GPU to run the test | ||||
|     tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size') | ||||
|     if [[ $gpu_count -lt $tp ]]; then | ||||
|       echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." | ||||
|       continue | ||||
|     if [ "$ON_CPU" == "1" ];then | ||||
|       if [[ $numa_count -lt $tp ]]; then | ||||
|         echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name." | ||||
|         continue | ||||
|       fi | ||||
|     else | ||||
|       if [[ $gpu_count -lt $tp ]]; then | ||||
|         echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." | ||||
|         continue | ||||
|       fi | ||||
|     fi | ||||
|  | ||||
|     latency_command="python3 benchmark_latency.py \ | ||||
|     latency_command=" $latency_envs python3 benchmark_latency.py \ | ||||
|       --output-json $RESULTS_FOLDER/${test_name}.json \ | ||||
|       $latency_args" | ||||
|  | ||||
| @ -216,15 +255,24 @@ run_throughput_tests() { | ||||
|     # get arguments | ||||
|     throughput_params=$(echo "$params" | jq -r '.parameters') | ||||
|     throughput_args=$(json2args "$throughput_params") | ||||
|     throughput_environment_variables=$(echo "$params" | jq -r '.environment_variables') | ||||
|     throughput_envs=$(json2envs "$throughput_environment_variables") | ||||
|  | ||||
|     # check if there is enough GPU to run the test | ||||
|     tp=$(echo "$throughput_params" | jq -r '.tensor_parallel_size') | ||||
|     if [[ $gpu_count -lt $tp ]]; then | ||||
|       echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." | ||||
|       continue | ||||
|     if [ "$ON_CPU" == "1" ];then | ||||
|       if [[ $numa_count -lt $tp ]]; then | ||||
|         echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name." | ||||
|         continue | ||||
|       fi | ||||
|     else | ||||
|       if [[ $gpu_count -lt $tp ]]; then | ||||
|         echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." | ||||
|         continue | ||||
|       fi | ||||
|     fi | ||||
|  | ||||
|     throughput_command="python3 benchmark_throughput.py \ | ||||
|     throughput_command=" $throughput_envs python3 benchmark_throughput.py \ | ||||
|       --output-json $RESULTS_FOLDER/${test_name}.json \ | ||||
|       $throughput_args" | ||||
|  | ||||
| @ -272,18 +320,27 @@ run_serving_tests() { | ||||
|  | ||||
|     # get client and server arguments | ||||
|     server_params=$(echo "$params" | jq -r '.server_parameters') | ||||
|     server_envs=$(echo "$params" | jq -r '.server_environment_variables') | ||||
|     client_params=$(echo "$params" | jq -r '.client_parameters') | ||||
|     server_args=$(json2args "$server_params") | ||||
|     server_envs=$(json2envs "$server_envs") | ||||
|     client_args=$(json2args "$client_params") | ||||
|     qps_list=$(echo "$params" | jq -r '.qps_list') | ||||
|     qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') | ||||
|     echo "Running over qps list $qps_list" | ||||
|  | ||||
|     # check if there is enough GPU to run the test | ||||
|     # check if there is enough resources to run the test | ||||
|     tp=$(echo "$server_params" | jq -r '.tensor_parallel_size') | ||||
|     if [[ $gpu_count -lt $tp ]]; then | ||||
|       echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." | ||||
|       continue | ||||
|     if [ "$ON_CPU" == "1" ];then | ||||
|       if [[ $numa_count -lt $tp ]]; then | ||||
|         echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name." | ||||
|         continue | ||||
|       fi | ||||
|     else | ||||
|       if [[ $gpu_count -lt $tp ]]; then | ||||
|         echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." | ||||
|         continue | ||||
|       fi | ||||
|     fi | ||||
|  | ||||
|     # check if server model and client model is aligned | ||||
| @ -294,23 +351,33 @@ run_serving_tests() { | ||||
|       continue | ||||
|     fi | ||||
|  | ||||
|     server_command="python3 \ | ||||
|     server_command="$server_envs python3 \ | ||||
|       -m vllm.entrypoints.openai.api_server \ | ||||
|       $server_args" | ||||
|  | ||||
|     # run the server | ||||
|     echo "Running test case $test_name" | ||||
|     echo "Server command: $server_command" | ||||
|     bash -c "$server_command" & | ||||
|     server_pid=$! | ||||
|  | ||||
|     # wait until the server is alive | ||||
|     if wait_for_server; then | ||||
|       echo "" | ||||
|       echo "vllm server is up and running." | ||||
|     # support remote vllm server | ||||
|     client_remote_args="" | ||||
|     if [[ -z "${REMOTE_HOST}" ]]; then | ||||
|       bash -c "$server_command" & | ||||
|       server_pid=$! | ||||
|       # wait until the server is alive | ||||
|       if wait_for_server; then | ||||
|         echo "" | ||||
|         echo "vLLM server is up and running." | ||||
|       else | ||||
|         echo "" | ||||
|         echo "vLLM failed to start within the timeout period." | ||||
|       fi | ||||
|     else | ||||
|       echo "" | ||||
|       echo "vllm failed to start within the timeout period." | ||||
|       server_command="Using Remote Server $REMOTE_HOST $REMOTE_PORT" | ||||
|       if [[ ${REMOTE_PORT} ]]; then | ||||
|         client_remote_args=" --host=$REMOTE_HOST --port=$REMOTE_PORT " | ||||
|       else | ||||
|         client_remote_args=" --host=$REMOTE_HOST " | ||||
|       fi | ||||
|     fi | ||||
|  | ||||
|     # iterate over different QPS | ||||
| @ -332,7 +399,7 @@ run_serving_tests() { | ||||
|         --result-filename ${new_test_name}.json \ | ||||
|         --request-rate $qps \ | ||||
|         --metadata "tensor_parallel_size=$tp" \ | ||||
|         $client_args" | ||||
|         $client_args $client_remote_args " | ||||
|  | ||||
|       echo "Running test case $test_name with qps $qps" | ||||
|       echo "Client command: $client_command" | ||||
| @ -360,7 +427,14 @@ run_serving_tests() { | ||||
| } | ||||
|  | ||||
| main() { | ||||
|   check_gpus | ||||
|   local ARCH | ||||
|   ARCH='' | ||||
|   if [ "$ON_CPU" == "1" ];then | ||||
|      check_cpus | ||||
|      ARCH='-cpu' | ||||
|   else | ||||
|      check_gpus | ||||
|   fi | ||||
|   check_hf_token | ||||
|  | ||||
|   # Set to v1 to run v1 benchmark | ||||
| @ -386,9 +460,9 @@ main() { | ||||
|   QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ | ||||
|  | ||||
|   # benchmarking | ||||
|   run_serving_tests $QUICK_BENCHMARK_ROOT/tests/serving-tests.json | ||||
|   run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json | ||||
|   run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json | ||||
|   run_serving_tests $QUICK_BENCHMARK_ROOT/tests/"${SERVING_JSON:-serving-tests$ARCH.json}" | ||||
|   run_latency_tests $QUICK_BENCHMARK_ROOT/tests/"${LATENCY_JSON:-latency-tests$ARCH.json}" | ||||
|   run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/"${THROUGHPUT_JSON:-throughput-tests$ARCH.json}" | ||||
|  | ||||
|   # postprocess benchmarking results | ||||
|   pip install tabulate pandas | ||||
|  | ||||
							
								
								
									
										30
									
								
								.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,30 @@ | ||||
| [ | ||||
|     { | ||||
|         "test_name": "latency_llama8B_tp1", | ||||
|         "environment_variables": { | ||||
| 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||
| 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||
|         }, | ||||
|         "parameters": { | ||||
|             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||
|             "tensor_parallel_size": 1, | ||||
|             "load_format": "dummy", | ||||
|             "num_iters_warmup": 5, | ||||
|             "num_iters": 15 | ||||
|         } | ||||
|     }, | ||||
|     { | ||||
|         "test_name": "latency_llama8B_tp4", | ||||
|         "environment_variables": { | ||||
| 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||
| 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||
|         }, | ||||
|         "parameters": { | ||||
|             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||
|             "tensor_parallel_size": 4, | ||||
|             "load_format": "dummy", | ||||
|             "num_iters_warmup": 5, | ||||
|             "num_iters": 15 | ||||
|         } | ||||
|     } | ||||
| ] | ||||
							
								
								
									
										158
									
								
								.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										158
									
								
								.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,158 @@ | ||||
| [ | ||||
|     { | ||||
|         "test_name": "serving_llama8B_tp1_sharegpt", | ||||
|         "qps_list": [1, 4, 16, "inf"], | ||||
|         "server_environment_variables": { | ||||
|             "VLLM_RPC_TIMEOUT": 100000, | ||||
| 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||
| 	    "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, | ||||
| 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||
|         }, | ||||
|         "server_parameters": { | ||||
|             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||
|             "tensor_parallel_size": 1, | ||||
| 	    "dtype": "bfloat16", | ||||
| 	    "distributed_executor_backend": "mp", | ||||
| 	    "block_size": 128, | ||||
| 	    "trust_remote_code": "", | ||||
|             "disable_log_stats": "", | ||||
|             "disable_log_requests": "", | ||||
| 	    "enforce_eager": "", | ||||
|             "load_format": "dummy" | ||||
|         }, | ||||
|         "client_parameters": { | ||||
|             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||
|             "backend": "vllm", | ||||
|             "dataset_name": "sharegpt", | ||||
|             "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", | ||||
| 	    "max_concurrency": 60, | ||||
|             "num_prompts": 200 | ||||
|         } | ||||
|     }, | ||||
|     { | ||||
|         "test_name": "serving_llama8B_tp2_sharegpt", | ||||
|         "qps_list": [1, 4, 16, "inf"], | ||||
|         "server_environment_variables": { | ||||
|             "VLLM_RPC_TIMEOUT": 100000, | ||||
| 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||
| 	    "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, | ||||
| 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||
|         }, | ||||
|         "server_parameters": { | ||||
|             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||
|             "tensor_parallel_size": 2, | ||||
| 	    "dtype": "bfloat16", | ||||
| 	    "distributed_executor_backend": "mp", | ||||
| 	    "block_size": 128, | ||||
| 	    "trust_remote_code": "", | ||||
|             "disable_log_stats": "", | ||||
|             "disable_log_requests": "", | ||||
| 	    "enforce_eager": "", | ||||
|             "load_format": "dummy" | ||||
|         }, | ||||
|         "client_parameters": { | ||||
|             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||
|             "backend": "vllm", | ||||
|             "dataset_name": "sharegpt", | ||||
|             "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", | ||||
| 	    "max_concurrency": 60, | ||||
|             "num_prompts": 200 | ||||
|         } | ||||
|     }, | ||||
|     { | ||||
|         "test_name": "serving_llama8B_tp4_sharegpt", | ||||
|         "qps_list": [1, 4, 16, "inf"], | ||||
|         "server_environment_variables": { | ||||
|             "VLLM_RPC_TIMEOUT": 100000, | ||||
| 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||
| 	    "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, | ||||
| 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||
|         }, | ||||
|         "server_parameters": { | ||||
|             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||
|             "tensor_parallel_size": 4, | ||||
| 	    "dtype": "bfloat16", | ||||
| 	    "distributed_executor_backend": "mp", | ||||
| 	    "block_size": 128, | ||||
| 	    "trust_remote_code": "", | ||||
|             "disable_log_stats": "", | ||||
|             "disable_log_requests": "", | ||||
| 	    "enforce_eager": "", | ||||
|             "load_format": "dummy" | ||||
|         }, | ||||
|         "client_parameters": { | ||||
|             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||
|             "backend": "vllm", | ||||
|             "dataset_name": "sharegpt", | ||||
|             "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", | ||||
| 	    "max_concurrency": 60, | ||||
|             "num_prompts": 200 | ||||
|         } | ||||
|     }, | ||||
|     { | ||||
|         "test_name": "serving_llama8B_tp4_random_1024_128", | ||||
|         "qps_list": [1, 4, 16, "inf"], | ||||
|         "server_environment_variables": { | ||||
|             "VLLM_RPC_TIMEOUT": 100000, | ||||
| 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||
| 	    "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, | ||||
| 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||
|         }, | ||||
|         "server_parameters": { | ||||
|             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||
|             "tensor_parallel_size": 4, | ||||
| 	    "dtype": "bfloat16", | ||||
| 	    "distributed_executor_backend": "mp", | ||||
| 	    "block_size": 128, | ||||
| 	    "trust_remote_code": "", | ||||
| 	    "enable_chunked_prefill": "", | ||||
|             "disable_log_stats": "", | ||||
|             "disable_log_requests": "", | ||||
| 	    "enforce_eager": "", | ||||
|             "load_format": "dummy" | ||||
|         }, | ||||
|         "client_parameters": { | ||||
|             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||
|             "backend": "vllm", | ||||
|             "dataset_name": "random", | ||||
| 	    "random-input-len": 1024, | ||||
| 	    "random-output-len": 128, | ||||
| 	    "ignore-eos": "", | ||||
| 	    "max_concurrency": 100, | ||||
|             "num_prompts": 100 | ||||
|         } | ||||
|     }, | ||||
|     { | ||||
|         "test_name": "serving_llama8B_pp6_random_1024_128", | ||||
|         "qps_list": [1, 4, 16, "inf"], | ||||
|         "server_environment_variables": { | ||||
|             "VLLM_RPC_TIMEOUT": 100000, | ||||
| 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||
| 	    "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, | ||||
| 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||
|         }, | ||||
|         "server_parameters": { | ||||
|             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||
|             "pipeline_parallel_size": 6, | ||||
| 	    "dtype": "bfloat16", | ||||
| 	    "distributed_executor_backend": "mp", | ||||
| 	    "block_size": 128, | ||||
| 	    "trust_remote_code": "", | ||||
| 	    "enable_chunked_prefill": "", | ||||
|             "disable_log_stats": "", | ||||
|             "disable_log_requests": "", | ||||
| 	    "enforce_eager": "", | ||||
|             "load_format": "dummy" | ||||
|         }, | ||||
|         "client_parameters": { | ||||
|             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||
|             "backend": "vllm", | ||||
|             "dataset_name": "random", | ||||
| 	    "random-input-len": 1024, | ||||
| 	    "random-output-len": 128, | ||||
| 	    "ignore-eos": "", | ||||
| 	    "max_concurrency": 100, | ||||
|             "num_prompts": 100 | ||||
|         } | ||||
|     } | ||||
| ] | ||||
| @ -0,0 +1,32 @@ | ||||
| [ | ||||
|     { | ||||
|         "test_name": "throughput_llama8B_tp1", | ||||
|         "environment_variables": { | ||||
| 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||
| 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||
|         }, | ||||
|         "parameters": { | ||||
|             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||
|             "tensor_parallel_size": 1, | ||||
|             "load_format": "dummy", | ||||
|             "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", | ||||
|             "num_prompts": 200, | ||||
|             "backend": "vllm" | ||||
|         } | ||||
|     }, | ||||
|     { | ||||
|         "test_name": "throughput_llama8B_tp4", | ||||
|         "environment_variables": { | ||||
| 	    "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, | ||||
| 	    "VLLM_CPU_KVCACHE_SPACE": 40 | ||||
|         }, | ||||
|         "parameters": { | ||||
|             "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||||
|             "tensor_parallel_size": 4, | ||||
|             "load_format": "dummy", | ||||
|             "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", | ||||
|             "num_prompts": 200, | ||||
|             "backend": "vllm" | ||||
|         } | ||||
|     } | ||||
| ] | ||||
| @ -101,7 +101,8 @@ steps: | ||||
|       queue: cpu_queue_postmerge | ||||
|     commands: | ||||
|       - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" | ||||
|       - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ." | ||||
|       - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ." | ||||
|       - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest" | ||||
|       - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)" | ||||
|     env: | ||||
|       DOCKER_BUILDKIT: "1" | ||||
| @ -117,6 +118,7 @@ steps: | ||||
|     commands: | ||||
|       - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" | ||||
|       - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest --progress plain -f docker/Dockerfile.neuron ." | ||||
|       - "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest" | ||||
|       - "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version)" | ||||
|     env: | ||||
|       DOCKER_BUILDKIT: "1" | ||||
|  | ||||
| @ -24,13 +24,22 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE | ||||
| numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . | ||||
|  | ||||
| # Run the image, setting --shm-size=4g for tensor parallel. | ||||
| docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" | ||||
| docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 | ||||
| docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" | ||||
| docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 | ||||
|  | ||||
| function cpu_tests() { | ||||
|   set -e | ||||
|   export NUMA_NODE=$2 | ||||
|  | ||||
|   # list packages | ||||
|   docker exec cpu-test-"$NUMA_NODE"-avx2 bash -c " | ||||
|     set -e | ||||
|     pip list" | ||||
|  | ||||
|   docker exec cpu-test-"$NUMA_NODE" bash -c " | ||||
|     set -e | ||||
|     pip list" | ||||
|  | ||||
|   # offline inference | ||||
|   docker exec cpu-test-"$NUMA_NODE"-avx2 bash -c " | ||||
|     set -e | ||||
| @ -42,6 +51,7 @@ function cpu_tests() { | ||||
|     pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model | ||||
|     pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model | ||||
|     pytest -v -s tests/models/language/generation -m cpu_model | ||||
|     VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model | ||||
|     pytest -v -s tests/models/language/pooling -m cpu_model | ||||
|     pytest -v -s tests/models/multimodal/generation \ | ||||
|                 --ignore=tests/models/multimodal/generation/test_mllama.py \ | ||||
| @ -72,7 +82,7 @@ function cpu_tests() { | ||||
|     set -e | ||||
|     python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &  | ||||
|     timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 | ||||
|     python3 benchmarks/benchmark_serving.py \ | ||||
|     VLLM_CPU_CI_ENV=0 python3 benchmarks/benchmark_serving.py \ | ||||
|       --backend vllm \ | ||||
|       --dataset-name random \ | ||||
|       --model facebook/opt-125m \ | ||||
| @ -89,4 +99,4 @@ function cpu_tests() { | ||||
|  | ||||
| # All of CPU tests are expected to be finished less than 40 mins. | ||||
| export -f cpu_tests | ||||
| timeout 1h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" | ||||
| timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" | ||||
|  | ||||
| @ -2,10 +2,34 @@ | ||||
|  | ||||
| # This script build the CPU docker image and run the offline inference inside the container. | ||||
| # It serves a sanity check for compilation and basic model usage. | ||||
| set -ex | ||||
| set -exuo pipefail | ||||
|  | ||||
| # Try building the docker image | ||||
| docker build -t hpu-test-env -f docker/Dockerfile.hpu . | ||||
| cat <<EOF | docker build -t hpu-plugin-v1-test-env -f - . | ||||
| FROM 1.22-413-pt2.7.1:latest | ||||
|  | ||||
| COPY ./ /workspace/vllm | ||||
|  | ||||
| WORKDIR /workspace/vllm | ||||
|  | ||||
| RUN pip install -v -r requirements/hpu.txt | ||||
| RUN pip install git+https://github.com/vllm-project/vllm-gaudi.git | ||||
|  | ||||
| ENV no_proxy=localhost,127.0.0.1 | ||||
| ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true | ||||
|  | ||||
| RUN VLLM_TARGET_DEVICE=hpu python3 setup.py install | ||||
|  | ||||
| # install development dependencies (for testing) | ||||
| RUN python3 -m pip install -e tests/vllm_test_utils | ||||
|  | ||||
| WORKDIR /workspace/ | ||||
|  | ||||
| RUN git clone https://github.com/vllm-project/vllm-gaudi.git | ||||
|  | ||||
| RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks | ||||
|  | ||||
| EOF | ||||
|  | ||||
| # Setup cleanup | ||||
| # certain versions of HPU software stack have a bug that can | ||||
| @ -14,13 +38,21 @@ docker build -t hpu-test-env -f docker/Dockerfile.hpu . | ||||
| # functions, while other platforms only need one remove_docker_container | ||||
| # function. | ||||
| EXITCODE=1 | ||||
| remove_docker_containers() { docker rm -f hpu-test || true; docker rm -f hpu-test-tp2 || true; } | ||||
| remove_docker_containers_and_exit() { remove_docker_containers; exit $EXITCODE; } | ||||
| trap remove_docker_containers_and_exit EXIT | ||||
| remove_docker_containers() { docker rm -f hpu-plugin-v1-test || true; } | ||||
| trap 'remove_docker_containers; exit $EXITCODE;' EXIT | ||||
| remove_docker_containers | ||||
|  | ||||
| # Run the image and launch offline inference | ||||
| docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m | ||||
| docker run --runtime=habana --name=hpu-test-tp2 --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --tensor-parallel-size 2 | ||||
| echo "Running HPU plugin v1 test" | ||||
| docker run --rm --runtime=habana --name=hpu-plugin-v1-test --network=host \ | ||||
|   -e HABANA_VISIBLE_DEVICES=all \ | ||||
|   hpu-plugin-v1-test-env \ | ||||
|   /bin/bash "/workspace/vllm-gaudi/tests/upstream_tests/ci_tests.sh" | ||||
|  | ||||
| EXITCODE=$? | ||||
| if [ $EXITCODE -eq 0 ]; then | ||||
|   echo "Test with basic model passed" | ||||
| else | ||||
|   echo "Test with basic model FAILED with exit code: $EXITCODE" >&2 | ||||
| fi | ||||
|  | ||||
| # The trap will handle the container removal and final exit. | ||||
| @ -54,10 +54,11 @@ docker run --rm -it --device=/dev/neuron0 --network bridge \ | ||||
|        --name "${container_name}" \ | ||||
|        ${image_name} \ | ||||
|        /bin/bash -c " | ||||
|             set -e; # Exit on first error | ||||
|             python3 /workspace/vllm/examples/offline_inference/neuron.py; | ||||
|             python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys; | ||||
|             for f in /workspace/vllm/tests/neuron/2_core/*.py; do | ||||
|                 echo 'Running test file: '$f; | ||||
|                 echo \"Running test file: \$f\"; | ||||
|                 python3 -m pytest \$f -v --capture=tee-sys; | ||||
|             done | ||||
|        " | ||||
| @ -159,6 +159,8 @@ run_and_track_test 14 "test_tpu_qkv_linear.py" \ | ||||
|     "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py" | ||||
| run_and_track_test 15 "test_spmd_model_weight_loading.py" \ | ||||
|     "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py" | ||||
| run_and_track_test 16 "test_kv_cache_update_kernel.py" \ | ||||
|     "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py" | ||||
|  | ||||
| # After all tests have been attempted, exit with the overall status. | ||||
| if [ "$overall_script_exit_code" -ne 0 ]; then | ||||
|  | ||||
| @ -28,4 +28,5 @@ docker run \ | ||||
|     sh -c ' | ||||
|     VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m | ||||
|     VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2 | ||||
|     VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager | ||||
| ' | ||||
|  | ||||
| @ -4,8 +4,8 @@ CONTAINER_NAME=vllm-tpu | ||||
|  | ||||
| # vllm config | ||||
| MODEL=meta-llama/Llama-3.1-8B-Instruct | ||||
| MAX_NUM_SEQS=512 | ||||
| MAX_NUM_BATCHED_TOKENS=512 | ||||
| MAX_NUM_SEQS=256 | ||||
| MAX_NUM_BATCHED_TOKENS=1024 | ||||
| TENSOR_PARALLEL_SIZE=1 | ||||
| MAX_MODEL_LEN=2048 | ||||
| DOWNLOAD_DIR=/mnt/disks/persist | ||||
|  | ||||
| @ -68,7 +68,7 @@ docker run \ | ||||
|  | ||||
| echo "run script..." | ||||
| echo | ||||
| docker exec "$CONTAINER_NAME" /bin/bash -c ".buildkite/scripts/hardware_ci/run_bm.sh" | ||||
| docker exec "$CONTAINER_NAME" /bin/bash -c ".buildkite/scripts/tpu/run_bm.sh" | ||||
|  | ||||
| echo "copy result back..." | ||||
| VLLM_LOG="$LOG_ROOT/$TEST_NAME"_vllm_log.txt | ||||
|  | ||||
							
								
								
									
										14
									
								
								.buildkite/scripts/tpu/quantized_v6e_1.env
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								.buildkite/scripts/tpu/quantized_v6e_1.env
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,14 @@ | ||||
| # Environment config | ||||
| TEST_NAME=llama8bw8a8 | ||||
| CONTAINER_NAME=vllm-tpu | ||||
|  | ||||
| # vllm config | ||||
| MODEL=RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8 | ||||
| MAX_NUM_SEQS=128 | ||||
| MAX_NUM_BATCHED_TOKENS=1024 | ||||
| TENSOR_PARALLEL_SIZE=1 | ||||
| MAX_MODEL_LEN=2048 | ||||
| DOWNLOAD_DIR=/mnt/disks/persist | ||||
| EXPECTED_THROUGHPUT=10.0 | ||||
| INPUT_LEN=1800 | ||||
| OUTPUT_LEN=128 | ||||
| @ -41,6 +41,16 @@ steps: | ||||
|   # TODO: add `--strict` once warnings in docstrings are fixed | ||||
|   - mkdocs build | ||||
|  | ||||
| - label: Pytorch Nightly Dependency Override Check # 2min | ||||
|   # if this test fails, it means the nightly torch version is not compatible with some | ||||
|   # of the dependencies. Please check the error message and add the package to whitelist | ||||
|   # in /vllm/tools/generate_nightly_torch_test.py | ||||
|   soft_fail: true | ||||
|   source_file_dependencies: | ||||
|   - requirements/nightly_torch_test.txt | ||||
|   commands: | ||||
|   - bash standalone_tests/pytorch_nightly_dependency.sh | ||||
|  | ||||
| - label: Async Engine, Inputs, Utils, Worker Test # 24min | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   source_file_dependencies: | ||||
| @ -89,7 +99,7 @@ steps: | ||||
|   - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py | ||||
|  | ||||
| - label: Chunked Prefill Test | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   mirror_hardwares: [amdexperimental, amdproduction] | ||||
|   source_file_dependencies: | ||||
|   - vllm/ | ||||
|   - tests/basic_correctness/test_chunked_prefill | ||||
| @ -145,6 +155,7 @@ steps: | ||||
|   - examples/offline_inference/rlhf_colocate.py | ||||
|   - tests/examples/offline_inference/data_parallel.py | ||||
|   - tests/v1/test_async_llm_dp.py | ||||
|   - tests/v1/test_external_lb_dp.py | ||||
|   - tests/v1/engine/test_engine_core_client.py | ||||
|   commands: | ||||
|   # test with tp=2 and external_dp=2 | ||||
| @ -153,8 +164,9 @@ steps: | ||||
|   # test with tp=2 and pp=2 | ||||
|   - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py | ||||
|   # test with internal dp | ||||
|   - python3 ../examples/offline_inference/data_parallel.py | ||||
|   - python3 ../examples/offline_inference/data_parallel.py --enforce-eager | ||||
|   - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py | ||||
|   - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py | ||||
|   - pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp | ||||
|   - pytest -v -s distributed/test_utils.py | ||||
|   - pytest -v -s compile/test_basic_correctness.py | ||||
| @ -168,6 +180,23 @@ steps: | ||||
|   - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py | ||||
|   - popd | ||||
|  | ||||
| - label: EPLB Algorithm Test | ||||
|   working_dir: "/vllm-workspace/tests" | ||||
|   source_file_dependencies: | ||||
|   - vllm/distributed/eplb | ||||
|   - tests/distributed/test_eplb_algo.py | ||||
|   commands: | ||||
|   - pytest -v -s distributed/test_eplb_algo.py | ||||
|  | ||||
| - label: EPLB Execution Test # 5min | ||||
|   working_dir: "/vllm-workspace/tests" | ||||
|   num_gpus: 4 | ||||
|   source_file_dependencies: | ||||
|   - vllm/distributed/eplb | ||||
|   - tests/distributed/test_eplb_execute.py | ||||
|   commands: | ||||
|   - pytest -v -s distributed/test_eplb_execute.py | ||||
|  | ||||
| - label: Metrics, Tracing Test # 10min | ||||
|   mirror_hardwares: [amdexperimental, amdproduction] | ||||
|   num_gpus: 2 | ||||
| @ -177,13 +206,18 @@ steps: | ||||
|   - tests/tracing | ||||
|   commands: | ||||
|   - pytest -v -s metrics | ||||
|   - "pip install \ | ||||
|       'opentelemetry-sdk>=1.26.0' \ | ||||
|       'opentelemetry-api>=1.26.0' \ | ||||
|       'opentelemetry-exporter-otlp>=1.26.0' \ | ||||
|       'opentelemetry-semantic-conventions-ai>=0.4.1'" | ||||
|   - pytest -v -s tracing | ||||
|  | ||||
| ##### fast check tests  ##### | ||||
| #####  1 GPU test  ##### | ||||
|  | ||||
| - label: Regression Test # 5min | ||||
|   mirror_hardwares: [amdexperimental, amdproduction] | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   source_file_dependencies: | ||||
|   - vllm/ | ||||
|   - tests/test_regression | ||||
| @ -193,7 +227,7 @@ steps: | ||||
|   working_dir: "/vllm-workspace/tests" # optional | ||||
|  | ||||
| - label: Engine Test # 10min | ||||
|   mirror_hardwares: [amdexperimental, amdproduction] | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   source_file_dependencies: | ||||
|   - vllm/ | ||||
|   - tests/engine | ||||
| @ -266,6 +300,15 @@ steps: | ||||
|   commands: | ||||
|     - pytest -v -s prefix_caching | ||||
|  | ||||
|  | ||||
| - label: Platform Tests (CUDA) | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   source_file_dependencies: | ||||
|   - vllm/ | ||||
|   - tests/cuda | ||||
|   commands: | ||||
|     - pytest -v -s cuda/test_cuda_context.py | ||||
|  | ||||
| - label: Samplers Test # 36min | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   source_file_dependencies: | ||||
| @ -297,7 +340,7 @@ steps: | ||||
|   parallelism: 4 | ||||
|  | ||||
| - label: PyTorch Compilation Unit Tests | ||||
|   mirror_hardwares: [amdexperimental, amdproduction] | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   torch_nightly: true | ||||
|   source_file_dependencies: | ||||
|     - vllm/ | ||||
| @ -305,6 +348,7 @@ steps: | ||||
|   commands: | ||||
|     - pytest -v -s compile/test_pass_manager.py | ||||
|     - pytest -v -s compile/test_fusion.py | ||||
|     - pytest -v -s compile/test_fusion_attn.py | ||||
|     - pytest -v -s compile/test_silu_mul_quant_fusion.py | ||||
|     - pytest -v -s compile/test_sequence_parallelism.py | ||||
|     - pytest -v -s compile/test_async_tp.py | ||||
| @ -378,7 +422,7 @@ steps: | ||||
|     - pytest -v -s kernels/mamba | ||||
|  | ||||
| - label: Tensorizer Test # 11min | ||||
|   mirror_hardwares: [amdexperimental, amdproduction] | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   soft_fail: true | ||||
|   source_file_dependencies: | ||||
|   - vllm/model_executor/model_loader | ||||
| @ -470,7 +514,7 @@ steps: | ||||
| #####  models test  ##### | ||||
|  | ||||
| - label: Basic Models Test # 24min | ||||
|   mirror_hardwares: [amdexperimental, amdproduction] | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   torch_nightly: true | ||||
|   source_file_dependencies: | ||||
|   - vllm/ | ||||
| @ -494,6 +538,17 @@ steps: | ||||
|     - pip freeze | grep -E 'torch' | ||||
|     - pytest -v -s models/language -m core_model | ||||
|  | ||||
| - label: Language Models Test (Hybrid) # 35 min | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   torch_nightly: true | ||||
|   source_file_dependencies: | ||||
|   - vllm/ | ||||
|   - tests/models/language/generation | ||||
|   commands: | ||||
|     # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. | ||||
|     - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' | ||||
|     - pytest -v -s models/language/generation -m hybrid_model | ||||
|  | ||||
| - label: Language Models Test (Extended Generation) # 1hr20min | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   optional: true | ||||
| @ -503,7 +558,7 @@ steps: | ||||
|   commands: | ||||
|     # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. | ||||
|     - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' | ||||
|     - pytest -v -s models/language/generation -m 'not core_model' | ||||
|     - pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)' | ||||
|  | ||||
| - label: Language Models Test (Extended Pooling)  # 36min | ||||
|   mirror_hardwares: [amdexperimental] | ||||
| @ -548,7 +603,7 @@ steps: | ||||
|     - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model' | ||||
|  | ||||
| - label: Multi-Modal Models Test (Extended) 3 | ||||
|   mirror_hardwares: [amdexperimental, amdproduction] | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   optional: true | ||||
|   source_file_dependencies: | ||||
|   - vllm/ | ||||
| @ -600,13 +655,18 @@ steps: | ||||
|   - vllm/executor/ | ||||
|   - vllm/model_executor/models/ | ||||
|   - tests/distributed/ | ||||
|   - tests/examples/offline_inference/data_parallel.py | ||||
|   commands: | ||||
|   - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) | ||||
|     - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' | ||||
|     - NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' | ||||
|     - python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code | ||||
|     - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py | ||||
|     - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py | ||||
|   - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) | ||||
|     - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' | ||||
|     - NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' | ||||
|     - python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code | ||||
|  | ||||
| - label: Distributed Tests (2 GPUs) # 40min | ||||
|   mirror_hardwares: [amdexperimental] | ||||
| @ -624,10 +684,12 @@ steps: | ||||
|   - vllm/worker/model_runner.py | ||||
|   - entrypoints/llm/test_collective_rpc.py | ||||
|   - tests/v1/test_async_llm_dp.py | ||||
|   - tests/v1/test_external_lb_dp.py | ||||
|   - tests/v1/entrypoints/openai/test_multi_api_servers.py | ||||
|   - vllm/v1/engine/ | ||||
|   commands: | ||||
|   - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py | ||||
|   - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py | ||||
|   - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py | ||||
|   - pytest -v -s entrypoints/llm/test_collective_rpc.py | ||||
|   - pytest -v -s ./compile/test_basic_correctness.py | ||||
| @ -669,7 +731,7 @@ steps: | ||||
|   - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins | ||||
|  | ||||
| - label: Multi-step Tests (4 GPUs) # 36min | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   mirror_hardwares: [amdexperimental, amdproduction] | ||||
|   working_dir: "/vllm-workspace/tests" | ||||
|   num_gpus: 4 | ||||
|   source_file_dependencies: | ||||
| @ -730,7 +792,7 @@ steps: | ||||
|     - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt | ||||
|  | ||||
| - label: Weight Loading Multiple GPU Test - Large Models # optional | ||||
|   mirror_hardwares: [amdexperimental]  | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   working_dir: "/vllm-workspace/tests" | ||||
|   num_gpus: 2 | ||||
|   gpu: a100 | ||||
|  | ||||
							
								
								
									
										4
									
								
								.github/CODEOWNERS
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/CODEOWNERS
									
									
									
									
										vendored
									
									
								
							| @ -18,6 +18,10 @@ | ||||
| /vllm/entrypoints @aarnphm | ||||
| CMakeLists.txt @tlrmchlsmth | ||||
|  | ||||
| # Any change to the VllmConfig changes can have a large user-facing impact, | ||||
| # so spam a lot of people | ||||
| /vllm/config.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor | ||||
|  | ||||
| # vLLM V1 | ||||
| /vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat | ||||
| /vllm/v1/structured_output @mgoin @russellb @aarnphm | ||||
|  | ||||
							
								
								
									
										85
									
								
								.github/mergify.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										85
									
								
								.github/mergify.yml
									
									
									
									
										vendored
									
									
								
							| @ -27,6 +27,22 @@ pull_request_rules: | ||||
|       add: | ||||
|         - ci/build | ||||
|  | ||||
| - name: label-deepseek | ||||
|   description: Automatically apply deepseek label | ||||
|   conditions: | ||||
|     - or: | ||||
|       - files~=^examples/.*deepseek.*\.py | ||||
|       - files~=^tests/.*deepseek.*\.py | ||||
|       - files~=^vllm/entrypoints/openai/tool_parsers/.*deepseek.*\.py | ||||
|       - files~=^vllm/model_executor/models/.*deepseek.*\.py | ||||
|       - files~=^vllm/reasoning/.*deepseek.*\.py | ||||
|       - files~=^vllm/transformers_utils/.*deepseek.*\.py | ||||
|       - title~=(?i)DeepSeek | ||||
|   actions: | ||||
|     label: | ||||
|       add: | ||||
|         - deepseek | ||||
|  | ||||
| - name: label-frontend | ||||
|   description: Automatically apply frontend label | ||||
|   conditions: | ||||
| @ -45,6 +61,7 @@ pull_request_rules: | ||||
|       - files~=^vllm/entrypoints/openai/tool_parsers/llama.*\.py | ||||
|       - files~=^vllm/model_executor/models/.*llama.*\.py | ||||
|       - files~=^vllm/transformers_utils/configs/.*llama.*\.py | ||||
|       - title~=(?i)llama | ||||
|   actions: | ||||
|     label: | ||||
|       add: | ||||
| @ -57,14 +74,72 @@ pull_request_rules: | ||||
|       - files~=^vllm/multimodal/ | ||||
|       - files~=^tests/multimodal/ | ||||
|       - files~=^tests/models/multimodal/ | ||||
|       - files~=^tests/models/*/audio_language/ | ||||
|       - files~=^tests/models/*/vision_language/ | ||||
|       - files=tests/models/test_vision.py | ||||
|   actions: | ||||
|     label: | ||||
|       add: | ||||
|         - multi-modality | ||||
|  | ||||
| - name: label-new-model | ||||
|   description: Automatically apply new-model label | ||||
|   conditions: | ||||
|     - and: | ||||
|       - files~=^vllm/model_executor/models/ | ||||
|       - files=vllm/model_executor/models/registry.py | ||||
|       - files=tests/models/registry.py | ||||
|       - files=docs/models/supported_models.md | ||||
|   actions: | ||||
|     label: | ||||
|       add: | ||||
|         - new-model | ||||
|  | ||||
| - name: label-performance | ||||
|   description: Automatically apply performance label | ||||
|   conditions: | ||||
|     - or: | ||||
|       - files~=^benchmarks/ | ||||
|       - files~=^vllm/benchmarks/ | ||||
|       - files~=^tests/benchmarks/ | ||||
|       - files~=^\.buildkite/nightly-benchmarks/ | ||||
|   actions: | ||||
|     label: | ||||
|       add: | ||||
|         - performance | ||||
|  | ||||
| - name: label-qwen | ||||
|   description: Automatically apply qwen label | ||||
|   conditions: | ||||
|     - or: | ||||
|       - files~=^examples/.*qwen.*\.py | ||||
|       - files~=^tests/.*qwen.*\.py | ||||
|       - files~=^vllm/model_executor/models/.*qwen.*\.py | ||||
|       - files~=^vllm/reasoning/.*qwen.*\.py | ||||
|       - title~=(?i)Qwen | ||||
|   actions: | ||||
|     label: | ||||
|       add: | ||||
|         - qwen | ||||
|  | ||||
| - name: label-rocm | ||||
|   description: Automatically apply rocm label | ||||
|   conditions: | ||||
|     - or: | ||||
|       - files~=^csrc/rocm/ | ||||
|       - files~=^docker/Dockerfile.rocm | ||||
|       - files~=^requirements/rocm.*\.txt | ||||
|       - files~=^vllm/attention/backends/rocm.*\.py | ||||
|       - files~=^vllm/attention/ops/rocm.*\.py | ||||
|       - files~=^vllm/model_executor/layers/fused_moe/rocm.*\.py | ||||
|       - files~=^vllm/v1/attention/backends/mla/rocm.*\.py | ||||
|       - files~=^tests/kernels/.*_rocm.*\.py | ||||
|       - files=vllm/platforms/rocm.py | ||||
|       - title~=(?i)AMD | ||||
|       - title~=(?i)ROCm | ||||
|   actions: | ||||
|     label: | ||||
|       add: | ||||
|         - rocm | ||||
|  | ||||
| - name: label-structured-output | ||||
|   description: Automatically apply structured-output label | ||||
|   conditions: | ||||
| @ -92,8 +167,14 @@ pull_request_rules: | ||||
|   conditions: | ||||
|     - or: | ||||
|       - files~=^vllm/spec_decode/ | ||||
|       - files~=^vllm/v1/spec_decode/ | ||||
|       - files=vllm/model_executor/layers/spec_decode_base_sampler.py | ||||
|       - files~=^tests/spec_decode/ | ||||
|       - files~=^tests/v1/spec_decode/ | ||||
|       - files~=^examples/.*(spec_decode|mlpspeculator|eagle|speculation).*\.py | ||||
|       - files~=^vllm/model_executor/models/.*eagle.*\.py | ||||
|       - files=vllm/model_executor/models/mlp_speculator.py | ||||
|       - files~=^vllm/transformers_utils/configs/(eagle|medusa|mlp_speculator)\.py | ||||
|   actions: | ||||
|     label: | ||||
|       add: | ||||
|  | ||||
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -200,5 +200,5 @@ benchmarks/**/*.json | ||||
| actionlint | ||||
| shellcheck*/ | ||||
|  | ||||
| # Ingore moe/marlin_moe gen code | ||||
| # Ignore moe/marlin_moe gen code | ||||
| csrc/moe/marlin_moe_wna16/kernel_* | ||||
|  | ||||
| @ -20,12 +20,10 @@ repos: | ||||
|     args: [--output-format, github, --fix] | ||||
|   - id: ruff-format | ||||
|     files: ^(.buildkite|benchmarks|examples)/.* | ||||
| - repo: https://github.com/codespell-project/codespell | ||||
|   rev: v2.4.1 | ||||
| - repo: https://github.com/crate-ci/typos | ||||
|   rev: v1.32.0 | ||||
|   hooks: | ||||
|   - id: codespell | ||||
|     additional_dependencies: ['tomli'] | ||||
|     args: ['--toml', 'pyproject.toml'] | ||||
|   - id: typos | ||||
| - repo: https://github.com/PyCQA/isort | ||||
|   rev: 6.0.1 | ||||
|   hooks: | ||||
| @ -55,6 +53,11 @@ repos: | ||||
|       files: ^requirements/test\.(in|txt)$ | ||||
| - repo: local | ||||
|   hooks: | ||||
|   - id: format-torch-nightly-test | ||||
|     name: reformat nightly_torch_test.txt to be in sync with test.in | ||||
|     language: python | ||||
|     entry: python tools/generate_nightly_torch_test.py | ||||
|     files: ^requirements/test\.(in|txt)$ | ||||
|   - id: mypy-local | ||||
|     name: Run mypy for local Python installation | ||||
|     entry: tools/mypy.sh 0 "local" | ||||
| @ -117,6 +120,11 @@ repos: | ||||
|     entry: python tools/check_spdx_header.py | ||||
|     language: python | ||||
|     types: [python] | ||||
|   - id: check-root-lazy-imports | ||||
|     name: Check root lazy imports | ||||
|     entry: python tools/check_init_lazy_imports.py | ||||
|     language: python | ||||
|     types: [python] | ||||
|   - id: check-filenames | ||||
|     name: Check for spaces in all filenames | ||||
|     entry: bash | ||||
| @ -145,6 +153,20 @@ repos: | ||||
|     types: [python] | ||||
|     pass_filenames: false | ||||
|     additional_dependencies: [regex] | ||||
|   - id: check-pickle-imports | ||||
|     name: Prevent new pickle/cloudpickle imports | ||||
|     entry: python tools/check_pickle_imports.py | ||||
|     language: python | ||||
|     types: [python] | ||||
|     pass_filenames: false | ||||
|     additional_dependencies: [pathspec, regex] | ||||
|   - id: validate-config | ||||
|     name: Validate configuration has default values and that each field has a docstring | ||||
|     entry: python tools/validate_config.py | ||||
|     language: python | ||||
|     types: [python] | ||||
|     pass_filenames: true | ||||
|     files: vllm/config.py|tests/test_config.py | ||||
|   # Keep `suggestion` last | ||||
|   - id: suggestion | ||||
|     name: Suggestion | ||||
|  | ||||
| @ -420,9 +420,39 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | ||||
|     endif() | ||||
|   endif() | ||||
|  | ||||
|   # The cutlass_scaled_mm kernels for Blackwell (c3x, i.e. CUTLASS 3.x) require | ||||
|  | ||||
|   # The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require | ||||
|   # CUDA 12.8 or later | ||||
|   cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;12.0a" "${CUDA_ARCHS}") | ||||
|   cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}") | ||||
|   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS) | ||||
|     set(SRCS | ||||
|       "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu" | ||||
|       "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu" | ||||
|     ) | ||||
|     set_gencode_flags_for_srcs( | ||||
|       SRCS "${SRCS}" | ||||
|       CUDA_ARCHS "${SCALED_MM_ARCHS}") | ||||
|     list(APPEND VLLM_EXT_SRC "${SRCS}") | ||||
|     list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1") | ||||
|     # Let scaled_mm_c2x know it doesn't need to build these arches | ||||
|     list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") | ||||
|     message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}") | ||||
|   else() | ||||
|     if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS) | ||||
|       message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is " | ||||
|                      "not >= 12.8, we recommend upgrading to CUDA 12.8 or " | ||||
|                      "later if you intend on running FP8 quantized models on " | ||||
|                      "Blackwell.") | ||||
|     else() | ||||
|       message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found " | ||||
|                      "in CUDA target architectures") | ||||
|     endif() | ||||
|   endif() | ||||
|  | ||||
|  | ||||
|   # The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x) | ||||
|   # require CUDA 12.8 or later | ||||
|   cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}") | ||||
|   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS) | ||||
|     set(SRCS | ||||
|       "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu" | ||||
| @ -513,6 +543,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | ||||
|       CUDA_ARCHS "${FP4_ARCHS}") | ||||
|     list(APPEND VLLM_EXT_SRC "${SRCS}") | ||||
|     list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1") | ||||
|     list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") | ||||
|     message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") | ||||
|   else() | ||||
|     message(STATUS "Not building NVFP4 as no compatible archs were found.") | ||||
| @ -542,13 +573,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | ||||
|  | ||||
|   # CUTLASS MoE kernels | ||||
|  | ||||
|   # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works | ||||
|   # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works | ||||
|   # on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled | ||||
|   # if it's possible to compile MoE kernels that use its output. | ||||
|   cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") | ||||
|   cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") | ||||
|   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) | ||||
|     set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" | ||||
|              "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") | ||||
|     set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu") | ||||
|     set_gencode_flags_for_srcs( | ||||
|       SRCS "${SRCS}" | ||||
|       CUDA_ARCHS "${SCALED_MM_ARCHS}") | ||||
| @ -562,7 +592,27 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | ||||
|                      "if you intend on running FP8 quantized MoE models on Hopper.") | ||||
|     else() | ||||
|       message(STATUS "Not building grouped_mm_c3x as no compatible archs found " | ||||
|                      "in CUDA target architectures") | ||||
|                      "in CUDA target architectures.") | ||||
|     endif() | ||||
|   endif() | ||||
|  | ||||
|   # moe_data.cu is used by all CUTLASS MoE kernels. | ||||
|   cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") | ||||
|   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) | ||||
|     set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") | ||||
|     set_gencode_flags_for_srcs( | ||||
|       SRCS "${SRCS}" | ||||
|       CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") | ||||
|     list(APPEND VLLM_EXT_SRC "${SRCS}") | ||||
|     message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}") | ||||
|   else() | ||||
|     if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) | ||||
|       message(STATUS "Not building moe_data as CUDA Compiler version is " | ||||
|                      "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " | ||||
|                      "if you intend on running FP8 quantized MoE models on Hopper or Blackwell.") | ||||
|     else() | ||||
|       message(STATUS "Not building moe_data as no compatible archs found " | ||||
|                      "in CUDA target architectures.") | ||||
|     endif() | ||||
|   endif() | ||||
|  | ||||
| @ -638,6 +688,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | ||||
| # if CUDA endif | ||||
| endif() | ||||
|  | ||||
| if (VLLM_GPU_LANG STREQUAL "HIP") | ||||
|   # Add QuickReduce kernels | ||||
|   list(APPEND VLLM_EXT_SRC | ||||
|     "csrc/custom_quickreduce.cu" | ||||
|   ) | ||||
| # if ROCM endif | ||||
| endif() | ||||
|  | ||||
| message(STATUS "Enabling C extension.") | ||||
| define_gpu_extension_target( | ||||
|   _C | ||||
|  | ||||
| @ -154,11 +154,13 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs | ||||
|  | ||||
| ## Contact Us | ||||
|  | ||||
| <!-- --8<-- [start:contact-us] --> | ||||
| - For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues) or [Discussions](https://github.com/vllm-project/vllm/discussions) | ||||
| - For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai) | ||||
| - coordinating contributions and development, please use [Slack](https://slack.vllm.ai) | ||||
| - For coordinating contributions and development, please use [Slack](https://slack.vllm.ai) | ||||
| - For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature | ||||
| - For collaborations and partnerships, please contact us at [vllm-questions@lists.berkeley.edu](mailto:vllm-questions@lists.berkeley.edu) | ||||
| <!-- --8<-- [end:contact-us] --> | ||||
|  | ||||
| ## Media Kit | ||||
|  | ||||
|  | ||||
| @ -4,7 +4,7 @@ This README guides you through running benchmark tests with the extensive | ||||
| datasets supported on vLLM. It’s a living document, updated as new features and datasets | ||||
| become available. | ||||
|  | ||||
| ## Dataset Overview | ||||
| **Dataset Overview** | ||||
|  | ||||
| <table style="width:100%; border-collapse: collapse;"> | ||||
|   <thead> | ||||
| @ -82,7 +82,10 @@ become available. | ||||
| **Note**: HuggingFace dataset's `dataset-name` should be set to `hf` | ||||
|  | ||||
| --- | ||||
| ## Example - Online Benchmark | ||||
| <details> | ||||
| <summary><b>🚀 Example - Online Benchmark</b></summary> | ||||
|  | ||||
| <br/> | ||||
|  | ||||
| First start serving your model | ||||
|  | ||||
| @ -130,7 +133,8 @@ P99 ITL (ms):                            8.39 | ||||
| ================================================== | ||||
| ``` | ||||
|  | ||||
| ### Custom Dataset | ||||
| **Custom Dataset** | ||||
|  | ||||
| If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl | ||||
|  | ||||
| ``` | ||||
| @ -162,7 +166,7 @@ python3 benchmarks/benchmark_serving.py --port 9001 --save-result --save-detaile | ||||
|  | ||||
| You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`. | ||||
|  | ||||
| ### VisionArena Benchmark for Vision Language Models | ||||
| **VisionArena Benchmark for Vision Language Models** | ||||
|  | ||||
| ```bash | ||||
| # need a model with vision capability here | ||||
| @ -180,7 +184,7 @@ python3 vllm/benchmarks/benchmark_serving.py \ | ||||
|   --num-prompts 1000 | ||||
| ``` | ||||
|  | ||||
| ### InstructCoder Benchmark with Speculative Decoding | ||||
| **InstructCoder Benchmark with Speculative Decoding** | ||||
|  | ||||
| ``` bash | ||||
| VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ | ||||
| @ -197,7 +201,7 @@ python3 benchmarks/benchmark_serving.py \ | ||||
|     --num-prompts 2048 | ||||
| ``` | ||||
|  | ||||
| ### Other HuggingFaceDataset Examples | ||||
| **Other HuggingFaceDataset Examples** | ||||
|  | ||||
| ```bash | ||||
| vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests | ||||
| @ -251,7 +255,7 @@ python3 vllm/benchmarks/benchmark_serving.py \ | ||||
|     --num-prompts 80 | ||||
| ``` | ||||
|  | ||||
| ### Running With Sampling Parameters | ||||
| **Running With Sampling Parameters** | ||||
|  | ||||
| When using OpenAI-compatible backends such as `vllm`, optional sampling | ||||
| parameters can be specified. Example client command: | ||||
| @ -269,8 +273,27 @@ python3 vllm/benchmarks/benchmark_serving.py \ | ||||
|   --num-prompts 10 | ||||
| ``` | ||||
|  | ||||
| --- | ||||
| ## Example - Offline Throughput Benchmark | ||||
| **Running With Ramp-Up Request Rate** | ||||
|  | ||||
| The benchmark tool also supports ramping up the request rate over the | ||||
| duration of the benchmark run. This can be useful for stress testing the | ||||
| server or finding the maximum throughput that it can handle, given some latency budget. | ||||
|  | ||||
| Two ramp-up strategies are supported: | ||||
| - `linear`: Increases the request rate linearly from a start value to an end value. | ||||
| - `exponential`: Increases the request rate exponentially. | ||||
|  | ||||
| The following arguments can be used to control the ramp-up: | ||||
| - `--ramp-up-strategy`: The ramp-up strategy to use (`linear` or `exponential`). | ||||
| - `--ramp-up-start-rps`: The request rate at the beginning of the benchmark. | ||||
| - `--ramp-up-end-rps`: The request rate at the end of the benchmark. | ||||
|  | ||||
| </details> | ||||
|  | ||||
| <details> | ||||
| <summary><b>📈 Example - Offline Throughput Benchmark</b></summary> | ||||
|  | ||||
| <br/> | ||||
|  | ||||
| ```bash | ||||
| python3 vllm/benchmarks/benchmark_throughput.py \ | ||||
| @ -288,7 +311,7 @@ Total num prompt tokens:  5014 | ||||
| Total num output tokens:  1500 | ||||
| ``` | ||||
|  | ||||
| ### VisionArena Benchmark for Vision Language Models | ||||
| **VisionArena Benchmark for Vision Language Models** | ||||
|  | ||||
| ``` bash | ||||
| python3 vllm/benchmarks/benchmark_throughput.py \ | ||||
| @ -308,7 +331,7 @@ Total num prompt tokens:  14527 | ||||
| Total num output tokens:  1280 | ||||
| ``` | ||||
|  | ||||
| ### InstructCoder Benchmark with Speculative Decoding | ||||
| **InstructCoder Benchmark with Speculative Decoding** | ||||
|  | ||||
| ``` bash | ||||
| VLLM_WORKER_MULTIPROC_METHOD=spawn \ | ||||
| @ -332,7 +355,7 @@ Total num prompt tokens:  261136 | ||||
| Total num output tokens:  204800 | ||||
| ``` | ||||
|  | ||||
| ### Other HuggingFaceDataset Examples | ||||
| **Other HuggingFaceDataset Examples** | ||||
|  | ||||
| **`lmms-lab/LLaVA-OneVision-Data`** | ||||
|  | ||||
| @ -371,7 +394,7 @@ python3 benchmarks/benchmark_throughput.py \ | ||||
|   --num-prompts 10 | ||||
| ``` | ||||
|  | ||||
| ### Benchmark with LoRA Adapters | ||||
| **Benchmark with LoRA Adapters** | ||||
|  | ||||
| ``` bash | ||||
| # download dataset | ||||
| @ -387,3 +410,196 @@ python3 vllm/benchmarks/benchmark_throughput.py \ | ||||
|   --enable-lora \ | ||||
|   --lora-path yard1/llama-2-7b-sql-lora-test | ||||
|   ``` | ||||
|  | ||||
| </details> | ||||
|  | ||||
| <details> | ||||
| <summary><b>🛠️ Example - Structured Output Benchmark</b></summary> | ||||
|  | ||||
| <br/> | ||||
|  | ||||
| Benchmark the performance of structured output generation (JSON, grammar, regex). | ||||
|  | ||||
| **Server Setup** | ||||
|  | ||||
| ```bash | ||||
| vllm serve NousResearch/Hermes-3-Llama-3.1-8B --disable-log-requests | ||||
| ``` | ||||
|  | ||||
| **JSON Schema Benchmark** | ||||
|  | ||||
| ```bash | ||||
| python3 benchmarks/benchmark_serving_structured_output.py \ | ||||
|   --backend vllm \ | ||||
|   --model NousResearch/Hermes-3-Llama-3.1-8B \ | ||||
|   --dataset json \ | ||||
|   --structured-output-ratio 1.0 \ | ||||
|   --request-rate 10 \ | ||||
|   --num-prompts 1000 | ||||
| ``` | ||||
|  | ||||
| **Grammar-based Generation Benchmark** | ||||
|  | ||||
| ```bash | ||||
| python3 benchmarks/benchmark_serving_structured_output.py \ | ||||
|   --backend vllm \ | ||||
|   --model NousResearch/Hermes-3-Llama-3.1-8B \ | ||||
|   --dataset grammar \ | ||||
|   --structure-type grammar \ | ||||
|   --request-rate 10 \ | ||||
|   --num-prompts 1000 | ||||
| ``` | ||||
|  | ||||
| **Regex-based Generation Benchmark** | ||||
|  | ||||
| ```bash | ||||
| python3 benchmarks/benchmark_serving_structured_output.py \ | ||||
|   --backend vllm \ | ||||
|   --model NousResearch/Hermes-3-Llama-3.1-8B \ | ||||
|   --dataset regex \ | ||||
|   --request-rate 10 \ | ||||
|   --num-prompts 1000 | ||||
| ``` | ||||
|  | ||||
| **Choice-based Generation Benchmark** | ||||
|  | ||||
| ```bash | ||||
| python3 benchmarks/benchmark_serving_structured_output.py \ | ||||
|   --backend vllm \ | ||||
|   --model NousResearch/Hermes-3-Llama-3.1-8B \ | ||||
|   --dataset choice \ | ||||
|   --request-rate 10 \ | ||||
|   --num-prompts 1000 | ||||
| ``` | ||||
|  | ||||
| **XGrammar Benchmark Dataset** | ||||
|  | ||||
| ```bash | ||||
| python3 benchmarks/benchmark_serving_structured_output.py \ | ||||
|   --backend vllm \ | ||||
|   --model NousResearch/Hermes-3-Llama-3.1-8B \ | ||||
|   --dataset xgrammar_bench \ | ||||
|   --request-rate 10 \ | ||||
|   --num-prompts 1000 | ||||
| ``` | ||||
|  | ||||
| </details> | ||||
|  | ||||
| <details> | ||||
| <summary><b>📚 Example - Long Document QA Benchmark</b></summary> | ||||
|  | ||||
| <br/> | ||||
|  | ||||
| Benchmark the performance of long document question-answering with prefix caching. | ||||
|  | ||||
| **Basic Long Document QA Test** | ||||
|  | ||||
| ```bash | ||||
| python3 benchmarks/benchmark_long_document_qa_throughput.py \ | ||||
|   --model meta-llama/Llama-2-7b-chat-hf \ | ||||
|   --enable-prefix-caching \ | ||||
|   --num-documents 16 \ | ||||
|   --document-length 2000 \ | ||||
|   --output-len 50 \ | ||||
|   --repeat-count 5 | ||||
| ``` | ||||
|  | ||||
| **Different Repeat Modes** | ||||
|  | ||||
| ```bash | ||||
| # Random mode (default) - shuffle prompts randomly | ||||
| python3 benchmarks/benchmark_long_document_qa_throughput.py \ | ||||
|   --model meta-llama/Llama-2-7b-chat-hf \ | ||||
|   --enable-prefix-caching \ | ||||
|   --num-documents 8 \ | ||||
|   --document-length 3000 \ | ||||
|   --repeat-count 3 \ | ||||
|   --repeat-mode random | ||||
|  | ||||
| # Tile mode - repeat entire prompt list in sequence | ||||
| python3 benchmarks/benchmark_long_document_qa_throughput.py \ | ||||
|   --model meta-llama/Llama-2-7b-chat-hf \ | ||||
|   --enable-prefix-caching \ | ||||
|   --num-documents 8 \ | ||||
|   --document-length 3000 \ | ||||
|   --repeat-count 3 \ | ||||
|   --repeat-mode tile | ||||
|  | ||||
| # Interleave mode - repeat each prompt consecutively | ||||
| python3 benchmarks/benchmark_long_document_qa_throughput.py \ | ||||
|   --model meta-llama/Llama-2-7b-chat-hf \ | ||||
|   --enable-prefix-caching \ | ||||
|   --num-documents 8 \ | ||||
|   --document-length 3000 \ | ||||
|   --repeat-count 3 \ | ||||
|   --repeat-mode interleave | ||||
| ``` | ||||
|  | ||||
| </details> | ||||
|  | ||||
| <details> | ||||
| <summary><b>🗂️ Example - Prefix Caching Benchmark</b></summary> | ||||
|  | ||||
| <br/> | ||||
|  | ||||
| Benchmark the efficiency of automatic prefix caching. | ||||
|  | ||||
| **Fixed Prompt with Prefix Caching** | ||||
|  | ||||
| ```bash | ||||
| python3 benchmarks/benchmark_prefix_caching.py \ | ||||
|   --model meta-llama/Llama-2-7b-chat-hf \ | ||||
|   --enable-prefix-caching \ | ||||
|   --num-prompts 1 \ | ||||
|   --repeat-count 100 \ | ||||
|   --input-length-range 128:256 | ||||
| ``` | ||||
|  | ||||
| **ShareGPT Dataset with Prefix Caching** | ||||
|  | ||||
| ```bash | ||||
| # download dataset | ||||
| # wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json | ||||
|  | ||||
| python3 benchmarks/benchmark_prefix_caching.py \ | ||||
|   --model meta-llama/Llama-2-7b-chat-hf \ | ||||
|   --dataset-path /path/ShareGPT_V3_unfiltered_cleaned_split.json \ | ||||
|   --enable-prefix-caching \ | ||||
|   --num-prompts 20 \ | ||||
|   --repeat-count 5 \ | ||||
|   --input-length-range 128:256 | ||||
| ``` | ||||
|  | ||||
| </details> | ||||
|  | ||||
| <details> | ||||
| <summary><b>⚡ Example - Request Prioritization Benchmark</b></summary> | ||||
|  | ||||
| <br/> | ||||
|  | ||||
| Benchmark the performance of request prioritization in vLLM. | ||||
|  | ||||
| **Basic Prioritization Test** | ||||
|  | ||||
| ```bash | ||||
| python3 benchmarks/benchmark_prioritization.py \ | ||||
|   --model meta-llama/Llama-2-7b-chat-hf \ | ||||
|   --input-len 128 \ | ||||
|   --output-len 64 \ | ||||
|   --num-prompts 100 \ | ||||
|   --scheduling-policy priority | ||||
| ``` | ||||
|  | ||||
| **Multiple Sequences per Prompt** | ||||
|  | ||||
| ```bash | ||||
| python3 benchmarks/benchmark_prioritization.py \ | ||||
|   --model meta-llama/Llama-2-7b-chat-hf \ | ||||
|   --input-len 128 \ | ||||
|   --output-len 64 \ | ||||
|   --num-prompts 100 \ | ||||
|   --scheduling-policy priority \ | ||||
|   --n 2 | ||||
| ``` | ||||
|  | ||||
| </details> | ||||
|  | ||||
| @ -10,6 +10,7 @@ | ||||
| # 3. Set variables (ALL REQUIRED) | ||||
| #   BASE: your directory for vllm repo | ||||
| #   MODEL: the model served by vllm | ||||
| #   SYSTEM: the hardware, choice TPU or GPU, for other systems, "get best profile" might not support. | ||||
| #   TP: ways of tensor parallelism | ||||
| #   DOWNLOAD_DIR: directory to download and load model weights. | ||||
| #   INPUT_LEN: request input len | ||||
| @ -34,6 +35,7 @@ | ||||
| TAG=$(date +"%Y_%m_%d_%H_%M") | ||||
| BASE="" | ||||
| MODEL="meta-llama/Llama-3.1-8B-Instruct" | ||||
| SYSTEM="TPU" | ||||
| TP=1 | ||||
| DOWNLOAD_DIR="" | ||||
| INPUT_LEN=4000 | ||||
| @ -45,12 +47,15 @@ NUM_BATCHED_TOKENS_LIST="512 1024 2048 4096" | ||||
|  | ||||
| LOG_FOLDER="$BASE/auto-benchmark/$TAG" | ||||
| RESULT="$LOG_FOLDER/result.txt" | ||||
| PROFILE_PATH="$LOG_FOLDER/profile" | ||||
|  | ||||
| echo "result file: $RESULT" | ||||
| echo "model: $MODEL" | ||||
|  | ||||
| rm -rf $LOG_FOLDER | ||||
| rm -rf $PROFILE_PATH | ||||
| mkdir -p $LOG_FOLDER | ||||
| mkdir -p $PROFILE_PATH | ||||
|  | ||||
| cd "$BASE/vllm" | ||||
|  | ||||
| @ -70,10 +75,11 @@ start_server() { | ||||
|     local max_num_seqs=$2 | ||||
|     local max_num_batched_tokens=$3 | ||||
|     local vllm_log=$4 | ||||
|     local profile_dir=$5 | ||||
|      | ||||
|     pkill -f vllm | ||||
|  | ||||
|     VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 vllm serve $MODEL \ | ||||
|     VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir vllm serve $MODEL \ | ||||
|         --disable-log-requests \ | ||||
|         --port 8004 \ | ||||
|         --gpu-memory-utilization $gpu_memory_utilization \ | ||||
| @ -105,19 +111,37 @@ start_server() { | ||||
|     fi | ||||
| } | ||||
|  | ||||
| update_best_profile() { | ||||
|     local profile_dir=$1 | ||||
|     local profile_index=$2 | ||||
|     sorted_paths=($(find "$profile_dir" -maxdepth 1 -not -path "$profile_dir" | sort)) | ||||
|     selected_profile_file= | ||||
|     if [[ "$SYSTEM" == "TPU" ]]; then | ||||
|         selected_profile_file="${sorted_paths[$profile_index]}/*.xplane.pb" | ||||
|     fi  | ||||
|     if [[ "$SYSTEM" == "GPU" ]]; then | ||||
|         selected_profile_file="${sorted_paths[$profile_index]}" | ||||
|     fi  | ||||
|     rm -f $PROFILE_PATH/* | ||||
|     cp $selected_profile_file $PROFILE_PATH | ||||
| } | ||||
|  | ||||
| run_benchmark() { | ||||
|     local max_num_seqs=$1 | ||||
|     local max_num_batched_tokens=$2 | ||||
|     local gpu_memory_utilization=$3 | ||||
|     echo "max_num_seq: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" | ||||
|     local vllm_log="$LOG_FOLDER/vllm_log_${max_num_seqs}_${max_num_batched_tokens}.txt" | ||||
|     local profile_dir="$LOG_FOLDER/profile_${max_num_seqs}_${max_num_batched_tokens}" | ||||
|     echo "vllm_log: $vllm_log" | ||||
|     echo | ||||
|     rm -f $vllm_log | ||||
|     mkdir -p $profile_dir | ||||
|     pkill -f vllm | ||||
|     local profile_index=0 | ||||
|  | ||||
|     echo "starting server..." | ||||
|     start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log | ||||
|     start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log $profile_dir | ||||
|     result=$? | ||||
|     if [[ "$result" -eq 1 ]]; then | ||||
|         echo "server failed to start. gpu_memory_utilization:$gpu_memory_utilization, max_num_seqs:$max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" | ||||
| @ -144,7 +168,8 @@ run_benchmark() { | ||||
|         --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ | ||||
|         --num-prompts 1000 \ | ||||
|         --random-prefix-len $prefix_len \ | ||||
|         --port 8004 &> "$bm_log" | ||||
|         --port 8004 \ | ||||
|         --profile &> "$bm_log" | ||||
|     throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') | ||||
|     e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') | ||||
|     goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') | ||||
| @ -158,6 +183,7 @@ run_benchmark() { | ||||
|     # start from request-rate as int(throughput) + 1 | ||||
|         request_rate=$((${throughput%.*} + 1)) | ||||
|         while ((request_rate > 0)); do | ||||
|             profile_index=$((profile_index+1)) | ||||
|             # clear prefix cache | ||||
|             curl -X POST http://0.0.0.0:8004/reset_prefix_cache | ||||
|             sleep 5 | ||||
| @ -195,6 +221,12 @@ run_benchmark() { | ||||
|             best_max_num_seqs=$max_num_seqs | ||||
|             best_num_batched_tokens=$max_num_batched_tokens | ||||
|             best_goodput=$goodput | ||||
|             if [[ "$SYSTEM" == "TPU" ]]; then | ||||
|                 update_best_profile "$profile_dir/plugins/profile" $profile_index | ||||
|             fi | ||||
|             if [[ "$SYSTEM" == "GPU" ]]; then | ||||
|                 update_best_profile "$profile_dir" $profile_index | ||||
|             fi | ||||
|         fi | ||||
|     else | ||||
|         echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}" | ||||
| @ -239,6 +271,6 @@ for num_seqs in "${num_seqs_list[@]}"; do | ||||
|     done | ||||
| done | ||||
| echo "finish permutations" | ||||
| echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" | ||||
| echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" >> "$RESULT" | ||||
| echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" | ||||
| echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT" | ||||
|  | ||||
|  | ||||
| @ -404,8 +404,14 @@ async def async_request_openai_chat_completions( | ||||
|                         chunk_bytes = chunk_bytes.strip() | ||||
|                         if not chunk_bytes: | ||||
|                             continue | ||||
|                         chunk_bytes = chunk_bytes.decode("utf-8") | ||||
|                         # NOTE: SSE comments (often used as pings) start with a colon. | ||||
|                         # These are not JSON data payload and should be skipped. | ||||
|                         if chunk_bytes.startswith(":"): | ||||
|                             continue | ||||
|  | ||||
|                         chunk = chunk_bytes.removeprefix("data: ") | ||||
|  | ||||
|                         chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") | ||||
|                         if chunk != "[DONE]": | ||||
|                             timestamp = time.perf_counter() | ||||
|                             data = json.loads(chunk) | ||||
|  | ||||
| @ -349,11 +349,12 @@ class RandomDataset(BenchmarkDataset): | ||||
|             # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] | ||||
|             # To avoid uncontrolled change of the prompt length, | ||||
|             # the encoded sequence is truncated before being decode again. | ||||
|             total_input_len = prefix_len + int(input_lens[i]) | ||||
|             re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[ | ||||
|                 : input_lens[i] | ||||
|                 :total_input_len | ||||
|             ] | ||||
|             prompt = tokenizer.decode(re_encoded_sequence) | ||||
|             total_input_len = prefix_len + int(input_lens[i]) | ||||
|             total_input_len = len(re_encoded_sequence) | ||||
|             requests.append( | ||||
|                 SampleRequest( | ||||
|                     prompt=prompt, | ||||
|  | ||||
| @ -123,7 +123,7 @@ def main(args: argparse.Namespace): | ||||
|         save_to_pytorch_benchmark_format(args, results) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
| def create_argument_parser(): | ||||
|     parser = FlexibleArgumentParser( | ||||
|         description="Benchmark the latency of processing a single batch of " | ||||
|         "requests till completion." | ||||
| @ -171,6 +171,12 @@ if __name__ == "__main__": | ||||
|     # V1 enables prefix caching by default which skews the latency | ||||
|     # numbers. We need to disable prefix caching by default. | ||||
|     parser.set_defaults(enable_prefix_caching=False) | ||||
|  | ||||
|     return parser | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = create_argument_parser() | ||||
|     args = parser.parse_args() | ||||
|     if args.profile and not envs.VLLM_TORCH_PROFILER_DIR: | ||||
|         raise OSError( | ||||
|  | ||||
| @ -142,7 +142,7 @@ def main(args): | ||||
|     ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
| def create_argument_parser(): | ||||
|     parser = FlexibleArgumentParser( | ||||
|         description="Benchmark the performance with or " | ||||
|         "without automatic prefix caching." | ||||
| @ -192,5 +192,11 @@ if __name__ == "__main__": | ||||
|     ) | ||||
|  | ||||
|     parser = EngineArgs.add_cli_args(parser) | ||||
|  | ||||
|     return parser | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = create_argument_parser() | ||||
|     args = parser.parse_args() | ||||
|     main(args) | ||||
|  | ||||
							
								
								
									
										362
									
								
								benchmarks/benchmark_one_concurrent.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										362
									
								
								benchmarks/benchmark_one_concurrent.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,362 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| import argparse | ||||
| import asyncio | ||||
| import logging | ||||
| import random | ||||
| import time | ||||
| from dataclasses import dataclass | ||||
| from typing import Optional | ||||
|  | ||||
| import aiohttp  # Import aiohttp | ||||
| import numpy as np | ||||
| from tqdm import tqdm | ||||
|  | ||||
| from backend_request_func import RequestFuncInput, RequestFuncOutput | ||||
| from benchmark_dataset import RandomDataset, SampleRequest | ||||
|  | ||||
| try: | ||||
|     from vllm.transformers_utils.tokenizer import get_tokenizer | ||||
| except ImportError: | ||||
|     from backend_request_func import get_tokenizer | ||||
|  | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class BenchmarkMetrics: | ||||
|     completed: int | ||||
|     total_input: int | ||||
|     total_output: int | ||||
|     mean_ttft_ms: float | ||||
|     median_ttft_ms: float | ||||
|     std_ttft_ms: float | ||||
|     percentiles_ttft_ms: list[tuple[float, float]] | ||||
|     mean_itl_ms: float | ||||
|     median_itl_ms: float | ||||
|     std_itl_ms: float | ||||
|     percentiles_itl_ms: list[tuple[float, float]] | ||||
|     mean_e2el_ms: float | ||||
|     median_e2el_ms: float | ||||
|     std_e2el_ms: float | ||||
|     percentiles_e2el_ms: list[tuple[float, float]] | ||||
|  | ||||
|  | ||||
| async def reset_cache(reset_url: str): | ||||
|     """Sends a POST request to reset the prefix cache.""" | ||||
|     logger.debug("Resetting prefix cache at %s", reset_url) | ||||
|     try: | ||||
|         async with ( | ||||
|             aiohttp.ClientSession() as session, | ||||
|             session.post(reset_url) as response, | ||||
|         ): | ||||
|             response.raise_for_status()  # Raise an exception for bad status codes (4xx or 5xx) | ||||
|             logger.debug("Prefix cache reset successful: %s", response.status) | ||||
|     except aiohttp.ClientConnectorError as e: | ||||
|         logger.error("Failed to connect to cache reset endpoint %s: %s}", reset_url, e) | ||||
|     except aiohttp.ClientResponseError as e: | ||||
|         logger.error( | ||||
|             "Cache reset request failed with status %s: %s", e.status, e.message | ||||
|         ) | ||||
|     except Exception as e: | ||||
|         logger.error("An unexpected error occurred during cache reset: %s", e) | ||||
|  | ||||
|  | ||||
| async def sequential_benchmark( | ||||
|     backend: str, | ||||
|     api_url: str, | ||||
|     model_id: str, | ||||
|     tokenizer, | ||||
|     input_requests: list[SampleRequest], | ||||
|     request_func, | ||||
|     selected_percentiles: list[float], | ||||
|     cache_reset_url: Optional[str] = None, | ||||
| ): | ||||
|     """ | ||||
|     Benchmark that processes requests sequentially, waiting for each to complete | ||||
|     before starting the next one. Resets prefix cache between requests. | ||||
|     """ | ||||
|     outputs = [] | ||||
|  | ||||
|     pbar = tqdm(total=len(input_requests)) | ||||
|  | ||||
|     benchmark_start_time = time.perf_counter() | ||||
|  | ||||
|     # Process requests sequentially | ||||
|     for request in input_requests: | ||||
|         prompt, prompt_len, output_len = ( | ||||
|             request.prompt, | ||||
|             request.prompt_len, | ||||
|             request.expected_output_len, | ||||
|         ) | ||||
|  | ||||
|         logger.info("Sending request with len %s", request.prompt_len) | ||||
|         logger.debug('Request str: "%s"', request.prompt[:50]) | ||||
|         request_start_time = time.perf_counter() | ||||
|  | ||||
|         request_func_input = RequestFuncInput( | ||||
|             model=model_id, | ||||
|             prompt=prompt, | ||||
|             api_url=api_url, | ||||
|             prompt_len=prompt_len, | ||||
|             output_len=output_len, | ||||
|         ) | ||||
|  | ||||
|         output = await request_func(request_func_input=request_func_input) | ||||
|  | ||||
|         request_end_time = time.perf_counter() | ||||
|         # Add timing information | ||||
|         if output.success and not hasattr(output, "latency"): | ||||
|             output.latency = request_end_time - request_start_time | ||||
|         logger.info("Finished request with latency %.4f s", output.latency) | ||||
|  | ||||
|         outputs.append(output) | ||||
|         pbar.update(1) | ||||
|  | ||||
|     pbar.close() | ||||
|  | ||||
|     benchmark_duration = time.perf_counter() - benchmark_start_time | ||||
|  | ||||
|     # Calculate metrics | ||||
|     metrics = calculate_metrics( | ||||
|         input_requests=input_requests, | ||||
|         outputs=outputs, | ||||
|         dur_s=benchmark_duration, | ||||
|         tokenizer=tokenizer, | ||||
|         selected_percentiles=selected_percentiles, | ||||
|     ) | ||||
|  | ||||
|     print_results(metrics, benchmark_duration) | ||||
|  | ||||
|     result = { | ||||
|         "duration": benchmark_duration, | ||||
|         "completed": metrics.completed, | ||||
|         "total_input_tokens": metrics.total_input, | ||||
|         "total_output_tokens": metrics.total_output, | ||||
|         "input_lens": [request.prompt_len for request in input_requests], | ||||
|         "output_lens": [ | ||||
|             output.output_tokens if output.success else 0 for output in outputs | ||||
|         ], | ||||
|         "ttfts": [output.ttft for output in outputs if output.success], | ||||
|         "itls": [output.itl for output in outputs if output.success], | ||||
|         "generated_texts": [ | ||||
|             output.generated_text for output in outputs if output.success | ||||
|         ], | ||||
|         "errors": [output.error for output in outputs if not output.success], | ||||
|     } | ||||
|  | ||||
|     # Add summary statistics | ||||
|     for stat_name in ["ttft", "itl", "e2el"]: | ||||
|         for metric_name in ["mean", "median", "std"]: | ||||
|             result[f"{metric_name}_{stat_name}_ms"] = getattr( | ||||
|                 metrics, f"{metric_name}_{stat_name}_ms" | ||||
|             ) | ||||
|  | ||||
|         for p, value in getattr(metrics, f"percentiles_{stat_name}_ms"): | ||||
|             p_word = str(int(p)) if int(p) == p else str(p) | ||||
|             result[f"p{p_word}_{stat_name}_ms"] = value | ||||
|  | ||||
|     return result | ||||
|  | ||||
|  | ||||
| def calculate_metrics( | ||||
|     input_requests: list[SampleRequest], | ||||
|     outputs: list[RequestFuncOutput], | ||||
|     dur_s: float, | ||||
|     tokenizer, | ||||
|     selected_percentiles: list[float], | ||||
| ) -> BenchmarkMetrics: | ||||
|     """Calculate benchmark metrics from results.""" | ||||
|     total_input = 0 | ||||
|     completed = 0 | ||||
|     total_output = 0 | ||||
|     ttfts = [] | ||||
|     itls = [] | ||||
|     e2els = [] | ||||
|  | ||||
|     for i, output in enumerate(outputs): | ||||
|         if output.success: | ||||
|             output_len = output.output_tokens | ||||
|  | ||||
|             if not output_len: | ||||
|                 # Use tokenizer to count output tokens if not provided | ||||
|                 output_len = len( | ||||
|                     tokenizer(output.generated_text, add_special_tokens=False).input_ids | ||||
|                 ) | ||||
|  | ||||
|             total_output += output_len | ||||
|             total_input += input_requests[i].prompt_len | ||||
|  | ||||
|             if hasattr(output, "ttft") and output.ttft is not None: | ||||
|                 ttfts.append(output.ttft) | ||||
|  | ||||
|             if hasattr(output, "itl") and output.itl: | ||||
|                 # Ensure itl is a list of floats | ||||
|                 if isinstance(output.itl, list): | ||||
|                     itls.extend(output.itl) | ||||
|                 else: | ||||
|                     logger.warning( | ||||
|                         "Expected list for ITL but got %s. Appending as is.", | ||||
|                         type(output.itl), | ||||
|                     ) | ||||
|                     itls.append(output.itl) | ||||
|  | ||||
|             if hasattr(output, "latency") and output.latency is not None: | ||||
|                 e2els.append(output.latency) | ||||
|  | ||||
|             completed += 1 | ||||
|  | ||||
|     return BenchmarkMetrics( | ||||
|         completed=completed, | ||||
|         total_input=total_input, | ||||
|         total_output=total_output, | ||||
|         mean_ttft_ms=np.mean(ttfts or [0]) * 1000, | ||||
|         median_ttft_ms=np.median(ttfts or [0]) * 1000, | ||||
|         std_ttft_ms=np.std(ttfts or [0]) * 1000, | ||||
|         percentiles_ttft_ms=[ | ||||
|             (p, np.percentile(ttfts or [0], p) * 1000) for p in selected_percentiles | ||||
|         ], | ||||
|         mean_itl_ms=np.mean(itls or [0]) * 1000, | ||||
|         median_itl_ms=np.median(itls or [0]) * 1000, | ||||
|         std_itl_ms=np.std(itls or [0]) * 1000, | ||||
|         percentiles_itl_ms=[ | ||||
|             (p, np.percentile(itls or [0], p) * 1000) for p in selected_percentiles | ||||
|         ], | ||||
|         mean_e2el_ms=np.mean(e2els or [0]) * 1000, | ||||
|         median_e2el_ms=np.median(e2els or [0]) * 1000, | ||||
|         std_e2el_ms=np.std(e2els or [0]) * 1000, | ||||
|         percentiles_e2el_ms=[ | ||||
|             (p, np.percentile(e2els or [0], p) * 1000) for p in selected_percentiles | ||||
|         ], | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def print_results(metrics: BenchmarkMetrics, benchmark_duration: float): | ||||
|     """Print benchmark results in a formatted way.""" | ||||
|     print("{s:{c}^{n}}".format(s=" Sequential Benchmark Result ", n=60, c="=")) | ||||
|     print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) | ||||
|     print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) | ||||
|     print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) | ||||
|     print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) | ||||
|  | ||||
|     def print_metric_stats(metric_name, header): | ||||
|         print("{s:{c}^{n}}".format(s=header, n=60, c="-")) | ||||
|         print( | ||||
|             "{:<40} {:<10.2f}".format( | ||||
|                 f"Mean {metric_name} (ms):", | ||||
|                 getattr(metrics, f"mean_{metric_name.lower()}_ms"), | ||||
|             ) | ||||
|         ) | ||||
|         print( | ||||
|             "{:<40} {:<10.2f}".format( | ||||
|                 f"Median {metric_name} (ms):", | ||||
|                 getattr(metrics, f"median_{metric_name.lower()}_ms"), | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|         for p, value in getattr(metrics, f"percentiles_{metric_name.lower()}_ms"): | ||||
|             p_word = str(int(p)) if int(p) == p else str(p) | ||||
|             print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value)) | ||||
|  | ||||
|     print_metric_stats("TTFT", "Time to First Token") | ||||
|     print_metric_stats("ITL", "Inter-token Latency") | ||||
|     print_metric_stats("E2EL", "End-to-end Latency") | ||||
|     print("=" * 60) | ||||
|  | ||||
|  | ||||
| async def main_async(args): | ||||
|     # Import needed functions based on your setup | ||||
|     from backend_request_func import ASYNC_REQUEST_FUNCS | ||||
|  | ||||
|     backend = args.backend | ||||
|     model_id = args.model | ||||
|     tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model | ||||
|  | ||||
|     # Set up API URL | ||||
|     if args.base_url is not None: | ||||
|         api_url = f"{args.base_url}{args.endpoint}" | ||||
|     else: | ||||
|         api_url = f"http://{args.host}:{args.port}{args.endpoint}" | ||||
|  | ||||
|     # Set up Cache Reset URL | ||||
|     cache_reset_url = f"http://{args.host}:{args.port}/reset_prefix_cache" | ||||
|     logger.info("Prefix cache reset configured at: %s", cache_reset_url) | ||||
|  | ||||
|     # Get tokenizer | ||||
|     tokenizer = get_tokenizer(tokenizer_id, trust_remote_code=args.trust_remote_code) | ||||
|  | ||||
|     # Get request function | ||||
|     if backend in ASYNC_REQUEST_FUNCS: | ||||
|         request_func = ASYNC_REQUEST_FUNCS[backend] | ||||
|     else: | ||||
|         raise ValueError(f"Unknown backend: {backend}") | ||||
|  | ||||
|     input_requests = RandomDataset().sample( | ||||
|         tokenizer=tokenizer, | ||||
|         num_requests=args.num_requests, | ||||
|         prefix_len=0, | ||||
|         input_len=args.input_len, | ||||
|         output_len=args.output_len, | ||||
|         range_ratio=0.0, | ||||
|     ) | ||||
|  | ||||
|     # Run benchmark | ||||
|     result = await sequential_benchmark( | ||||
|         backend=backend, | ||||
|         api_url=api_url, | ||||
|         model_id=model_id, | ||||
|         tokenizer=tokenizer, | ||||
|         input_requests=input_requests, | ||||
|         request_func=request_func, | ||||
|         selected_percentiles=[50, 90, 95, 99], | ||||
|         cache_reset_url=cache_reset_url, | ||||
|     ) | ||||
|  | ||||
|     return result | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
|     print(args) | ||||
|     random.seed(args.seed) | ||||
|     np.random.seed(args.seed) | ||||
|  | ||||
|     asyncio.run(main_async(args)) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="Sequential benchmark for LLM serving") | ||||
|     parser.add_argument( | ||||
|         "--backend", type=str, default="vllm", help="Backend to use for requests" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--base-url", | ||||
|         type=str, | ||||
|         default=None, | ||||
|         help="Server base URL (overrides --host and --port)", | ||||
|     ) | ||||
|     parser.add_argument("--host", type=str, default="127.0.0.1") | ||||
|     parser.add_argument("--port", type=int, default=8000) | ||||
|     parser.add_argument( | ||||
|         "--endpoint", type=str, default="/v1/completions", help="API endpoint" | ||||
|     ) | ||||
|     parser.add_argument("--model", type=str, required=True, help="Name of the model") | ||||
|     parser.add_argument( | ||||
|         "--tokenizer", type=str, help="Name of the tokenizer (defaults to model name)" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--num-requests", type=int, default=100, help="Number of requests to process" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--input-len", type=int, default=128, help="Input len for generated prompts" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--output-len", type=int, default=None, help="Override output len for requests" | ||||
|     ) | ||||
|     parser.add_argument("--seed", type=int, default=42) | ||||
|     parser.add_argument( | ||||
|         "--trust-remote-code", | ||||
|         action="store_true", | ||||
|         help="Trust remote code from HuggingFace", | ||||
|     ) | ||||
|  | ||||
|     args = parser.parse_args() | ||||
|     main(args) | ||||
| @ -218,7 +218,7 @@ def main(args): | ||||
|     ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
| def create_argument_parser(): | ||||
|     parser = FlexibleArgumentParser( | ||||
|         description="Benchmark the performance with or without " | ||||
|         "automatic prefix caching." | ||||
| @ -268,5 +268,11 @@ if __name__ == "__main__": | ||||
|     ) | ||||
|  | ||||
|     parser = EngineArgs.add_cli_args(parser) | ||||
|  | ||||
|     return parser | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = create_argument_parser() | ||||
|     args = parser.parse_args() | ||||
|     main(args) | ||||
|  | ||||
| @ -161,7 +161,7 @@ def main(args: argparse.Namespace): | ||||
|             json.dump(results, f, indent=4) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
| def create_argument_parser(): | ||||
|     parser = FlexibleArgumentParser(description="Benchmark the throughput.") | ||||
|     parser.add_argument( | ||||
|         "--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm" | ||||
| @ -204,6 +204,12 @@ if __name__ == "__main__": | ||||
|     ) | ||||
|  | ||||
|     parser = EngineArgs.add_cli_args(parser) | ||||
|  | ||||
|     return parser | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = create_argument_parser() | ||||
|     args = parser.parse_args() | ||||
|     if args.tokenizer is None: | ||||
|         args.tokenizer = args.model | ||||
|  | ||||
| @ -33,7 +33,7 @@ import warnings | ||||
| from collections.abc import AsyncGenerator, Iterable | ||||
| from dataclasses import dataclass | ||||
| from datetime import datetime | ||||
| from typing import Any, Optional | ||||
| from typing import Any, Literal, Optional | ||||
|  | ||||
| import numpy as np | ||||
| from tqdm.asyncio import tqdm | ||||
| @ -107,14 +107,42 @@ class BenchmarkMetrics: | ||||
|     percentiles_e2el_ms: list[tuple[float, float]] | ||||
|  | ||||
|  | ||||
| def _get_current_request_rate( | ||||
|     ramp_up_strategy: Optional[Literal["linear", "exponential"]], | ||||
|     ramp_up_start_rps: Optional[int], | ||||
|     ramp_up_end_rps: Optional[int], | ||||
|     request_index: int, | ||||
|     total_requests: int, | ||||
|     request_rate: float, | ||||
| ) -> float: | ||||
|     if ( | ||||
|         ramp_up_strategy | ||||
|         and ramp_up_start_rps is not None | ||||
|         and ramp_up_end_rps is not None | ||||
|     ): | ||||
|         progress = request_index / max(total_requests - 1, 1) | ||||
|         if ramp_up_strategy == "linear": | ||||
|             increase = (ramp_up_end_rps - ramp_up_start_rps) * progress | ||||
|             return ramp_up_start_rps + increase | ||||
|         elif ramp_up_strategy == "exponential": | ||||
|             ratio = ramp_up_end_rps / ramp_up_start_rps | ||||
|             return ramp_up_start_rps * (ratio**progress) | ||||
|         else: | ||||
|             raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}") | ||||
|     return request_rate | ||||
|  | ||||
|  | ||||
| async def get_request( | ||||
|     input_requests: list[SampleRequest], | ||||
|     request_rate: float, | ||||
|     burstiness: float = 1.0, | ||||
| ) -> AsyncGenerator[SampleRequest, None]: | ||||
|     ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, | ||||
|     ramp_up_start_rps: Optional[int] = None, | ||||
|     ramp_up_end_rps: Optional[int] = None, | ||||
| ) -> AsyncGenerator[tuple[SampleRequest, float], None]: | ||||
|     """ | ||||
|     Asynchronously generates requests at a specified rate | ||||
|     with OPTIONAL burstiness. | ||||
|     with OPTIONAL burstiness and OPTIONAL ramp-up strategy. | ||||
|  | ||||
|     Args: | ||||
|         input_requests: | ||||
| @ -129,22 +157,44 @@ async def get_request( | ||||
|             A lower burstiness value (0 < burstiness < 1) results | ||||
|             in more bursty requests, while a higher burstiness value | ||||
|             (burstiness > 1) results in a more uniform arrival of requests. | ||||
|          ramp_up_strategy (optional): | ||||
|             The ramp-up strategy. Can be "linear" or "exponential". | ||||
|             If None, uses constant request rate (specified by request_rate). | ||||
|         ramp_up_start_rps (optional): | ||||
|             The starting request rate for ramp-up. | ||||
|         ramp_up_end_rps (optional): | ||||
|             The ending request rate for ramp-up. | ||||
|     """ | ||||
|     input_requests: Iterable[SampleRequest] = iter(input_requests) | ||||
|  | ||||
|     # Calculate scale parameter theta to maintain the desired request_rate. | ||||
|     assert burstiness > 0, ( | ||||
|         f"A positive burstiness factor is expected, but given {burstiness}." | ||||
|     ) | ||||
|     theta = 1.0 / (request_rate * burstiness) | ||||
|     # Convert to list to get length for ramp-up calculations | ||||
|     if isinstance(input_requests, Iterable) and not isinstance(input_requests, list): | ||||
|         input_requests = list(input_requests) | ||||
|  | ||||
|     total_requests = len(input_requests) | ||||
|     request_index = 0 | ||||
|  | ||||
|     for request in input_requests: | ||||
|         yield request | ||||
|         current_request_rate = _get_current_request_rate( | ||||
|             ramp_up_strategy, | ||||
|             ramp_up_start_rps, | ||||
|             ramp_up_end_rps, | ||||
|             request_index, | ||||
|             total_requests, | ||||
|             request_rate, | ||||
|         ) | ||||
|  | ||||
|         if request_rate == float("inf"): | ||||
|         yield request, current_request_rate | ||||
|  | ||||
|         request_index += 1 | ||||
|  | ||||
|         if current_request_rate == float("inf"): | ||||
|             # If the request rate is infinity, then we don't need to wait. | ||||
|             continue | ||||
|  | ||||
|         theta = 1.0 / (current_request_rate * burstiness) | ||||
|  | ||||
|         # Sample the request interval from the gamma distribution. | ||||
|         # If burstiness is 1, it follows exponential distribution. | ||||
|         interval = np.random.gamma(shape=burstiness, scale=theta) | ||||
| @ -290,6 +340,9 @@ async def benchmark( | ||||
|     max_concurrency: Optional[int], | ||||
|     lora_modules: Optional[Iterable[str]], | ||||
|     extra_body: Optional[dict], | ||||
|     ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, | ||||
|     ramp_up_start_rps: Optional[int] = None, | ||||
|     ramp_up_end_rps: Optional[int] = None, | ||||
| ): | ||||
|     if backend in ASYNC_REQUEST_FUNCS: | ||||
|         request_func = ASYNC_REQUEST_FUNCS[backend] | ||||
| @ -353,7 +406,15 @@ async def benchmark( | ||||
|  | ||||
|     distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" | ||||
|  | ||||
|     print(f"Traffic request rate: {request_rate}") | ||||
|     if ramp_up_strategy is not None: | ||||
|         print( | ||||
|             f"Traffic ramp-up strategy: {ramp_up_strategy}. Will increase " | ||||
|             f"RPS from {ramp_up_start_rps} to {ramp_up_end_rps} RPS over " | ||||
|             "the duration of the benchmark." | ||||
|         ) | ||||
|     else: | ||||
|         print(f"Traffic request rate: {request_rate} RPS.") | ||||
|  | ||||
|     print(f"Burstiness factor: {burstiness} ({distribution})") | ||||
|     print(f"Maximum request concurrency: {max_concurrency}") | ||||
|  | ||||
| @ -373,7 +434,34 @@ async def benchmark( | ||||
|  | ||||
|     benchmark_start_time = time.perf_counter() | ||||
|     tasks: list[asyncio.Task] = [] | ||||
|     async for request in get_request(input_requests, request_rate, burstiness): | ||||
|  | ||||
|     rps_change_events = [] | ||||
|     last_int_rps = -1 | ||||
|     if ramp_up_strategy is not None and ramp_up_start_rps is not None: | ||||
|         last_int_rps = ramp_up_start_rps | ||||
|         rps_change_events.append( | ||||
|             { | ||||
|                 "rps": last_int_rps, | ||||
|                 "timestamp": datetime.now().isoformat(), | ||||
|             } | ||||
|         ) | ||||
|  | ||||
|     async for request, current_request_rate in get_request( | ||||
|         input_requests, | ||||
|         request_rate, | ||||
|         burstiness, | ||||
|         ramp_up_strategy, | ||||
|         ramp_up_start_rps, | ||||
|         ramp_up_end_rps, | ||||
|     ): | ||||
|         if ramp_up_strategy is not None: | ||||
|             current_int_rps = int(current_request_rate) | ||||
|             if current_int_rps > last_int_rps: | ||||
|                 timestamp = datetime.now().isoformat() | ||||
|                 for rps_val in range(last_int_rps + 1, current_int_rps + 1): | ||||
|                     rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) | ||||
|                 last_int_rps = current_int_rps | ||||
|  | ||||
|         prompt, prompt_len, output_len, mm_content = ( | ||||
|             request.prompt, | ||||
|             request.prompt_len, | ||||
| @ -397,11 +485,8 @@ async def benchmark( | ||||
|             ignore_eos=ignore_eos, | ||||
|             extra_body=extra_body, | ||||
|         ) | ||||
|         tasks.append( | ||||
|             asyncio.create_task( | ||||
|                 limited_request_func(request_func_input=request_func_input, pbar=pbar) | ||||
|             ) | ||||
|         ) | ||||
|         task = limited_request_func(request_func_input=request_func_input, pbar=pbar) | ||||
|         tasks.append(asyncio.create_task(task)) | ||||
|     outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) | ||||
|  | ||||
|     if profile: | ||||
| @ -466,7 +551,7 @@ async def benchmark( | ||||
|         "total_input_tokens": metrics.total_input, | ||||
|         "total_output_tokens": metrics.total_output, | ||||
|         "request_throughput": metrics.request_throughput, | ||||
|         "request_goodput:": metrics.request_goodput if goodput_config_dict else None, | ||||
|         "request_goodput": metrics.request_goodput if goodput_config_dict else None, | ||||
|         "output_throughput": metrics.output_throughput, | ||||
|         "total_token_throughput": metrics.total_token_throughput, | ||||
|         "input_lens": [output.prompt_len for output in outputs], | ||||
| @ -477,6 +562,9 @@ async def benchmark( | ||||
|         "errors": [output.error for output in outputs], | ||||
|     } | ||||
|  | ||||
|     if rps_change_events: | ||||
|         result["rps_change_events"] = rps_change_events | ||||
|  | ||||
|     def process_one_metric( | ||||
|         # E.g., "ttft" | ||||
|         metric_attribute_name: str, | ||||
| @ -610,6 +698,26 @@ def main(args: argparse.Namespace): | ||||
|     tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model | ||||
|     tokenizer_mode = args.tokenizer_mode | ||||
|  | ||||
|     # Validate ramp-up arguments | ||||
|     if args.ramp_up_strategy is not None: | ||||
|         if args.request_rate != float("inf"): | ||||
|             raise ValueError( | ||||
|                 "When using ramp-up, do not specify --request-rate. " | ||||
|                 "The request rate will be controlled by ramp-up parameters. " | ||||
|                 "Please remove the --request-rate argument." | ||||
|             ) | ||||
|         if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None: | ||||
|             raise ValueError( | ||||
|                 "When using --ramp-up-strategy, both --ramp-up-start-rps and " | ||||
|                 "--ramp-up-end-rps must be specified" | ||||
|             ) | ||||
|         if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0: | ||||
|             raise ValueError("Ramp-up start and end RPS must be non-negative") | ||||
|         if args.ramp_up_start_rps > args.ramp_up_end_rps: | ||||
|             raise ValueError("Ramp-up start RPS must be less than end RPS") | ||||
|         if args.ramp_up_strategy == "exponential" and args.ramp_up_start_rps == 0: | ||||
|             raise ValueError("For exponential ramp-up, the start RPS cannot be 0.") | ||||
|  | ||||
|     if args.base_url is not None: | ||||
|         api_url = f"{args.base_url}{args.endpoint}" | ||||
|         base_url = f"{args.base_url}" | ||||
| @ -802,6 +910,9 @@ def main(args: argparse.Namespace): | ||||
|             max_concurrency=args.max_concurrency, | ||||
|             lora_modules=args.lora_modules, | ||||
|             extra_body=sampling_params, | ||||
|             ramp_up_strategy=args.ramp_up_strategy, | ||||
|             ramp_up_start_rps=args.ramp_up_start_rps, | ||||
|             ramp_up_end_rps=args.ramp_up_end_rps, | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
| @ -834,6 +945,11 @@ def main(args: argparse.Namespace): | ||||
|         result_json["burstiness"] = args.burstiness | ||||
|         result_json["max_concurrency"] = args.max_concurrency | ||||
|  | ||||
|         if args.ramp_up_strategy is not None: | ||||
|             result_json["ramp_up_strategy"] = args.ramp_up_strategy | ||||
|             result_json["ramp_up_start_rps"] = args.ramp_up_start_rps | ||||
|             result_json["ramp_up_end_rps"] = args.ramp_up_end_rps | ||||
|  | ||||
|         # Merge with benchmark result | ||||
|         result_json = {**result_json, **benchmark_result} | ||||
|  | ||||
| @ -859,7 +975,10 @@ def main(args: argparse.Namespace): | ||||
|             if args.max_concurrency is not None | ||||
|             else "" | ||||
|         ) | ||||
|         file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"  # noqa | ||||
|         if args.ramp_up_strategy is not None: | ||||
|             file_name = f"{backend}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"  # noqa | ||||
|         else: | ||||
|             file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json"  # noqa | ||||
|         if args.result_filename: | ||||
|             file_name = args.result_filename | ||||
|         if args.result_dir: | ||||
| @ -875,7 +994,7 @@ def main(args: argparse.Namespace): | ||||
|         save_to_pytorch_benchmark_format(args, result_json, file_name) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
| def create_argument_parser(): | ||||
|     parser = FlexibleArgumentParser( | ||||
|         description="Benchmark the online serving throughput." | ||||
|     ) | ||||
| @ -1225,6 +1344,35 @@ if __name__ == "__main__": | ||||
|         "script chooses a LoRA module at random.", | ||||
|     ) | ||||
|  | ||||
|     args = parser.parse_args() | ||||
|     parser.add_argument( | ||||
|         "--ramp-up-strategy", | ||||
|         type=str, | ||||
|         default=None, | ||||
|         choices=["linear", "exponential"], | ||||
|         help="The ramp-up strategy. This would be used to " | ||||
|         "ramp up the request rate from initial RPS to final " | ||||
|         "RPS rate (specified by --ramp-up-start-rps and --ramp-up-end-rps). " | ||||
|         "over the duration of the benchmark.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--ramp-up-start-rps", | ||||
|         type=int, | ||||
|         default=None, | ||||
|         help="The starting request rate for ramp-up (RPS). " | ||||
|         "Needs to be specified when --ramp-up-strategy is used.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--ramp-up-end-rps", | ||||
|         type=int, | ||||
|         default=None, | ||||
|         help="The ending request rate for ramp-up (RPS). " | ||||
|         "Needs to be specified when --ramp-up-strategy is used.", | ||||
|     ) | ||||
|  | ||||
|     return parser | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = create_argument_parser() | ||||
|     args = parser.parse_args() | ||||
|     main(args) | ||||
|  | ||||
| @ -850,7 +850,7 @@ def main(args: argparse.Namespace): | ||||
|             json.dump(results, outfile, indent=4) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
| def create_argument_parser(): | ||||
|     parser = FlexibleArgumentParser( | ||||
|         description="Benchmark the online serving throughput." | ||||
|     ) | ||||
| @ -1034,5 +1034,10 @@ if __name__ == "__main__": | ||||
|         help="Ratio of Structured Outputs requests", | ||||
|     ) | ||||
|  | ||||
|     return parser | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = create_argument_parser() | ||||
|     args = parser.parse_args() | ||||
|     main(args) | ||||
|  | ||||
| @ -97,7 +97,7 @@ def run_vllm( | ||||
|         assert lora_requests is None, "BeamSearch API does not support LoRA" | ||||
|         prompts = [request.prompt for request in requests] | ||||
|         # output_len should be the same for all requests. | ||||
|         output_len = requests[0][2] | ||||
|         output_len = requests[0].expected_output_len | ||||
|         for request in requests: | ||||
|             assert request.expected_output_len == output_len | ||||
|         start = time.perf_counter() | ||||
| @ -595,7 +595,7 @@ def validate_args(args): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
| def create_argument_parser(): | ||||
|     parser = FlexibleArgumentParser(description="Benchmark the throughput.") | ||||
|     parser.add_argument( | ||||
|         "--backend", | ||||
| @ -717,6 +717,12 @@ if __name__ == "__main__": | ||||
|     ) | ||||
|  | ||||
|     parser = AsyncEngineArgs.add_cli_args(parser) | ||||
|  | ||||
|     return parser | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = create_argument_parser() | ||||
|     args = parser.parse_args() | ||||
|     if args.tokenizer is None: | ||||
|         args.tokenizer = args.model | ||||
|  | ||||
| @ -19,7 +19,7 @@ from vllm import _custom_ops as ops | ||||
| from vllm.model_executor.layers.quantization.utils.fp8_utils import ( | ||||
|     w8a8_block_fp8_matmul, | ||||
| ) | ||||
| from vllm.utils import FlexibleArgumentParser | ||||
| from vllm.utils import FlexibleArgumentParser, cdiv | ||||
|  | ||||
| DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) | ||||
| DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] | ||||
| @ -117,14 +117,9 @@ def bench_fp8( | ||||
|     scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) | ||||
|     scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) | ||||
|  | ||||
|     def ceil_div(x: int, y: int) -> int: | ||||
|         return (x + y - 1) // y | ||||
|  | ||||
|     block_scale_a = torch.rand( | ||||
|         (m, ceil_div(k, 128)), device="cuda", dtype=torch.float32 | ||||
|     ) | ||||
|     block_scale_a = torch.rand((m, cdiv(k, 128)), device="cuda", dtype=torch.float32) | ||||
|     block_scale_b = torch.rand( | ||||
|         ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32 | ||||
|         cdiv(k, 128), cdiv(n, 128), device="cuda", dtype=torch.float32 | ||||
|     ) | ||||
|     block_scale_a_M_major = block_scale_a.t().contiguous().t() | ||||
|     block_scale_b_K_major = block_scale_b.t().contiguous().t() | ||||
|  | ||||
| @ -1,5 +1,4 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
| import argparse | ||||
| import copy | ||||
| import itertools | ||||
| @ -11,6 +10,80 @@ from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm | ||||
| from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant | ||||
| from vllm.triton_utils import triton | ||||
|  | ||||
| PROVIDER_CFGS = { | ||||
|     "torch-bf16": dict(enabled=True), | ||||
|     "fp8-tensor-w-token-a": dict( | ||||
|         w="tensor", a="token", no_a_quant=False, enabled=False | ||||
|     ), | ||||
|     "fp8-tensor-w-tensor-a": dict( | ||||
|         w="tensor", a="tensor", no_a_quant=False, enabled=True | ||||
|     ), | ||||
|     "fp8-channel-w-token-a": dict( | ||||
|         w="channel", a="token", no_a_quant=False, enabled=True | ||||
|     ), | ||||
|     "fp8-channel-w-tensor-a": dict( | ||||
|         w="channel", a="tensor", no_a_quant=False, enabled=False | ||||
|     ), | ||||
|     "fp8-tensor-w-token-a-noquant": dict( | ||||
|         w="tensor", a="token", no_a_quant=True, enabled=False | ||||
|     ), | ||||
|     "fp8-tensor-w-tensor-a-noquant": dict( | ||||
|         w="tensor", a="tensor", no_a_quant=True, enabled=True | ||||
|     ), | ||||
|     "fp8-channel-w-token-a-noquant": dict( | ||||
|         w="channel", a="token", no_a_quant=True, enabled=True | ||||
|     ), | ||||
|     "fp8-channel-w-tensor-a-noquant": dict( | ||||
|         w="channel", a="tensor", no_a_quant=True, enabled=False | ||||
|     ), | ||||
| } | ||||
|  | ||||
| _enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] | ||||
|  | ||||
|  | ||||
| def _quant_weight_fp8(b: torch.Tensor, w_type: str, device: str): | ||||
|     if w_type == "tensor": | ||||
|         scale_b = torch.ones(1, device=device, dtype=torch.float32) | ||||
|         b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) | ||||
|     else: | ||||
|         b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, use_per_token_if_dynamic=True) | ||||
|     return b_fp8.t(), scale_b_fp8 | ||||
|  | ||||
|  | ||||
| def build_fp8_runner(cfg, a, b, dtype, device): | ||||
|     b_fp8, scale_b_fp8 = _quant_weight_fp8(b, cfg["w"], device) | ||||
|  | ||||
|     scale_a_const = ( | ||||
|         torch.ones(1, device=device, dtype=torch.float32) | ||||
|         if cfg["a"] == "tensor" | ||||
|         else None | ||||
|     ) | ||||
|  | ||||
|     if cfg["no_a_quant"]: | ||||
|         if cfg["a"] == "tensor": | ||||
|             a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const) | ||||
|         else: | ||||
|             a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True) | ||||
|  | ||||
|         def run(): | ||||
|             return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||||
|  | ||||
|         return run | ||||
|  | ||||
|     if cfg["a"] == "tensor": | ||||
|  | ||||
|         def run(): | ||||
|             a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const) | ||||
|             return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||||
|  | ||||
|     else: | ||||
|  | ||||
|         def run(): | ||||
|             a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True) | ||||
|             return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||||
|  | ||||
|     return run | ||||
|  | ||||
|  | ||||
| @triton.testing.perf_report( | ||||
|     triton.testing.Benchmark( | ||||
| @ -18,28 +91,8 @@ from vllm.triton_utils import triton | ||||
|         x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], | ||||
|         x_log=False, | ||||
|         line_arg="provider", | ||||
|         line_vals=[ | ||||
|             "torch-bf16", | ||||
|             # "fp8-tensor-w-token-a", | ||||
|             "fp8-tensor-w-tensor-a", | ||||
|             "fp8-channel-w-token-a", | ||||
|             # "fp8-channel-w-tensor-a", | ||||
|             # "fp8-tensor-w-token-a-noquant", | ||||
|             "fp8-tensor-w-tensor-a-noquant", | ||||
|             "fp8-channel-w-token-a-noquant", | ||||
|             # "fp8-channel-w-tensor-a-noquant", | ||||
|         ], | ||||
|         line_names=[ | ||||
|             "torch-bf16", | ||||
|             # "fp8-tensor-w-token-a", | ||||
|             "fp8-tensor-w-tensor-a", | ||||
|             "fp8-channel-w-token-a", | ||||
|             # "fp8-channel-w-tensor-a", | ||||
|             # "fp8-tensor-w-token-a-noquant", | ||||
|             "fp8-tensor-w-tensor-a-noquant", | ||||
|             "fp8-channel-w-token-a-noquant", | ||||
|             # "fp8-channel-w-tensor-a-noquant", | ||||
|         ], | ||||
|         line_vals=_enabled, | ||||
|         line_names=_enabled, | ||||
|         ylabel="TFLOP/s (larger is better)", | ||||
|         plot_name="BF16 vs FP8 GEMMs", | ||||
|         args={}, | ||||
| @ -50,144 +103,34 @@ def benchmark(batch_size, provider, N, K): | ||||
|     device = "cuda" | ||||
|     dtype = torch.bfloat16 | ||||
|  | ||||
|     # Create input tensors | ||||
|     a = torch.randn((M, K), device=device, dtype=dtype) | ||||
|     b = torch.randn((N, K), device=device, dtype=dtype) | ||||
|  | ||||
|     quantiles = [0.5, 0.2, 0.8] | ||||
|  | ||||
|     if "torch-bf16" in provider: | ||||
|     if provider == "torch-bf16": | ||||
|         ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||||
|             lambda: torch.nn.functional.linear(a, b), quantiles=quantiles | ||||
|         ) | ||||
|  | ||||
|     elif "fp8" in provider: | ||||
|         # Weights are always quantized ahead of time | ||||
|         if "noquant" in provider: | ||||
|             # For no quantization, we just measure the GEMM | ||||
|             if "tensor-w-token-a" in provider: | ||||
|                 # Dynamic per-token quant for A, per-tensor quant for B | ||||
|                 b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b) | ||||
|                 assert scale_b_fp8.numel() == 1 | ||||
|                 a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( | ||||
|                     a, use_per_token_if_dynamic=True | ||||
|                 ) | ||||
|  | ||||
|                 def run_quant(): | ||||
|                     return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||||
|  | ||||
|             elif "tensor-w-tensor-a" in provider: | ||||
|                 # Static per-tensor quantization with fixed scales | ||||
|                 # for both A and B | ||||
|                 scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) | ||||
|                 scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) | ||||
|                 b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) | ||||
|                 assert scale_b_fp8.numel() == 1 | ||||
|                 a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) | ||||
|  | ||||
|                 def run_quant(): | ||||
|                     return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||||
|  | ||||
|             elif "channel-w-token-a" in provider: | ||||
|                 # Static per-channel quantization for weights, per-token | ||||
|                 # quant for A | ||||
|                 scale_b = torch.tensor((N,), device=device, dtype=torch.float32) | ||||
|                 b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) | ||||
|                 scale_b_fp8 = scale_b_fp8.expand(N).contiguous() | ||||
|                 assert scale_b_fp8.numel() == N | ||||
|                 a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( | ||||
|                     a, use_per_token_if_dynamic=True | ||||
|                 ) | ||||
|  | ||||
|                 def run_quant(): | ||||
|                     return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||||
|  | ||||
|             elif "channel-w-tensor-a" in provider: | ||||
|                 # Static per-channel quantization for weights, per-tensor | ||||
|                 # quant for A | ||||
|                 scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) | ||||
|                 scale_b = torch.tensor((N,), device=device, dtype=torch.float32) | ||||
|                 b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) | ||||
|                 scale_b_fp8 = scale_b_fp8.expand(N).contiguous() | ||||
|                 assert scale_b_fp8.numel() == N | ||||
|                 a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) | ||||
|  | ||||
|                 def run_quant(): | ||||
|                     return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||||
|  | ||||
|         else: | ||||
|             # In these cases, we quantize the activations during the GEMM call | ||||
|             if "tensor-w-token-a" in provider: | ||||
|                 # Dynamic per-token quant for A, per-tensor quant for B | ||||
|                 b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b) | ||||
|                 assert scale_b_fp8.numel() == 1 | ||||
|  | ||||
|                 def run_quant(): | ||||
|                     a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( | ||||
|                         a, use_per_token_if_dynamic=True | ||||
|                     ) | ||||
|                     return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||||
|  | ||||
|             elif "tensor-w-tensor-a" in provider: | ||||
|                 # Static per-tensor quantization with fixed scales | ||||
|                 # for both A and B | ||||
|                 scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) | ||||
|                 scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) | ||||
|                 b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) | ||||
|                 assert scale_b_fp8.numel() == 1 | ||||
|  | ||||
|                 def run_quant(): | ||||
|                     a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) | ||||
|                     return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||||
|  | ||||
|             elif "channel-w-token-a" in provider: | ||||
|                 # Static per-channel quantization for weights, per-token | ||||
|                 # quant for A | ||||
|                 scale_b = torch.tensor((N,), device=device, dtype=torch.float32) | ||||
|                 b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) | ||||
|                 scale_b_fp8 = scale_b_fp8.expand(N).contiguous() | ||||
|                 assert scale_b_fp8.numel() == N | ||||
|  | ||||
|                 def run_quant(): | ||||
|                     a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( | ||||
|                         a, use_per_token_if_dynamic=True | ||||
|                     ) | ||||
|                     return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||||
|  | ||||
|             elif "channel-w-tensor-a" in provider: | ||||
|                 # Static per-channel quantization for weights, per-tensor | ||||
|                 # quant for A | ||||
|                 scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) | ||||
|                 scale_b = torch.tensor((N,), device=device, dtype=torch.float32) | ||||
|                 b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) | ||||
|                 scale_b_fp8 = scale_b_fp8.expand(N).contiguous() | ||||
|                 assert scale_b_fp8.numel() == N | ||||
|  | ||||
|                 def run_quant(): | ||||
|                     a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) | ||||
|                     return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) | ||||
|  | ||||
|         b_fp8 = b_fp8.t() | ||||
|  | ||||
|     else: | ||||
|         cfg = PROVIDER_CFGS[provider] | ||||
|         run_quant = build_fp8_runner(cfg, a, b, dtype, device) | ||||
|         ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||||
|             lambda: run_quant(), quantiles=quantiles | ||||
|         ) | ||||
|  | ||||
|     # Calculate TFLOP/s, two flops per multiply-add | ||||
|     tflops = lambda ms: (2 * M * N * K) * 1e-12 / (ms * 1e-3) | ||||
|     return tflops(ms), tflops(max_ms), tflops(min_ms) | ||||
|     to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) | ||||
|     return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) | ||||
|  | ||||
|  | ||||
| def prepare_shapes(args): | ||||
|     KN_model_names = [] | ||||
|     models_tps = list(itertools.product(args.models, args.tp_sizes)) | ||||
|     for model, tp_size in models_tps: | ||||
|         assert model in WEIGHT_SHAPES | ||||
|         for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): | ||||
|             KN[tp_split_dim] = KN[tp_split_dim] // tp_size | ||||
|     out = [] | ||||
|     for model, tp_size in itertools.product(args.models, args.tp_sizes): | ||||
|         for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): | ||||
|             KN[tp_dim] //= tp_size | ||||
|             KN.append(model) | ||||
|             KN_model_names.append(KN) | ||||
|     return KN_model_names | ||||
|             out.append(KN) | ||||
|     return out | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
| @ -197,21 +140,13 @@ if __name__ == "__main__": | ||||
|         nargs="+", | ||||
|         type=str, | ||||
|         default=["meta-llama/Llama-3.1-8B-Instruct"], | ||||
|         choices=[*WEIGHT_SHAPES.keys()], | ||||
|         help="List of models to benchmark", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--tp-sizes", | ||||
|         nargs="+", | ||||
|         type=int, | ||||
|         default=[1], | ||||
|         help="List of tensor parallel sizes", | ||||
|         choices=list(WEIGHT_SHAPES.keys()), | ||||
|     ) | ||||
|     parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     KN_model_names = prepare_shapes(args) | ||||
|     for K, N, model_name in KN_model_names: | ||||
|         print(f"{model_name}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:") | ||||
|     for K, N, model in prepare_shapes(args): | ||||
|         print(f"{model}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:") | ||||
|         benchmark.run( | ||||
|             print_data=True, | ||||
|             show_plots=True, | ||||
|  | ||||
							
								
								
									
										169
									
								
								benchmarks/kernels/bench_int8_gemm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										169
									
								
								benchmarks/kernels/bench_int8_gemm.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,169 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
| import argparse | ||||
| import copy | ||||
| import itertools | ||||
|  | ||||
| import torch | ||||
| from weight_shapes import WEIGHT_SHAPES | ||||
|  | ||||
| from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm | ||||
| from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant | ||||
| from vllm.triton_utils import triton | ||||
|  | ||||
| PROVIDER_CFGS = { | ||||
|     "torch-bf16": dict(enabled=True), | ||||
|     "int8-tensor-w-token-a": dict( | ||||
|         w="tensor", a="token", no_a_quant=False, enabled=False | ||||
|     ), | ||||
|     "int8-tensor-w-tensor-a": dict( | ||||
|         w="tensor", a="tensor", no_a_quant=False, enabled=True | ||||
|     ), | ||||
|     "int8-channel-w-token-a": dict( | ||||
|         w="channel", a="token", no_a_quant=False, enabled=True | ||||
|     ), | ||||
|     "int8-channel-w-tensor-a": dict( | ||||
|         w="channel", a="tensor", no_a_quant=False, enabled=False | ||||
|     ), | ||||
|     "int8-tensor-w-token-a-noquant": dict( | ||||
|         w="tensor", a="token", no_a_quant=True, enabled=False | ||||
|     ), | ||||
|     "int8-tensor-w-tensor-a-noquant": dict( | ||||
|         w="tensor", a="tensor", no_a_quant=True, enabled=True | ||||
|     ), | ||||
|     "int8-channel-w-token-a-noquant": dict( | ||||
|         w="channel", a="token", no_a_quant=True, enabled=True | ||||
|     ), | ||||
|     "int8-channel-w-tensor-a-noquant": dict( | ||||
|         w="channel", a="tensor", no_a_quant=True, enabled=False | ||||
|     ), | ||||
| } | ||||
|  | ||||
|  | ||||
| def _quant_weight(b, w_type, device): | ||||
|     if w_type == "tensor": | ||||
|         scale_b = torch.ones(1, device=device, dtype=torch.float32) | ||||
|         b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b) | ||||
|         assert scale_b_int8.numel() == 1 | ||||
|     else:  # channel | ||||
|         b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b) | ||||
|         assert scale_b_int8.numel() == b.shape[0] | ||||
|     return b_int8.t(), scale_b_int8 | ||||
|  | ||||
|  | ||||
| def build_int8_runner(cfg, a, b, dtype, device): | ||||
|     # quant before running the kernel | ||||
|     b_int8, scale_b_int8 = _quant_weight(b, cfg["w"], device) | ||||
|  | ||||
|     scale_a_const = None | ||||
|     if cfg["a"] == "tensor": | ||||
|         scale_a_const = torch.ones(1, device=device, dtype=torch.float32) | ||||
|  | ||||
|     # no quant, create activation ahead | ||||
|     if cfg["no_a_quant"]: | ||||
|         if cfg["a"] == "tensor": | ||||
|             a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a_const) | ||||
|         else:  # token | ||||
|             a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) | ||||
|  | ||||
|         def run_quant(): | ||||
|             return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype) | ||||
|  | ||||
|         return run_quant | ||||
|  | ||||
|     # dynamic quant, create activation inside | ||||
|     if cfg["a"] == "tensor": | ||||
|  | ||||
|         def run_quant(): | ||||
|             a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a_const) | ||||
|             return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype) | ||||
|  | ||||
|     else:  # token | ||||
|  | ||||
|         def run_quant(): | ||||
|             a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) | ||||
|             return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype) | ||||
|  | ||||
|     return run_quant | ||||
|  | ||||
|  | ||||
| _enabled = [k for k, v in PROVIDER_CFGS.items() if v.get("enabled")] | ||||
|  | ||||
|  | ||||
| @triton.testing.perf_report( | ||||
|     triton.testing.Benchmark( | ||||
|         x_names=["batch_size"], | ||||
|         x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], | ||||
|         x_log=False, | ||||
|         line_arg="provider", | ||||
|         line_vals=_enabled, | ||||
|         line_names=[k for k in _enabled], | ||||
|         ylabel="TFLOP/s (larger is better)", | ||||
|         plot_name="BF16 vs INT8 GEMMs", | ||||
|         args={}, | ||||
|     ) | ||||
| ) | ||||
| def benchmark(batch_size, provider, N, K): | ||||
|     M = batch_size | ||||
|     device = "cuda" | ||||
|     dtype = torch.bfloat16 | ||||
|     a = torch.randn((M, K), device=device, dtype=dtype) | ||||
|     b = torch.randn((N, K), device=device, dtype=dtype) | ||||
|  | ||||
|     quantiles = [0.5, 0.2, 0.8] | ||||
|  | ||||
|     if provider == "torch-bf16": | ||||
|         ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||||
|             lambda: torch.nn.functional.linear(a, b), quantiles=quantiles | ||||
|         ) | ||||
|     else: | ||||
|         cfg = PROVIDER_CFGS[provider] | ||||
|         run_quant = build_int8_runner(cfg, a, b, dtype, device) | ||||
|         ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( | ||||
|             lambda: run_quant(), quantiles=quantiles | ||||
|         ) | ||||
|  | ||||
|     to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) | ||||
|     return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) | ||||
|  | ||||
|  | ||||
| def prepare_shapes(args): | ||||
|     KN_model_names = [] | ||||
|     for model, tp_size in itertools.product(args.models, args.tp_sizes): | ||||
|         for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): | ||||
|             KN[tp_dim] //= tp_size | ||||
|             KN.append(model) | ||||
|             KN_model_names.append(KN) | ||||
|     return KN_model_names | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument( | ||||
|         "--models", | ||||
|         nargs="+", | ||||
|         type=str, | ||||
|         default=["meta-llama/Llama-3.1-8B-Instruct"], | ||||
|         choices=list(WEIGHT_SHAPES.keys()), | ||||
|         help="List of models to benchmark", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--tp-sizes", | ||||
|         nargs="+", | ||||
|         type=int, | ||||
|         default=[1], | ||||
|         help="List of tensor parallel sizes", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     for K, N, model in prepare_shapes(args): | ||||
|         print(f"{model}, N={N} K={K}, BF16 vs INT8 GEMMs TFLOP/s:") | ||||
|         benchmark.run( | ||||
|             print_data=True, | ||||
|             show_plots=True, | ||||
|             save_path=f"bench_int8_res_n{N}_k{K}", | ||||
|             N=N, | ||||
|             K=K, | ||||
|         ) | ||||
|  | ||||
|     print("Benchmark finished!") | ||||
| @ -113,6 +113,7 @@ def bench_run( | ||||
|         w2_scale: torch.Tensor, | ||||
|         topk_weights: torch.Tensor, | ||||
|         topk_ids: torch.Tensor, | ||||
|         per_act_token: bool, | ||||
|         num_repeats: int, | ||||
|     ): | ||||
|         for _ in range(num_repeats): | ||||
| @ -124,7 +125,8 @@ def bench_run( | ||||
|                 topk_ids, | ||||
|                 w1_scale, | ||||
|                 w2_scale, | ||||
|                 a1_scale=a_scale, | ||||
|                 per_act_token, | ||||
|                 a1_scale=None, | ||||
|             ) | ||||
|  | ||||
|     def run_cutlass_from_graph( | ||||
| @ -148,7 +150,8 @@ def bench_run( | ||||
|                 topk_ids, | ||||
|                 w1_scale, | ||||
|                 w2_scale, | ||||
|                 a1_scale=a_scale, | ||||
|                 per_act_token, | ||||
|                 a1_scale=None, | ||||
|             ) | ||||
|  | ||||
|     def run_triton_from_graph( | ||||
| @ -227,6 +230,7 @@ def bench_run( | ||||
|         "w2_q": w2_q, | ||||
|         "w1_scale": w1_scale, | ||||
|         "w2_scale": w2_scale, | ||||
|         "per_act_token": per_act_token, | ||||
|         # cuda graph params | ||||
|         "cutlass_graph": cutlass_graph, | ||||
|         "triton_graph": triton_graph, | ||||
| @ -287,12 +291,13 @@ def bench_run( | ||||
|         w2_scale, | ||||
|         topk_weights, | ||||
|         topk_ids, | ||||
|         per_act_token, | ||||
|         num_warmup, | ||||
|     ) | ||||
|  | ||||
|     results.append( | ||||
|         benchmark.Timer( | ||||
|             stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)",  # noqa: E501 | ||||
|             stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)",  # noqa: E501 | ||||
|             globals=globals, | ||||
|             label=label, | ||||
|             sub_label=sub_label, | ||||
|  | ||||
| @ -234,8 +234,10 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: | ||||
|  | ||||
|         fn = lambda: ops.gptq_marlin_gemm( | ||||
|             a=bt.a, | ||||
|             c=None, | ||||
|             b_q_weight=w_q, | ||||
|             b_scales=w_s, | ||||
|             global_scale=None, | ||||
|             b_zeros=w_zp, | ||||
|             g_idx=g_idx, | ||||
|             perm=sort_indices, | ||||
|  | ||||
| @ -22,8 +22,16 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( | ||||
|     MARLIN_SUPPORTED_GROUP_SIZES, | ||||
|     query_marlin_supported_quant_types, | ||||
| ) | ||||
| from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( | ||||
|     FP4_MARLIN_SUPPORTED_GROUP_SIZES, | ||||
|     rand_marlin_weight_fp4_like, | ||||
| ) | ||||
| from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( | ||||
|     marlin_quant_fp8_torch, | ||||
| ) | ||||
| from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( | ||||
|     MarlinWorkspace, | ||||
|     awq_marlin_quantize, | ||||
|     marlin_quantize, | ||||
| ) | ||||
| from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( | ||||
| @ -35,7 +43,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( | ||||
|     quantize_weights, | ||||
|     sort_weights, | ||||
| ) | ||||
| from vllm.scalar_type import ScalarType | ||||
| from vllm.scalar_type import ScalarType, scalar_types | ||||
| from vllm.utils import FlexibleArgumentParser | ||||
|  | ||||
| DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] | ||||
| @ -57,80 +65,144 @@ def bench_run( | ||||
|     size_n: int, | ||||
| ): | ||||
|     label = "Quant Matmul" | ||||
|  | ||||
|     sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format( | ||||
|         model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n | ||||
|     ) | ||||
|  | ||||
|     print(f"Testing: {sub_label}") | ||||
|  | ||||
|     a = torch.randn(size_m, size_k).to(torch.half).cuda() | ||||
|     b = torch.rand(size_k, size_n).to(torch.half).cuda() | ||||
|     has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] | ||||
|     if act_order and (group_size == -1 or group_size == size_k or has_zp): | ||||
|         return | ||||
|     if size_k % group_size != 0: | ||||
|         return | ||||
|  | ||||
|     a_tmp = torch.zeros(size_m, size_k).to(torch.half).cuda() | ||||
|  | ||||
|     # Marlin quant | ||||
|     ( | ||||
|         marlin_w_ref, | ||||
|         marlin_q_w, | ||||
|         marlin_s, | ||||
|         marlin_g_idx, | ||||
|         marlin_sort_indices, | ||||
|         marlin_rand_perm, | ||||
|     ) = marlin_quantize(b, quant_type, group_size, act_order) | ||||
|  | ||||
|     # Marlin_24 quant | ||||
|     (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = ( | ||||
|         marlin_24_quantize(b, quant_type, group_size) | ||||
|     marlin_24_supported = ( | ||||
|         quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES | ||||
|         and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES | ||||
|     ) | ||||
|  | ||||
|     marlin_zp = torch.empty(0, dtype=torch.int, device=b.device) | ||||
|  | ||||
|     # GPTQ quant | ||||
|     (w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights( | ||||
|         b, quant_type, group_size, act_order | ||||
|     repack_supported = ( | ||||
|         quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES | ||||
|         and group_size in MARLIN_SUPPORTED_GROUP_SIZES | ||||
|     ) | ||||
|     q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) | ||||
|  | ||||
|     # For act_order, sort the "weights" and "g_idx" | ||||
|     # so that group ids are increasing | ||||
|     repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device) | ||||
|     if act_order: | ||||
|         (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) | ||||
|  | ||||
|     # Prepare | ||||
|     marlin_workspace = MarlinWorkspace( | ||||
|         size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL | ||||
|     ) | ||||
|  | ||||
|     marlin_24_workspace = MarlinWorkspace( | ||||
|         size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL | ||||
|     ) | ||||
|     marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) | ||||
|  | ||||
|     # AllSpark W8A16 quant | ||||
|     as_supported_case = ( | ||||
|     allspark_supported = ( | ||||
|         quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES | ||||
|         and group_size == -1 | ||||
|         and not act_order | ||||
|         and is_k_full | ||||
|     ) | ||||
|     if as_supported_case: | ||||
|         properties = torch.cuda.get_device_properties(b.device.index) | ||||
|         sm_count = properties.multi_processor_count | ||||
|         sm_version = properties.major * 10 + properties.minor | ||||
|  | ||||
|         supported_arch = sm_version >= 80 and sm_version < 90 | ||||
|         as_supported_case = as_supported_case and supported_arch | ||||
|         if supported_arch: | ||||
|             has_zp = False | ||||
|             w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp) | ||||
|             qw = qw.to(torch.uint8) | ||||
|  | ||||
|             qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight( | ||||
|                 qw, s, zp, has_zp | ||||
|     def gen_marlin_params(): | ||||
|         # Marlin quant | ||||
|         marlin_g_idx = marlin_sort_indices = marlin_zp = marlin_s2 = None | ||||
|         if quant_type == scalar_types.float4_e2m1f: | ||||
|             if group_size != 16 or act_order: | ||||
|                 return | ||||
|             marlin_w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like( | ||||
|                 b.T, group_size | ||||
|             ) | ||||
|             CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD | ||||
|         elif quant_type == scalar_types.float8_e4m3fn: | ||||
|             if group_size not in [-1, 128] or act_order: | ||||
|                 return | ||||
|             marlin_w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b.T, group_size) | ||||
|         elif group_size == 16: | ||||
|             return | ||||
|         elif has_zp: | ||||
|             marlin_w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( | ||||
|                 b, quant_type, group_size | ||||
|             ) | ||||
|         else: | ||||
|             marlin_w_ref, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, _ = ( | ||||
|                 marlin_quantize(b, quant_type, group_size, act_order) | ||||
|             ) | ||||
|         return ( | ||||
|             marlin_w_ref, | ||||
|             marlin_q_w, | ||||
|             marlin_s, | ||||
|             marlin_s2, | ||||
|             marlin_zp, | ||||
|             marlin_g_idx, | ||||
|             marlin_sort_indices, | ||||
|         ) | ||||
|  | ||||
|     def gen_marlin_24_params(): | ||||
|         marlin_24_w_ref = marlin_24_q_w_comp = marlin_24_meta = marlin_24_s = None | ||||
|         if marlin_24_supported: | ||||
|             (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = ( | ||||
|                 marlin_24_quantize(b, quant_type, group_size) | ||||
|             ) | ||||
|         return (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) | ||||
|  | ||||
|     def gen_repack_params(): | ||||
|         q_w_gptq = None | ||||
|         repack_sort_indices = None | ||||
|         if repack_supported: | ||||
|             (w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights( | ||||
|                 b, quant_type, group_size, act_order | ||||
|             ) | ||||
|             q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) | ||||
|  | ||||
|             # For act_order, sort the "weights" and "g_idx" | ||||
|             # so that group ids are increasing | ||||
|             repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device) | ||||
|             if act_order: | ||||
|                 (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) | ||||
|         return q_w_gptq, repack_sort_indices | ||||
|  | ||||
|     def gen_allspark_params(): | ||||
|         qw_reorder = s_reorder = zp_reorder = sm_count = sm_version = ( | ||||
|             CUBLAS_M_THRESHOLD | ||||
|         ) = None | ||||
|         nonlocal allspark_supported | ||||
|         if allspark_supported: | ||||
|             properties = torch.cuda.get_device_properties(b.device.index) | ||||
|             sm_count = properties.multi_processor_count | ||||
|             sm_version = properties.major * 10 + properties.minor | ||||
|  | ||||
|             supported_arch = sm_version >= 80 and sm_version < 90 | ||||
|             allspark_supported = allspark_supported and supported_arch | ||||
|             if supported_arch: | ||||
|                 w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp) | ||||
|                 qw = qw.to(torch.uint8) | ||||
|  | ||||
|                 qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight( | ||||
|                     qw, s, zp, has_zp | ||||
|                 ) | ||||
|                 CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD | ||||
|         return ( | ||||
|             qw_reorder, | ||||
|             s_reorder, | ||||
|             zp_reorder, | ||||
|             sm_count, | ||||
|             sm_version, | ||||
|             CUBLAS_M_THRESHOLD, | ||||
|         ) | ||||
|  | ||||
|     ( | ||||
|         marlin_w_ref, | ||||
|         marlin_q_w, | ||||
|         marlin_s, | ||||
|         marlin_s2, | ||||
|         marlin_zp, | ||||
|         marlin_g_idx, | ||||
|         marlin_sort_indices, | ||||
|     ) = gen_marlin_params() | ||||
|     marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s = ( | ||||
|         gen_marlin_24_params() | ||||
|     ) | ||||
|     q_w_gptq, repack_sort_indices = gen_repack_params() | ||||
|     qw_reorder, s_reorder, zp_reorder, sm_count, sm_version, CUBLAS_M_THRESHOLD = ( | ||||
|         gen_allspark_params() | ||||
|     ) | ||||
|  | ||||
|     # Prepare | ||||
|     marlin_workspace = MarlinWorkspace( | ||||
|         size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL | ||||
|     ) | ||||
|     marlin_24_workspace = MarlinWorkspace( | ||||
|         size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL | ||||
|     ) | ||||
|  | ||||
|     globals = { | ||||
|         # Gen params | ||||
| @ -140,15 +212,14 @@ def bench_run( | ||||
|         "size_n": size_n, | ||||
|         "size_k": size_k, | ||||
|         "a": a, | ||||
|         "a_tmp": a_tmp, | ||||
|         # Marlin params | ||||
|         "marlin_w_ref": marlin_w_ref, | ||||
|         "marlin_q_w": marlin_q_w, | ||||
|         "marlin_s": marlin_s, | ||||
|         "marlin_s2": marlin_s2, | ||||
|         "marlin_zp": marlin_zp, | ||||
|         "marlin_g_idx": marlin_g_idx, | ||||
|         "marlin_sort_indices": marlin_sort_indices, | ||||
|         "marlin_rand_perm": marlin_rand_perm, | ||||
|         "marlin_workspace": marlin_workspace, | ||||
|         "is_k_full": is_k_full, | ||||
|         # Marlin_24 params | ||||
| @ -161,12 +232,12 @@ def bench_run( | ||||
|         "q_w_gptq": q_w_gptq, | ||||
|         "repack_sort_indices": repack_sort_indices, | ||||
|         # AllSpark W8A16 params | ||||
|         "qw_reorder": qw_reorder if as_supported_case else None, | ||||
|         "s_reorder": s_reorder if as_supported_case else None, | ||||
|         "zp_reorder": zp_reorder if as_supported_case else None, | ||||
|         "sm_count": sm_count if as_supported_case else None, | ||||
|         "sm_version": sm_version if as_supported_case else None, | ||||
|         "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD if as_supported_case else None, | ||||
|         "qw_reorder": qw_reorder, | ||||
|         "s_reorder": s_reorder, | ||||
|         "zp_reorder": zp_reorder, | ||||
|         "sm_count": sm_count, | ||||
|         "sm_version": sm_version, | ||||
|         "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD, | ||||
|         # Kernels | ||||
|         "gptq_marlin_gemm": ops.gptq_marlin_gemm, | ||||
|         "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, | ||||
| @ -177,7 +248,7 @@ def bench_run( | ||||
|     min_run_time = 1 | ||||
|  | ||||
|     # Warmup pytorch | ||||
|     for i in range(5): | ||||
|     for _ in range(5): | ||||
|         torch.matmul(a, marlin_w_ref) | ||||
|  | ||||
|     results.append( | ||||
| @ -192,17 +263,17 @@ def bench_run( | ||||
|  | ||||
|     results.append( | ||||
|         benchmark.Timer( | ||||
|             stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)",  # noqa: E501 | ||||
|             stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)",  # noqa: E501 | ||||
|             globals=globals, | ||||
|             label=label, | ||||
|             sub_label=sub_label, | ||||
|             description="gptq_marlin_gemm_fp16", | ||||
|             description="gptq_marlin_gemm", | ||||
|         ).blocked_autorange(min_run_time=min_run_time) | ||||
|     ) | ||||
|  | ||||
|     results.append( | ||||
|         benchmark.Timer( | ||||
|             stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)",  # noqa: E501 | ||||
|             stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)",  # noqa: E501 | ||||
|             globals=globals, | ||||
|             label=label, | ||||
|             sub_label=sub_label, | ||||
| @ -210,10 +281,7 @@ def bench_run( | ||||
|         ).blocked_autorange(min_run_time=min_run_time) | ||||
|     ) | ||||
|  | ||||
|     if ( | ||||
|         quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES | ||||
|         and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES | ||||
|     ): | ||||
|     if marlin_24_supported: | ||||
|         results.append( | ||||
|             benchmark.Timer( | ||||
|                 stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)",  # noqa: E501 | ||||
| @ -224,17 +292,18 @@ def bench_run( | ||||
|             ).blocked_autorange(min_run_time=min_run_time) | ||||
|         ) | ||||
|  | ||||
|     results.append( | ||||
|         benchmark.Timer( | ||||
|             stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)",  # noqa: E501 | ||||
|             globals=globals, | ||||
|             label=label, | ||||
|             sub_label=sub_label, | ||||
|             description="gptq_marlin_repack", | ||||
|         ).blocked_autorange(min_run_time=min_run_time) | ||||
|     ) | ||||
|     if repack_supported: | ||||
|         results.append( | ||||
|             benchmark.Timer( | ||||
|                 stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)",  # noqa: E501 | ||||
|                 globals=globals, | ||||
|                 label=label, | ||||
|                 sub_label=sub_label, | ||||
|                 description="gptq_marlin_repack", | ||||
|             ).blocked_autorange(min_run_time=min_run_time) | ||||
|         ) | ||||
|  | ||||
|     if as_supported_case: | ||||
|     if allspark_supported: | ||||
|         results.append( | ||||
|             benchmark.Timer( | ||||
|                 stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)",  # noqa: E501 | ||||
| @ -250,7 +319,6 @@ def main(args): | ||||
|     print("Benchmarking models:") | ||||
|     for i, model in enumerate(args.models): | ||||
|         print(f"[{i}]  {model}") | ||||
|  | ||||
|     results: list[benchmark.Measurement] = [] | ||||
|  | ||||
|     for model in args.models: | ||||
| @ -278,14 +346,17 @@ def main(args): | ||||
|                     ): | ||||
|                         continue | ||||
|  | ||||
|                     for quant_type in query_marlin_supported_quant_types(False): | ||||
|                     for quant_type in query_marlin_supported_quant_types(): | ||||
|                         if ( | ||||
|                             len(args.limit_num_bits) > 0 | ||||
|                             and quant_type.size_bits not in args.limit_num_bits | ||||
|                         ): | ||||
|                             continue | ||||
|  | ||||
|                         for group_size in MARLIN_SUPPORTED_GROUP_SIZES: | ||||
|                         for group_size in ( | ||||
|                             MARLIN_SUPPORTED_GROUP_SIZES | ||||
|                             + FP4_MARLIN_SUPPORTED_GROUP_SIZES | ||||
|                         ): | ||||
|                             if ( | ||||
|                                 len(args.limit_group_size) > 0 | ||||
|                                 and group_size not in args.limit_group_size | ||||
|  | ||||
							
								
								
									
										159
									
								
								benchmarks/kernels/benchmark_moe_align_block_size.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								benchmarks/kernels/benchmark_moe_align_block_size.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,159 @@ | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
| import argparse | ||||
| import itertools | ||||
|  | ||||
| import torch | ||||
|  | ||||
| from vllm import _custom_ops as ops | ||||
| from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( | ||||
|     moe_align_block_size_triton, | ||||
| ) | ||||
| from vllm.triton_utils import triton | ||||
|  | ||||
|  | ||||
| def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: | ||||
|     return torch.stack( | ||||
|         [ | ||||
|             torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] | ||||
|             for _ in range(num_tokens) | ||||
|         ] | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8): | ||||
|     """ | ||||
|     Verifies vllm vs. Triton | ||||
|     """ | ||||
|     topk_ids = get_topk_ids(num_tokens, num_experts, topk) | ||||
|  | ||||
|     # 1. malloc space for triton and vllm | ||||
|     # malloc enough space (max_num_tokens_padded) for the sorted ids | ||||
|     max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) | ||||
|     sorted_ids_triton = torch.empty( | ||||
|         (max_num_tokens_padded,), dtype=torch.int32, device="cuda" | ||||
|     ) | ||||
|     sorted_ids_triton.fill_(topk_ids.numel())  # fill with sentinel value | ||||
|     expert_ids_triton = torch.zeros( | ||||
|         (max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda" | ||||
|     ) | ||||
|     num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda") | ||||
|  | ||||
|     sorted_ids_vllm = torch.empty_like(sorted_ids_triton) | ||||
|     sorted_ids_vllm.fill_(topk_ids.numel()) | ||||
|     expert_ids_vllm = torch.zeros_like(expert_ids_triton) | ||||
|     num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton) | ||||
|  | ||||
|     # 2. run implementations | ||||
|     moe_align_block_size_triton( | ||||
|         topk_ids, | ||||
|         num_experts, | ||||
|         block_size, | ||||
|         sorted_ids_triton, | ||||
|         expert_ids_triton, | ||||
|         num_tokens_post_pad_triton, | ||||
|     ) | ||||
|  | ||||
|     ops.moe_align_block_size( | ||||
|         topk_ids, | ||||
|         num_experts, | ||||
|         block_size, | ||||
|         sorted_ids_vllm, | ||||
|         expert_ids_vllm, | ||||
|         num_tokens_post_pad_vllm, | ||||
|     ) | ||||
|     print(f"✅ VLLM implementation works with {num_experts} experts!") | ||||
|  | ||||
|     # 3. compare results | ||||
|     if torch.allclose(expert_ids_triton, expert_ids_vllm) and torch.allclose( | ||||
|         num_tokens_post_pad_triton, num_tokens_post_pad_vllm | ||||
|     ): | ||||
|         print("✅ Triton and VLLM implementations match.") | ||||
|     else: | ||||
|         print("❌ Triton and VLLM implementations DO NOT match.") | ||||
|         print("Triton expert_ids:", expert_ids_triton) | ||||
|         print("VLLM expert_ids:", expert_ids_vllm) | ||||
|         print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton) | ||||
|         print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm) | ||||
|  | ||||
|  | ||||
| # test configurations | ||||
| num_tokens_range = [1, 16, 256, 4096] | ||||
| num_experts_range = [16, 64, 224, 256, 280, 512] | ||||
| topk_range = [1, 2, 8] | ||||
| configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) | ||||
|  | ||||
|  | ||||
| @triton.testing.perf_report( | ||||
|     triton.testing.Benchmark( | ||||
|         x_names=["num_tokens", "num_experts", "topk"], | ||||
|         x_vals=configs, | ||||
|         line_arg="provider", | ||||
|         line_vals=["vllm", "triton"],  # "triton" | ||||
|         line_names=["VLLM", "Triton"],  # "Triton" | ||||
|         plot_name="moe-align-block-size-performance", | ||||
|         args={}, | ||||
|     ) | ||||
| ) | ||||
| def benchmark(num_tokens, num_experts, topk, provider): | ||||
|     """Benchmark function for Triton.""" | ||||
|     block_size = 256 | ||||
|     topk_ids = get_topk_ids(num_tokens, num_experts, topk) | ||||
|  | ||||
|     max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) | ||||
|     sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") | ||||
|     sorted_ids.fill_(topk_ids.numel()) | ||||
|     max_num_m_blocks = max_num_tokens_padded // block_size | ||||
|     expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda") | ||||
|     num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda") | ||||
|  | ||||
|     quantiles = [0.5, 0.2, 0.8] | ||||
|  | ||||
|     if provider == "vllm": | ||||
|         ms, min_ms, max_ms = triton.testing.do_bench( | ||||
|             lambda: ops.moe_align_block_size( | ||||
|                 topk_ids, | ||||
|                 num_experts, | ||||
|                 block_size, | ||||
|                 sorted_ids.clone(), | ||||
|                 expert_ids.clone(), | ||||
|                 num_tokens_post_pad.clone(), | ||||
|             ), | ||||
|             quantiles=quantiles, | ||||
|         ) | ||||
|     elif provider == "triton": | ||||
|         ms, min_ms, max_ms = triton.testing.do_bench( | ||||
|             lambda: moe_align_block_size_triton( | ||||
|                 topk_ids, | ||||
|                 num_experts, | ||||
|                 block_size, | ||||
|                 sorted_ids.clone(), | ||||
|                 expert_ids.clone(), | ||||
|                 num_tokens_post_pad.clone(), | ||||
|             ), | ||||
|             quantiles=quantiles, | ||||
|         ) | ||||
|  | ||||
|     return 1000 * ms, 1000 * max_ms, 1000 * min_ms | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser() | ||||
|     parser.add_argument( | ||||
|         "--num_experts", | ||||
|         type=int, | ||||
|         default=64, | ||||
|         choices=[8, 16, 32, 64, 128, 256], | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--topk", | ||||
|         type=int, | ||||
|         default=8, | ||||
|         choices=[2, 4, 8], | ||||
|         help="Top-k value for correctness check.", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     print("Running correctness check...") | ||||
|     check_correctness(num_tokens=1024, num_experts=args.num_experts, topk=args.topk) | ||||
|     benchmark.run(print_data=True, show_plots=True) | ||||
| @ -85,12 +85,6 @@ def benchmark_shape(m: int, | ||||
|  | ||||
|     # === DeepGEMM Implementation === | ||||
|     def deepgemm_gemm(): | ||||
|         # A quantization is inside the loop as it depends on activations | ||||
|         # A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) | ||||
|         # A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8( | ||||
|         #     A, block_size[1]) | ||||
|         # A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm) | ||||
|         # C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) | ||||
|         deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm), | ||||
|                                        (B_deepgemm, B_scale_deepgemm), | ||||
|                                        C_deepgemm) | ||||
| @ -98,8 +92,6 @@ def benchmark_shape(m: int, | ||||
|  | ||||
|     # === vLLM Triton Implementation === | ||||
|     def vllm_triton_gemm(): | ||||
|         # A quantization is inside the loop as it depends on activations | ||||
|         # A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) | ||||
|         return w8a8_block_fp8_matmul(A_vllm, | ||||
|                                      B_vllm, | ||||
|                                      A_scale_vllm, | ||||
| @ -109,9 +101,6 @@ def benchmark_shape(m: int, | ||||
|  | ||||
|     # === vLLM CUTLASS Implementation === | ||||
|     def vllm_cutlass_gemm(): | ||||
|         # A quantization is inside the loop as it depends on activations | ||||
|         # A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( | ||||
|         #     A, block_size[1], column_major_scales=True) | ||||
|         return ops.cutlass_scaled_mm(A_vllm_cutlass, | ||||
|                                      B_vllm.T, | ||||
|                                      scale_a=A_scale_vllm_cutlass, | ||||
|  | ||||
| @ -12,9 +12,8 @@ endif() | ||||
| # | ||||
| # Define environment variables for special configurations | ||||
| # | ||||
| if(DEFINED ENV{VLLM_CPU_AVX512BF16}) | ||||
|     set(ENABLE_AVX512BF16 ON) | ||||
| endif() | ||||
| set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16}) | ||||
| set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI}) | ||||
|  | ||||
| include_directories("${CMAKE_SOURCE_DIR}/csrc") | ||||
|  | ||||
| @ -96,12 +95,30 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) | ||||
|         if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND | ||||
|             CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) | ||||
|             list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") | ||||
|             set(ENABLE_AVX512BF16 ON) | ||||
|         else() | ||||
|             set(ENABLE_AVX512BF16 OFF) | ||||
|             message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") | ||||
|         endif() | ||||
|     else() | ||||
|         set(ENABLE_AVX512BF16 OFF) | ||||
|         message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.") | ||||
|     endif() | ||||
|  | ||||
|     find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND) | ||||
|     if (AVX512VNNI_FOUND OR ENABLE_AVX512VNNI) | ||||
|         if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND | ||||
|             CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) | ||||
|             list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni") | ||||
|             set(ENABLE_AVX512VNNI ON) | ||||
|         else() | ||||
|             set(ENABLE_AVX512VNNI OFF) | ||||
|             message(WARNING "Disable AVX512-VNNI ISA support, requires gcc/g++ >= 12.3") | ||||
|         endif() | ||||
|     else() | ||||
|         set(ENABLE_AVX512VNNI OFF) | ||||
|         message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.") | ||||
|     endif() | ||||
|      | ||||
| elseif (AVX2_FOUND) | ||||
|     list(APPEND CXX_COMPILE_FLAGS "-mavx2") | ||||
| @ -231,12 +248,25 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) | ||||
|         "csrc/cpu/quant.cpp" | ||||
|         "csrc/cpu/shm.cpp" | ||||
|         ${VLLM_EXT_SRC}) | ||||
|     if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) | ||||
|         set(VLLM_EXT_SRC | ||||
|             "csrc/cpu/sgl-kernels/gemm.cpp" | ||||
|             "csrc/cpu/sgl-kernels/gemm_int8.cpp" | ||||
|             "csrc/cpu/sgl-kernels/gemm_fp8.cpp" | ||||
|             "csrc/cpu/sgl-kernels/moe.cpp" | ||||
|             "csrc/cpu/sgl-kernels/moe_int8.cpp" | ||||
|             "csrc/cpu/sgl-kernels/moe_fp8.cpp" | ||||
|             ${VLLM_EXT_SRC}) | ||||
|         add_compile_definitions(-DCPU_CAPABILITY_AVX512) | ||||
|     endif() | ||||
| elseif(POWER10_FOUND) | ||||
|     set(VLLM_EXT_SRC | ||||
|         "csrc/cpu/quant.cpp" | ||||
|         ${VLLM_EXT_SRC}) | ||||
| endif() | ||||
|  | ||||
| message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}") | ||||
|  | ||||
| # | ||||
| # Define extension targets | ||||
| # | ||||
|  | ||||
| @ -38,7 +38,7 @@ else() | ||||
|   FetchContent_Declare( | ||||
|           vllm-flash-attn | ||||
|           GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git | ||||
|           GIT_TAG 8798f27777fb57f447070301bf33a9f9c607f491 | ||||
|           GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4 | ||||
|           GIT_PROGRESS TRUE | ||||
|           # Don't share the vllm-flash-attn build between build types | ||||
|           BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn | ||||
|  | ||||
| @ -122,6 +122,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) | ||||
|       "-DENABLE_FP8" | ||||
|       "-U__HIP_NO_HALF_CONVERSIONS__" | ||||
|       "-U__HIP_NO_HALF_OPERATORS__" | ||||
|       "-Werror=unused-variable" | ||||
|       "-fno-gpu-rdc") | ||||
|  | ||||
|   endif() | ||||
| @ -264,8 +265,8 @@ macro(set_gencode_flags_for_srcs) | ||||
| endmacro() | ||||
|  | ||||
| # | ||||
| # For the given `SRC_CUDA_ARCHS` list of gencode versions in the form  | ||||
| #  `<major>.<minor>[letter]` compute the "loose intersection" with the  | ||||
| # For the given `SRC_CUDA_ARCHS` list of gencode versions in the form | ||||
| #  `<major>.<minor>[letter]` compute the "loose intersection" with the | ||||
| #  `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in | ||||
| #  `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there | ||||
| #  is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the | ||||
| @ -277,7 +278,7 @@ endmacro() | ||||
| #  in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`. | ||||
| # We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is | ||||
| #  in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add | ||||
| #  x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).  | ||||
| #  x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS). | ||||
| # The result is stored in `OUT_CUDA_ARCHS`. | ||||
| # | ||||
| # Example: | ||||
| @ -312,21 +313,16 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR | ||||
|   # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should | ||||
|   # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS | ||||
|   set(_CUDA_ARCHS) | ||||
|   if ("9.0a" IN_LIST _SRC_CUDA_ARCHS) | ||||
|     list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a") | ||||
|     if ("9.0" IN_LIST TGT_CUDA_ARCHS) | ||||
|       list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0") | ||||
|       set(_CUDA_ARCHS "9.0a") | ||||
|   foreach(_arch ${_SRC_CUDA_ARCHS}) | ||||
|     if(_arch MATCHES "\\a$") | ||||
|       list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}") | ||||
|       string(REPLACE "a" "" _base "${_arch}") | ||||
|       if ("${_base}" IN_LIST TGT_CUDA_ARCHS) | ||||
|         list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}") | ||||
|         list(APPEND _CUDA_ARCHS "${_arch}") | ||||
|       endif() | ||||
|     endif() | ||||
|   endif() | ||||
|  | ||||
|   if ("10.0a" IN_LIST _SRC_CUDA_ARCHS) | ||||
|     list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a") | ||||
|     if ("10.0" IN_LIST TGT_CUDA_ARCHS) | ||||
|       list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0") | ||||
|       set(_CUDA_ARCHS "10.0a") | ||||
|     endif() | ||||
|   endif() | ||||
|   endforeach() | ||||
|  | ||||
|   list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) | ||||
|  | ||||
| @ -358,7 +354,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR | ||||
|   endforeach() | ||||
|  | ||||
|   list(REMOVE_DUPLICATES _CUDA_ARCHS) | ||||
|    | ||||
|  | ||||
|   # reapply +PTX suffix to architectures that requested PTX | ||||
|   set(_FINAL_ARCHS) | ||||
|   foreach(_arch ${_CUDA_ARCHS}) | ||||
| @ -369,7 +365,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR | ||||
|     endif() | ||||
|   endforeach() | ||||
|   set(_CUDA_ARCHS ${_FINAL_ARCHS}) | ||||
|    | ||||
|  | ||||
|   set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) | ||||
| endfunction() | ||||
|  | ||||
|  | ||||
| @ -207,7 +207,7 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out, | ||||
|               "page_table must be a 32-bit integer tensor"); | ||||
|  | ||||
|   auto in_dtype = q_nope.dtype(); | ||||
|   at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(q_nope)); | ||||
|   const cudaStream_t stream = | ||||
|       at::cuda::getCurrentCUDAStream(q_nope.get_device()); | ||||
|   if (in_dtype == at::ScalarType::Half) { | ||||
|  | ||||
| @ -65,9 +65,6 @@ void paged_attention_v1_launcher( | ||||
|   int kv_block_stride = key_cache.stride(0); | ||||
|   int kv_head_stride = key_cache.stride(1); | ||||
|  | ||||
|   [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); | ||||
|   assert(head_size % thread_group_size == 0); | ||||
|  | ||||
|   // NOTE: alibi_slopes is optional. | ||||
|   const float* alibi_slopes_ptr = | ||||
|       alibi_slopes | ||||
| @ -193,4 +190,4 @@ void paged_attention_v1( | ||||
| #undef WARP_SIZE | ||||
| #undef MAX | ||||
| #undef MIN | ||||
| #undef DIVIDE_ROUND_UP | ||||
| #undef DIVIDE_ROUND_UP | ||||
|  | ||||
| @ -66,9 +66,6 @@ void paged_attention_v2_launcher( | ||||
|   int kv_block_stride = key_cache.stride(0); | ||||
|   int kv_head_stride = key_cache.stride(1); | ||||
|  | ||||
|   [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); | ||||
|   assert(head_size % thread_group_size == 0); | ||||
|  | ||||
|   // NOTE: alibi_slopes is optional. | ||||
|   const float* alibi_slopes_ptr = | ||||
|       alibi_slopes | ||||
| @ -203,4 +200,4 @@ void paged_attention_v2( | ||||
| #undef WARP_SIZE | ||||
| #undef MAX | ||||
| #undef MIN | ||||
| #undef DIVIDE_ROUND_UP | ||||
| #undef DIVIDE_ROUND_UP | ||||
|  | ||||
| @ -137,8 +137,8 @@ FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size, | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data, | ||||
|                                         const int size) { | ||||
| FORCE_INLINE void reducePartitionSoftmax(const T* max_data, T* sum_data, | ||||
|                                          const int size) { | ||||
|   T max = max_data[0]; | ||||
|   for (int i = 1; i < size; ++i) { | ||||
|     max = max >= max_data[i] ? max : max_data[i]; | ||||
| @ -634,7 +634,7 @@ struct paged_attention_v2_impl { | ||||
|  | ||||
|         if (partition_num == 1) continue; | ||||
|  | ||||
|         reducePartitonSoftmax( | ||||
|         reducePartitionSoftmax( | ||||
|             max_logits + seq_idx * num_heads * max_num_partitions + | ||||
|                 head_idx * max_num_partitions, | ||||
|             exp_sums + seq_idx * num_heads * max_num_partitions + | ||||
|  | ||||
| @ -83,7 +83,7 @@ struct FP16Vec16 : public Vec<FP16Vec16> { | ||||
|   explicit FP16Vec16(const void* ptr) | ||||
|       : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} | ||||
|  | ||||
|   // non-temproal load | ||||
|   // non-temporal load | ||||
|   explicit FP16Vec16(bool, void* ptr) | ||||
|       : reg(_mm256_stream_load_si256((__m256i*)ptr)) {} | ||||
|  | ||||
| @ -120,7 +120,7 @@ struct BF16Vec16 : public Vec<BF16Vec16> { | ||||
|   explicit BF16Vec16(const void* ptr) | ||||
|       : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} | ||||
|  | ||||
|   // non-temproal load | ||||
|   // non-temporal load | ||||
|   explicit BF16Vec16(bool, void* ptr) | ||||
|       : reg(_mm256_stream_load_si256((__m256i*)ptr)) {} | ||||
|  | ||||
| @ -327,7 +327,7 @@ struct FP32Vec16 : public Vec<FP32Vec16> { | ||||
|   // normal load | ||||
|   explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {} | ||||
|  | ||||
|   // non-temproal load | ||||
|   // non-temporal load | ||||
|   explicit FP32Vec16(bool, void* ptr) | ||||
|       : reg((__m512)_mm512_stream_load_si512(ptr)) {} | ||||
|  | ||||
| @ -576,7 +576,7 @@ struct INT8Vec64 : public Vec<INT8Vec64> { | ||||
|   // normal load | ||||
|   explicit INT8Vec64(void* ptr) : reg(_mm512_loadu_epi8(ptr)) {} | ||||
|  | ||||
|   // non-temproal load | ||||
|   // non-temporal load | ||||
|   explicit INT8Vec64(bool, void* ptr) : reg(_mm512_stream_load_si512(ptr)) {} | ||||
|  | ||||
|   void save(void* ptr) const { _mm512_storeu_epi8(ptr, reg); } | ||||
| @ -587,7 +587,7 @@ struct INT8Vec64 : public Vec<INT8Vec64> { | ||||
|     _mm512_mask_storeu_epi8(ptr, mask, reg); | ||||
|   } | ||||
|  | ||||
|   // non-temproal save | ||||
|   // non-temporal save | ||||
|   void nt_save(int8_t* ptr) { _mm512_stream_si512((__m512i*)ptr, reg); } | ||||
| }; | ||||
| #endif | ||||
|  | ||||
							
								
								
									
										238
									
								
								csrc/cpu/sgl-kernels/common.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										238
									
								
								csrc/cpu/sgl-kernels/common.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,238 @@ | ||||
| // Adapted from | ||||
| // https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu | ||||
|  | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/ATen.h> | ||||
| #include <ATen/Parallel.h> | ||||
| #include <ATen/record_function.h> | ||||
|  | ||||
| // clang-format off | ||||
|  | ||||
| #if defined(_OPENMP) | ||||
| #include <omp.h> | ||||
| #endif | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| // dispatch bool | ||||
| #define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...)                                 \ | ||||
|   [&] {                                                                          \ | ||||
|     if (BOOL_V) {                                                                \ | ||||
|       constexpr bool BOOL_NAME = true;                                           \ | ||||
|       return __VA_ARGS__();                                                      \ | ||||
|     } else {                                                                     \ | ||||
|       constexpr bool BOOL_NAME = false;                                          \ | ||||
|       return __VA_ARGS__();                                                      \ | ||||
|     }                                                                            \ | ||||
|   }() | ||||
|  | ||||
| // dispatch: bfloat16, float16, int8_t, fp8_e4m3 | ||||
| #define CPU_DISPATCH_PACKED_TYPES(TYPE, ...)                                    \ | ||||
|   [&] {                                                                         \ | ||||
|     switch (TYPE) {                                                             \ | ||||
|       case at::ScalarType::BFloat16 : {                                         \ | ||||
|         using packed_t = at::BFloat16;                                          \ | ||||
|         return __VA_ARGS__();                                                   \ | ||||
|       }                                                                         \ | ||||
|       case at::ScalarType::Half: {                                              \ | ||||
|         using packed_t = at::Half;                                              \ | ||||
|         return __VA_ARGS__();                                                   \ | ||||
|       }                                                                         \ | ||||
|       case at::ScalarType::Char : {                                             \ | ||||
|         using packed_t = int8_t;                                                \ | ||||
|         return __VA_ARGS__();                                                   \ | ||||
|       }                                                                         \ | ||||
|       case at::ScalarType::Float8_e4m3fn : {                                    \ | ||||
|         using packed_t = at::Float8_e4m3fn;                                     \ | ||||
|         return __VA_ARGS__();                                                   \ | ||||
|       }                                                                         \ | ||||
|       default:                                                                  \ | ||||
|         TORCH_CHECK(false, "Unsupported floating data type.\n");                \ | ||||
|     }                                                                           \ | ||||
|   }() | ||||
|  | ||||
| #define UNUSED(x) (void)(x) | ||||
|  | ||||
| #define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor") | ||||
|  | ||||
| #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") | ||||
| #define CHECK_LAST_DIM_CONTIGUOUS(x) \ | ||||
|   TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention") | ||||
|  | ||||
| #define CHECK_INPUT(x) \ | ||||
|   CHECK_CPU(x);        \ | ||||
|   CHECK_CONTIGUOUS(x) | ||||
| #define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ | ||||
|   CHECK_CPU(x);                            \ | ||||
|   CHECK_LAST_DIM_CONTIGUOUS(x) | ||||
|  | ||||
| #define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") | ||||
|  | ||||
| #define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) | ||||
|  | ||||
| // parallel routines | ||||
| constexpr int GRAIN_SIZE = 1024; | ||||
|  | ||||
| template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0> | ||||
| inline T div_up(T x, T y) { return (x + y - 1) / y; } | ||||
|  | ||||
| template <typename T> | ||||
| inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { | ||||
| #if 0 | ||||
|     // onednn partition pattern | ||||
|     T& n_my = n_end; | ||||
|     if (nth <= 1 || n == 0) { | ||||
|         n_start = 0; | ||||
|         n_my = n; | ||||
|     } else { | ||||
|         T n1 = div_up(n, nth); | ||||
|         T n2 = n1 - 1; | ||||
|         T T1 = n - n2 * nth; | ||||
|         n_my = ith < T1 ? n1 : n2; | ||||
|         n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2; | ||||
|     } | ||||
|     n_end += n_start; | ||||
| #else | ||||
|     // pytorch aten partition pattern | ||||
|     T n_my = div_up(n, nth); | ||||
|     n_start = ith * n_my; | ||||
|     n_end = std::min(n_start + n_my, n); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <typename func_t> | ||||
| inline void parallel_for(int n, const func_t& f) { | ||||
| #if defined(_OPENMP) | ||||
| #pragma omp parallel | ||||
| { | ||||
|     int nth = omp_get_num_threads(); | ||||
|     int ith = omp_get_thread_num(); | ||||
|     int tbegin, tend; | ||||
|     balance211(n, nth, ith, tbegin, tend); | ||||
|     f(tbegin, tend); | ||||
| } | ||||
| #else | ||||
|     f(0, n); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| // for 1d parallel, use `actual_nth` | ||||
| // for 2d parallel, use even nths, e.g. 43->42 | ||||
| int inline adjust_num_threads(int m) { | ||||
|   int actual_nth = at::get_num_threads(); | ||||
|   if (m == 1) { | ||||
|     return actual_nth; | ||||
|   } | ||||
|   return std::max(1, (actual_nth >> 1) * 2); | ||||
| } | ||||
|  | ||||
| template <typename func_t> | ||||
| inline void parallel_2d(int m, int n, const func_t& f) { | ||||
|  | ||||
|   // make sure we have even num_threads | ||||
|   int nth = adjust_num_threads(m); | ||||
|  | ||||
|   // [NOTE] thread blocking: | ||||
|   // | ||||
|   //   1) prefer square block per thread | ||||
|   //   2) use even number of CPU cores | ||||
|   //   3) use all `num_threads` cores | ||||
|   // | ||||
|   //   we have: | ||||
|   //     TM * TN = T | ||||
|   //     BM / TM = BN / TN | ||||
|   //   then: | ||||
|   //     TM = ((BM / BN) * T) ^ 0.5 | ||||
|   // | ||||
|   float r = float(m) / n; | ||||
|   int nth_m = std::ceil(std::sqrt(r * nth)); | ||||
|   int nth_n = 1; | ||||
|   for (; nth_m > 0; --nth_m) { | ||||
|     nth_n = nth / nth_m; | ||||
|     if (nth_m * nth_n == nth) { | ||||
|       break; | ||||
|     } | ||||
|   } | ||||
|  | ||||
| #if defined(_OPENMP) | ||||
| #pragma omp parallel num_threads(nth) | ||||
| { | ||||
|   int ith = omp_get_thread_num(); | ||||
|   int ith_m = ith / nth_n; | ||||
|   int ith_n = ith % nth_n; | ||||
|  | ||||
|   int thread_block_m = div_up(m, nth_m); | ||||
|   int thread_block_n = div_up(n, nth_n); | ||||
|  | ||||
|   int begin_m = ith_m * thread_block_m; | ||||
|   int end_m = std::min(m, begin_m + thread_block_m); | ||||
|   int begin_n = ith_n * thread_block_n; | ||||
|   int end_n = std::min(n, begin_n + thread_block_n); | ||||
|  | ||||
|   f(begin_m, end_m, begin_n, end_n); | ||||
| } | ||||
| #else | ||||
|   f(0, m, 0, n); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| int get_cache_blocks(int BLOCK_SIZE, int K) { | ||||
|   // L2 2MB and ratio of 50% | ||||
|   const int L2_size = 2048 * 1024 >> 1; | ||||
|   return std::max(1, int(L2_size / (BLOCK_SIZE * K * sizeof(T)))); | ||||
| } | ||||
|  | ||||
| // data indexing for dimension collapse | ||||
| template <typename T> | ||||
| inline T data_index_init(T offset) { | ||||
|   return offset; | ||||
| } | ||||
|  | ||||
| template <typename T, typename... Args> | ||||
| inline T data_index_init(T offset, T& x, const T& X, Args&&... args) { | ||||
|   offset = data_index_init(offset, std::forward<Args>(args)...); | ||||
|   x = offset % X; | ||||
|   return offset / X; | ||||
| } | ||||
|  | ||||
| inline bool data_index_step() { | ||||
|   return true; | ||||
| } | ||||
|  | ||||
| template <typename T, typename... Args> | ||||
| inline bool data_index_step(T& x, const T& X, Args&&... args) { | ||||
|   if (data_index_step(std::forward<Args>(args)...)) { | ||||
|     x = ((x + 1) == X) ? 0 : (x + 1); | ||||
|     return x == 0; | ||||
|   } | ||||
|   return false; | ||||
| } | ||||
|  | ||||
| // forced unroll for perf critical path | ||||
|  | ||||
| #if __has_attribute(always_inline) | ||||
| #define ALWAYS_INLINE __attribute__((__always_inline__)) inline | ||||
| #else | ||||
| #define ALWAYS_INLINE inline | ||||
| #endif | ||||
|  | ||||
| template <int n> | ||||
| struct Unroll { | ||||
|   template <typename Func, typename... Args> | ||||
|   ALWAYS_INLINE void operator()(const Func& f, Args... args) const { | ||||
|     Unroll<n - 1>{}(f, args...); | ||||
|     f(std::integral_constant<int, n - 1>{}, args...); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct Unroll<1> { | ||||
|   template <typename Func, typename... Args> | ||||
|   ALWAYS_INLINE void operator()(const Func& f, Args... args) const { | ||||
|     f(std::integral_constant<int, 0>{}, args...); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| } // anonymous namespace | ||||
							
								
								
									
										464
									
								
								csrc/cpu/sgl-kernels/gemm.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										464
									
								
								csrc/cpu/sgl-kernels/gemm.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,464 @@ | ||||
| // Adapted from | ||||
| // https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu | ||||
|  | ||||
| #include "common.h" | ||||
| #include "vec.h" | ||||
| #include "gemm.h" | ||||
|  | ||||
| // clang-format off | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| // packed   layout: | ||||
| //   quants {N, K}  int8_t | ||||
| //   comp   {N}     int32_t | ||||
| template <int BLOCK_N> | ||||
| inline void s8s8_compensation(int8_t* __restrict__ packed, int K) { | ||||
| #if defined(CPU_CAPABILITY_AVX512) | ||||
|   constexpr int COLS = BLOCK_N / 16; | ||||
|   __m512i vcomp[COLS]; | ||||
|  | ||||
|   for (int col = 0; col < COLS; ++col) { | ||||
|     vcomp[col] = _mm512_setzero_si512(); | ||||
|   } | ||||
|  | ||||
|   const int64_t offset = BLOCK_N * K; | ||||
|   const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80)); | ||||
|   for (int k = 0; k < K / 4; ++k) { | ||||
|     for (int col = 0; col < COLS; ++col) { | ||||
|       __m512i vb = _mm512_loadu_si512((const __m512i *)(packed + k * BLOCK_N * 4 + col * 64)); | ||||
|       vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   for (int col = 0; col < COLS; ++col) { | ||||
|     _mm512_storeu_si512((__m512i *)(packed + offset + col * 64), vcomp[col]); | ||||
|   } | ||||
| #else | ||||
|   TORCH_CHECK(false, "s8s8_compensation not implemented!"); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| // convert to vnni format | ||||
| // from [N, K] to [K/2, N, 2] for bfloat16 and float16 | ||||
| template <typename packed_t> | ||||
| inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) { | ||||
|   const int VNNI_BLK = 2; | ||||
|   for (int n = 0; n < N; ++n) { | ||||
|     for (int k = 0; k < K / VNNI_BLK; ++k) { | ||||
|       for (int d = 0; d < VNNI_BLK; ++d) { | ||||
|         packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d]; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <> | ||||
| inline void pack_vnni<int8_t>(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) { | ||||
|   constexpr int BLOCK_N = block_size_n(); | ||||
|   TORCH_CHECK(N == BLOCK_N); | ||||
|  | ||||
|   const int VNNI_BLK = 4; | ||||
|   for (int n = 0; n < N; ++n) { | ||||
|     for (int k = 0; k < K / VNNI_BLK; ++k) { | ||||
|       for (int d = 0; d < VNNI_BLK; ++d) { | ||||
|         packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d]; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   s8s8_compensation<BLOCK_N>(packed, K); | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { | ||||
|   using bVec = at::vec::Vectorized<scalar_t>; | ||||
|   using fVec = at::vec::Vectorized<float>; | ||||
|   constexpr int kVecSize = bVec::size(); | ||||
|  | ||||
|   int64_t d; | ||||
|   #pragma GCC unroll 4 | ||||
|   for (d = 0; d <= size - kVecSize; d += kVecSize) { | ||||
|     fVec data0 = fVec::loadu(input + d); | ||||
|     fVec data1 = fVec::loadu(input + d + fVec::size()); | ||||
|     bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1); | ||||
|     out_vec.store(out + d); | ||||
|   } | ||||
|   for (; d < size; ++d) { | ||||
|     out[d] = static_cast<scalar_t>(input[d]); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) { | ||||
|   using bVec = at::vec::Vectorized<scalar_t>; | ||||
|   using fVec = at::vec::Vectorized<float>; | ||||
|   constexpr int kVecSize = bVec::size(); | ||||
|  | ||||
|   int64_t d; | ||||
|   #pragma GCC unroll 4 | ||||
|   for (d = 0; d <= size - kVecSize; d += kVecSize) { | ||||
|     fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d); | ||||
|     fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size()); | ||||
|     bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1); | ||||
|     out_vec.store(out + d); | ||||
|   } | ||||
|   for (; d < size; ++d) { | ||||
|     out[d] = static_cast<scalar_t>(input[d] + bias[d]); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N> | ||||
| struct tinygemm_kernel_nn { | ||||
|   static inline void apply( | ||||
|       const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, | ||||
|       const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||
|     TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| #if defined(CPU_CAPABILITY_AVX512) | ||||
| template <bool has_bias, int BLOCK_M, int BLOCK_N> | ||||
| struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> { | ||||
|   static inline void apply( | ||||
|       const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B, at::BFloat16* __restrict__ C, | ||||
|       const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||
|  | ||||
|     constexpr int ROWS = BLOCK_M; | ||||
|     constexpr int COLS = BLOCK_N / 16; | ||||
|  | ||||
|     // prefetch distance | ||||
|     constexpr int PREFETCH_SIZE_K = 0; | ||||
|  | ||||
|     __m512bh va; | ||||
|     __m512bh vb[COLS]; | ||||
|     __m512 vc[ROWS * COLS]; | ||||
|  | ||||
|     auto loadc = [&](auto i) { | ||||
|       constexpr int col = i % COLS; | ||||
|       if constexpr (has_bias) { | ||||
|         vc[i] = _mm512_loadu_ps(bias + col * 16); | ||||
|       } else { | ||||
|         vc[i] = _mm512_set1_ps(0.f); | ||||
|       } | ||||
|     }; | ||||
|     Unroll<ROWS * COLS>{}(loadc); | ||||
|  | ||||
|     const int64_t K2 = K >> 1; | ||||
|     const int64_t lda2 = lda >> 1; | ||||
|     const int64_t ldb2 = ldb; // ldb * 2 >> 1; | ||||
|     const float* a_ptr = reinterpret_cast<const float*>(A); | ||||
|     const float* b_ptr = reinterpret_cast<const float*>(B); | ||||
|  | ||||
|     auto compute = [&](auto i, int64_t k) { | ||||
|       constexpr int row = i / COLS; | ||||
|       constexpr int col = i % COLS; | ||||
|  | ||||
|       if constexpr (col == 0) { | ||||
|         va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); | ||||
|       } | ||||
|       if constexpr (row == 0) { | ||||
|         vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16)); | ||||
|         if constexpr (PREFETCH_SIZE_K > 0) { | ||||
|           _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); | ||||
|         } | ||||
|       } | ||||
|       vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); | ||||
|     }; | ||||
|     for (int64_t k = 0; k < K2; ++k) { | ||||
|       Unroll<ROWS * COLS>{}(compute, k); | ||||
|     } | ||||
|  | ||||
|     auto storec = [&](auto i) { | ||||
|       constexpr int row = i / COLS; | ||||
|       constexpr int col = i % COLS; | ||||
|       // for COLS = 2, 4 use 512bit store | ||||
|       // for COLS = 1, 3 use 256bit store | ||||
|       if constexpr (COLS % 2 == 0) { | ||||
|         if constexpr (col % 2 == 0) { | ||||
|           _mm512_storeu_si512( | ||||
|               reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), | ||||
|               (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); | ||||
|         } | ||||
|       } else { | ||||
|         _mm256_storeu_si256( | ||||
|             reinterpret_cast<__m256i*>(C + row * ldc + col * 16), | ||||
|             (__m256i)(_mm512_cvtneps_pbh(vc[i]))); | ||||
|       } | ||||
|     }; | ||||
|     Unroll<ROWS * COLS>{}(storec); | ||||
|   } | ||||
| }; | ||||
| #endif | ||||
|  | ||||
| #define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE)                          \ | ||||
|     tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply(         \ | ||||
|         A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \ | ||||
|         has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc); | ||||
|  | ||||
| template <typename scalar_t, bool has_bias> | ||||
| struct brgemm { | ||||
|   static inline void apply( | ||||
|       const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, | ||||
|       float* __restrict__ Ctmp, const float* __restrict__ bias, | ||||
|       int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||
|  | ||||
|     constexpr int BLOCK_N = block_size_n(); | ||||
|     at::native::cpublas::brgemm( | ||||
|         M, N, K, lda, ldb, BLOCK_N, /* add_C */false, | ||||
|         A, B, Ctmp); | ||||
|  | ||||
|     // copy from Ctmp to C | ||||
|     for (int64_t m = 0; m < M; ++m) { | ||||
|       if constexpr (has_bias) { | ||||
|         copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); | ||||
|       } else { | ||||
|         copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <typename scalar_t, bool has_bias> | ||||
| void tinygemm_kernel( | ||||
|     const scalar_t* __restrict__ A, | ||||
|     const scalar_t* __restrict__ B, | ||||
|     scalar_t* __restrict__ C, | ||||
|     float* __restrict__ Ctmp, | ||||
|     const float* __restrict__ bias, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K, | ||||
|     int64_t lda, | ||||
|     int64_t ldb, | ||||
|     int64_t ldc, | ||||
|     bool brg) { | ||||
|  | ||||
|   if (brg) { | ||||
|     brgemm<scalar_t, has_bias>::apply( | ||||
|         A, B, C, Ctmp, bias, | ||||
|         M, N, K, lda, ldb, ldc); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   // pattern: 1-4-16 | ||||
|   constexpr int64_t BLOCK_M = 4; | ||||
|   constexpr int64_t BLOCK_N = 64; | ||||
|   const int64_t MB = div_up(M, BLOCK_M); | ||||
|   const int64_t NB = div_up(N, BLOCK_N); | ||||
|   for (int mb = 0; mb < MB; ++mb) { | ||||
|     int64_t mb_start = mb * BLOCK_M; | ||||
|     int64_t mb_size = std::min(BLOCK_M, M - mb_start); | ||||
|     for (int64_t nb = 0; nb < NB; ++nb) { | ||||
|       int64_t nb_start = nb * BLOCK_N; | ||||
|       int64_t nb_size = std::min(BLOCK_N, N - nb_start); | ||||
|  | ||||
|       switch(mb_size << 4 | nb_size >> 4) { | ||||
|         // mb_size = 1 | ||||
|         case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; | ||||
|         case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break; | ||||
|         // mb_size = 2 | ||||
|         case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; | ||||
|         case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break; | ||||
|         // mb_size = 3 | ||||
|         case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; | ||||
|         case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break; | ||||
|         // mb_size = 4 | ||||
|         case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; | ||||
|         case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break; | ||||
|         default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void weight_packed_linear_kernel_impl( | ||||
|     scalar_t* __restrict__ out, | ||||
|     const scalar_t* __restrict__ mat1, | ||||
|     const scalar_t* __restrict__ mat2, | ||||
|     const float* __restrict__ bias, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K, | ||||
|     int64_t mat1_strideM, | ||||
|     int64_t out_strideM) { | ||||
|  | ||||
|   constexpr int64_t BLOCK_M = block_size_m(); | ||||
|   constexpr int64_t BLOCK_N = block_size_n(); | ||||
|   const int64_t MB = div_up(M, BLOCK_M); | ||||
|   const int64_t NB = div_up(N, BLOCK_N); | ||||
|  | ||||
|   // use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx | ||||
|   const bool use_brgemm = (M > 4) || (!std::is_same_v<scalar_t, at::BFloat16>); | ||||
|  | ||||
|   // l2 cache block for n | ||||
|   int64_t cache_blocks_nb = get_cache_blocks<scalar_t>(BLOCK_N, K); | ||||
|  | ||||
|   // parallel on [MB, NB] | ||||
|   AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { | ||||
|     parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) { | ||||
|  | ||||
|       // for brgemm, use float32 for accumulate | ||||
|       alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; | ||||
|  | ||||
|       for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) { | ||||
|       for (int64_t mb = begin_mb; mb < end_mb; ++mb) { | ||||
|       for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) { | ||||
|  | ||||
|         int64_t mb_start = mb * BLOCK_M; | ||||
|         int64_t mb_size = std::min(M - mb_start, BLOCK_M); | ||||
|         int64_t nb_start = nb * BLOCK_N; | ||||
|         int64_t nb_size = std::min(N - nb_start, BLOCK_N); | ||||
|  | ||||
|         tinygemm_kernel<scalar_t, has_bias>( | ||||
|             /*   A */ mat1 + mb_start * mat1_strideM, | ||||
|             /*   B */ mat2 + nb_start * K /* nb * BLOCK_N * K */, | ||||
|             /*   C */ out + mb_start * out_strideM + nb_start, | ||||
|             /* Ctmp*/ Ctmp, | ||||
|             /* bias*/ bias + nb_start, | ||||
|             /*   M */ mb_size, | ||||
|             /*   N */ nb_size, | ||||
|             /*   K */ K, | ||||
|             /* lda */ mat1_strideM, | ||||
|             /* ldb */ nb_size, | ||||
|             /* ldc */ out_strideM, | ||||
|             /* brg */ use_brgemm); | ||||
|       }}} | ||||
|  | ||||
|       if (use_brgemm) { | ||||
|         at::native::cpublas::brgemm_release(); | ||||
|       } | ||||
|     }); | ||||
|   }); | ||||
| } | ||||
|  | ||||
| } // anonymous namespace | ||||
|  | ||||
| // tinygemm interface | ||||
| template <typename scalar_t> | ||||
| void tinygemm_kernel(const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, | ||||
|     float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) { | ||||
|   tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg); | ||||
| } | ||||
|  | ||||
| #define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE)                                             \ | ||||
|     template void tinygemm_kernel<TYPE>(                                                \ | ||||
|         const TYPE* __restrict__ A, const TYPE* __restrict__ B, TYPE* __restrict__ C,   \ | ||||
|         float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda,         \ | ||||
|         int64_t ldb, int64_t ldc, bool brg) | ||||
|  | ||||
| INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); | ||||
| INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); | ||||
|  | ||||
| at::Tensor convert_weight_packed(at::Tensor& weight) { | ||||
|   // for 3d moe weights | ||||
|   // weight : [E, OC, IC] | ||||
|   //     w1 : [E, 2N,  K] | ||||
|   //     w2 : [E,  K,  N] | ||||
|   CHECK_INPUT(weight); | ||||
|  | ||||
|   const int64_t ndim = weight.ndimension(); | ||||
|   TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor."); | ||||
|   const auto st = weight.scalar_type(); | ||||
|   const int64_t E = ndim == 3 ? weight.size(0) : 1; | ||||
|   const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0); | ||||
|   const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1); | ||||
|  | ||||
|   // we handle 2 TILE_N at a time. | ||||
|   TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC); | ||||
|   TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC); | ||||
|  | ||||
|   constexpr int64_t BLOCK_N = block_size_n(); | ||||
|   const int64_t NB = div_up(OC, BLOCK_N); | ||||
|  | ||||
|   // use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2] | ||||
|   auto packed_weight = at::empty({}, weight.options()); | ||||
|   const int64_t stride = OC * IC; | ||||
|  | ||||
|   TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn, | ||||
|       "expect weight to be bfloat16, float16, int8 or fp8_e4m3."); | ||||
|  | ||||
|   CPU_DISPATCH_PACKED_TYPES(st, [&] { | ||||
|     // adjust most inner dimension size | ||||
|     const int packed_row_size = get_row_size<packed_t>(IC); | ||||
|     auto sizes = weight.sizes().vec(); | ||||
|     sizes[ndim - 1] = packed_row_size; | ||||
|     packed_weight.resize_(sizes); | ||||
|  | ||||
|     const packed_t* w_data = weight.data_ptr<packed_t>(); | ||||
|     packed_t* packed_data = packed_weight.data_ptr<packed_t>(); | ||||
|  | ||||
|     // parallel on {E, NB} | ||||
|     at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) { | ||||
|       int64_t e{0}, nb{0}; | ||||
|       data_index_init(begin, e, E, nb, NB); | ||||
|  | ||||
|       for (int64_t i = begin; i < end; ++i) { | ||||
|         UNUSED(i); | ||||
|  | ||||
|         int64_t n = nb * BLOCK_N; | ||||
|         int64_t n_size = std::min(BLOCK_N, OC - n); | ||||
|         pack_vnni<packed_t>( | ||||
|             packed_data + e * OC * packed_row_size + n * packed_row_size, | ||||
|             w_data + e * stride + n * IC, | ||||
|             n_size, | ||||
|             IC); | ||||
|  | ||||
|         // move to the next index | ||||
|         data_index_step(e, E, nb, NB); | ||||
|       } | ||||
|     }); | ||||
|   }); | ||||
|   return packed_weight; | ||||
| } | ||||
|  | ||||
| // mat1 : [M, K] | ||||
| // mat2 : [N, K] | ||||
| // bias : [N] | ||||
| // out  : [M, N] | ||||
| // | ||||
| at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, | ||||
|     const std::optional<at::Tensor>& bias, bool is_vnni) { | ||||
|   RECORD_FUNCTION( | ||||
|     "sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias})); | ||||
|  | ||||
|   auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); | ||||
|  | ||||
|   CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); | ||||
|   CHECK_INPUT(mat2); | ||||
|  | ||||
|   int64_t M = mat1.size(0); | ||||
|   int64_t N = mat2.size(0); | ||||
|   int64_t K = mat2.size(1); | ||||
|   CHECK_EQ(mat1.size(1), K); | ||||
|   CHECK_DIM(2, mat1); | ||||
|   CHECK_DIM(2, mat2); | ||||
|  | ||||
|   auto out = at::empty({M, N}, mat1.options()); | ||||
|  | ||||
|   // strides | ||||
|   int64_t mat1_strideM = mat1.stride(0); | ||||
|   int64_t out_strideM = out.stride(0); | ||||
|  | ||||
|   const bool has_bias = bias.has_value(); | ||||
|   const float* bias_data = nullptr; | ||||
|   if (has_bias) { | ||||
|     CHECK_EQ(bias.value().size(0), N); | ||||
|     bias_data = bias.value().data_ptr<float>(); | ||||
|   } | ||||
|  | ||||
|   AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] { | ||||
|     weight_packed_linear_kernel_impl<scalar_t>( | ||||
|         out.data_ptr<scalar_t>(), | ||||
|         mat1.data_ptr<scalar_t>(), | ||||
|         packed_w.data_ptr<scalar_t>(), | ||||
|         bias_data, | ||||
|         M, | ||||
|         N, | ||||
|         K, | ||||
|         mat1_strideM, | ||||
|         out_strideM); | ||||
|   }); | ||||
|  | ||||
|   return out; | ||||
| } | ||||
							
								
								
									
										266
									
								
								csrc/cpu/sgl-kernels/gemm.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										266
									
								
								csrc/cpu/sgl-kernels/gemm.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,266 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <ATen/native/CPUBlas.h> | ||||
|  | ||||
| // clang-format off | ||||
|  | ||||
| // amx-bf16 | ||||
| #define TILE_M 16 | ||||
| #define TILE_N 16 | ||||
| #define TILE_K 32 | ||||
|  | ||||
| // block size for AMX gemm | ||||
| constexpr int block_size_m() { return 2 * TILE_M; } | ||||
| constexpr int block_size_n() { return 2 * TILE_N; } | ||||
|  | ||||
| // define threshold using brgemm (intel AMX) | ||||
| template <typename T> inline bool can_use_brgemm(int M); | ||||
| template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; } | ||||
| template <> inline bool can_use_brgemm<at::Half>(int M) { return true; } | ||||
| // TODO: add u8s8 brgemm, this requires PyTorch 2.7 | ||||
| template <> inline bool can_use_brgemm<int8_t>(int M) { return false; } | ||||
| template <> inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) { return M > 4; } | ||||
| template <> inline bool can_use_brgemm<at::quint4x2>(int M) { return M > 4; } | ||||
|  | ||||
| // work around compiler internal error | ||||
| #define BLOCK_K 128 // 4 * TILE_K | ||||
|  | ||||
| // adjust leading dimension size for K | ||||
| template <typename T> | ||||
| inline int64_t get_row_size(int64_t K) { | ||||
|   return K; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| inline int64_t get_row_size<int8_t>(int64_t K) { | ||||
|   return K + sizeof(int32_t); | ||||
| } | ||||
|  | ||||
| inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) { | ||||
|   return use_int8_w8a8 ? K + sizeof(int32_t) : K; | ||||
| } | ||||
|  | ||||
| // pack weight to vnni format | ||||
| at::Tensor convert_weight_packed(at::Tensor& weight); | ||||
|  | ||||
| // moe implementations for int8 w8a8 | ||||
| template <typename scalar_t> | ||||
| void fused_experts_int8_kernel_impl( | ||||
|     scalar_t* __restrict__ output, | ||||
|     scalar_t* __restrict__ ic1, | ||||
|     scalar_t* __restrict__ ic2, | ||||
|     uint8_t* __restrict__ A_tmp, | ||||
|     float* __restrict__ C_tmp, | ||||
|     uint8_t* __restrict__ Aq_tmp, | ||||
|     float* __restrict__ As_tmp, | ||||
|     const scalar_t* __restrict__ input, | ||||
|     const int8_t* __restrict__ packed_w1, | ||||
|     const int8_t* __restrict__ packed_w2, | ||||
|     const float* __restrict__ w1s, | ||||
|     const float* __restrict__ w2s, | ||||
|     const float* __restrict__ topk_weights, | ||||
|     const int32_t* __restrict__ sorted_ids, | ||||
|     const int32_t* __restrict__ expert_ids, | ||||
|     const int32_t* __restrict__ offsets, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K, | ||||
|     int64_t E, | ||||
|     int64_t topk, | ||||
|     int64_t num_tokens_post_pad); | ||||
|  | ||||
| // moe implementations for fp8 w8a16 | ||||
| template <typename scalar_t> | ||||
| void fused_experts_fp8_kernel_impl( | ||||
|     scalar_t* __restrict__ output, | ||||
|     scalar_t* __restrict__ ic0, | ||||
|     scalar_t* __restrict__ ic1, | ||||
|     scalar_t* __restrict__ ic2, | ||||
|     scalar_t* __restrict__ A_tmp, | ||||
|     scalar_t* __restrict__ B_tmp, | ||||
|     float* __restrict__ C_tmp, | ||||
|     const scalar_t* __restrict__ input, | ||||
|     const at::Float8_e4m3fn* __restrict__ packed_w1, | ||||
|     const at::Float8_e4m3fn* __restrict__ packed_w2, | ||||
|     const float* __restrict__ w1s, | ||||
|     const float* __restrict__ w2s, | ||||
|     int64_t block_size_N, | ||||
|     int64_t block_size_K, | ||||
|     const float* __restrict__ topk_weights, | ||||
|     const int32_t* __restrict__ sorted_ids, | ||||
|     const int32_t* __restrict__ expert_ids, | ||||
|     const int32_t* __restrict__ offsets, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K, | ||||
|     int64_t E, | ||||
|     int64_t topk, | ||||
|     int64_t num_tokens_post_pad); | ||||
|  | ||||
| // moe implementations for int4 w4a16 | ||||
| template <typename scalar_t> | ||||
| void fused_experts_int4_w4a16_kernel_impl( | ||||
|     scalar_t* __restrict__ output, | ||||
|     scalar_t* __restrict__ ic0, | ||||
|     scalar_t* __restrict__ ic1, | ||||
|     scalar_t* __restrict__ ic2, | ||||
|     scalar_t* __restrict__ A_tmp, | ||||
|     scalar_t* __restrict__ B_tmp, | ||||
|     float* __restrict__ C_tmp, | ||||
|     const scalar_t* __restrict__ input, | ||||
|     const at::quint4x2* __restrict__ packed_w1, | ||||
|     const at::quint4x2* __restrict__ packed_w2, | ||||
|     const uint8_t* __restrict__ w1z, | ||||
|     const uint8_t* __restrict__ w2z, | ||||
|     const scalar_t* __restrict__ w1s, | ||||
|     const scalar_t* __restrict__ w2s, | ||||
|     int group_size, | ||||
|     const float* __restrict__ topk_weights, | ||||
|     const int32_t* __restrict__ sorted_ids, | ||||
|     const int32_t* __restrict__ expert_ids, | ||||
|     const int32_t* __restrict__ offsets, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K, | ||||
|     int64_t E, | ||||
|     int64_t topk, | ||||
|     int64_t num_tokens_post_pad); | ||||
|  | ||||
| // shared expert implememntation for int8 w8a8 | ||||
| template <typename scalar_t> | ||||
| void shared_expert_int8_kernel_impl( | ||||
|     scalar_t* __restrict__ output, | ||||
|     scalar_t* __restrict__ ic1, | ||||
|     float* __restrict__ C_tmp, | ||||
|     uint8_t* __restrict__ Aq_tmp, | ||||
|     float* __restrict__ As_tmp, | ||||
|     const scalar_t* __restrict__ input, | ||||
|     const int8_t* __restrict__ packed_w1, | ||||
|     const int8_t* __restrict__ packed_w2, | ||||
|     const float* __restrict__ w1s, | ||||
|     const float* __restrict__ w2s, | ||||
|     const scalar_t* __restrict__ fused_experts_out, | ||||
|     float routed_scaling_factor, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K); | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void shared_expert_fp8_kernel_impl( | ||||
|     scalar_t* __restrict__ output, | ||||
|     scalar_t* __restrict__ ic0, | ||||
|     scalar_t* __restrict__ ic1, | ||||
|     scalar_t* __restrict__ B_tmp, | ||||
|     float* __restrict__ C_tmp, | ||||
|     const scalar_t* __restrict__ input, | ||||
|     const at::Float8_e4m3fn* __restrict__ packed_w1, | ||||
|     const at::Float8_e4m3fn* __restrict__ packed_w2, | ||||
|     const float* __restrict__ w1s, | ||||
|     const float* __restrict__ w2s, | ||||
|     int64_t block_size_N, | ||||
|     int64_t block_size_K, | ||||
|     const scalar_t* __restrict__ fused_experts_out, | ||||
|     float routed_scaling_factor, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K); | ||||
|  | ||||
| // tinygemm interface | ||||
| template <typename scalar_t> | ||||
| void tinygemm_kernel( | ||||
|     const scalar_t* __restrict__ A, | ||||
|     const scalar_t* __restrict__ B, | ||||
|     scalar_t* __restrict__ C, | ||||
|     float* __restrict__ Ctmp, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K, | ||||
|     int64_t lda, | ||||
|     int64_t ldb, | ||||
|     int64_t ldc, | ||||
|     bool brg); | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void tinygemm_kernel( | ||||
|     const uint8_t* __restrict__ A, | ||||
|     const int8_t* __restrict__ B, | ||||
|     scalar_t* __restrict__ C, | ||||
|     int32_t* __restrict__ Ctmp, | ||||
|     const float* __restrict__ As, | ||||
|     const float* __restrict__ Bs, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K, | ||||
|     int64_t lda, | ||||
|     int64_t ldb, | ||||
|     int64_t ldc, | ||||
|     bool brg); | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void tinygemm_kernel( | ||||
|     const scalar_t* __restrict__ A, | ||||
|     const at::Float8_e4m3fn* __restrict__ B, | ||||
|     scalar_t* __restrict__ C, | ||||
|     scalar_t* __restrict__ Btmp, | ||||
|     float* __restrict__ Ctmp, | ||||
|     const float* __restrict__ scale, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K, | ||||
|     int64_t lda, | ||||
|     int64_t ldb, | ||||
|     int64_t ldc, | ||||
|     bool brg, | ||||
|     int64_t block_size_K); | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void tinygemm_kernel( | ||||
|     const scalar_t* __restrict__ A, | ||||
|     const at::quint4x2* __restrict__ B, | ||||
|     scalar_t* __restrict__ C, | ||||
|     const uint8_t* __restrict__ Bz, | ||||
|     const scalar_t* __restrict__ Bs, | ||||
|     scalar_t* __restrict__ Btmp, | ||||
|     float* __restrict__ Ctmp, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K, | ||||
|     int group_size, | ||||
|     int64_t lda, | ||||
|     int64_t ldb, | ||||
|     int64_t ldc, | ||||
|     int64_t strideBz, | ||||
|     int64_t strideBs, | ||||
|     bool brg); | ||||
|  | ||||
| // TODO: debug print, remove me later | ||||
| inline void print_16x32i(const __m512i x) { | ||||
|   int32_t a[16]; | ||||
|   _mm512_storeu_si512((__m512i *)a, x); | ||||
|  | ||||
|   for (int i = 0; i < 16; i++){ | ||||
|     std::cout << a[i] << " "; | ||||
|   } | ||||
|   std::cout << std::endl; | ||||
| } | ||||
|  | ||||
| inline void print_16x32(const __m512 x) { | ||||
|   float a[16]; | ||||
|   _mm512_storeu_ps((__m512 *)a, x); | ||||
|  | ||||
|   for (int i = 0; i < 16; i++){ | ||||
|     std::cout << a[i] << " "; | ||||
|   } | ||||
|   std::cout << std::endl; | ||||
| } | ||||
|  | ||||
|  | ||||
| inline void print_32x8u(const __m256i x) { | ||||
|   uint8_t a[32]; | ||||
|   _mm256_storeu_si256((__m256i *)a, x); | ||||
|  | ||||
|   for (int i = 0; i < 32; ++i) { | ||||
|     std::cout << int32_t(a[i]) << " "; | ||||
|   } | ||||
|   std::cout << std::endl; | ||||
| } | ||||
							
								
								
									
										530
									
								
								csrc/cpu/sgl-kernels/gemm_fp8.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										530
									
								
								csrc/cpu/sgl-kernels/gemm_fp8.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,530 @@ | ||||
| // Adapted from | ||||
| // https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu | ||||
|  | ||||
| #include "common.h" | ||||
| #include "vec.h" | ||||
| #include "gemm.h" | ||||
|  | ||||
| // clang-format off | ||||
|  | ||||
| // we use 4x32 for BLOCK_M | ||||
| #define BLOCK_SIZE_M_SCALE 4 | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| template <typename scalar_t> | ||||
| inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { | ||||
|   using bVec = at::vec::Vectorized<scalar_t>; | ||||
|   using fVec = at::vec::Vectorized<float>; | ||||
|   constexpr int kVecSize = bVec::size(); | ||||
|  | ||||
|   int64_t d; | ||||
|   #pragma GCC unroll 4 | ||||
|   for (d = 0; d <= size - kVecSize; d += kVecSize) { | ||||
|     fVec data0 = fVec::loadu(input + d); | ||||
|     fVec data1 = fVec::loadu(input + d + fVec::size()); | ||||
|     bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1); | ||||
|     out_vec.store(out + d); | ||||
|   } | ||||
|   for (; d < size; ++d) { | ||||
|     out[d] = static_cast<scalar_t>(input[d]); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) { | ||||
|   using bVec = at::vec::Vectorized<scalar_t>; | ||||
|   using fVec = at::vec::Vectorized<float>; | ||||
|   constexpr int kVecSize = bVec::size(); | ||||
|  | ||||
|   int64_t d; | ||||
|   #pragma GCC unroll 4 | ||||
|   for (d = 0; d <= size - kVecSize; d += kVecSize) { | ||||
|     fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d); | ||||
|     fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size()); | ||||
|     bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1); | ||||
|     out_vec.store(out + d); | ||||
|   } | ||||
|   for (; d < size; ++d) { | ||||
|     out[d] = static_cast<scalar_t>(input[d] + bias[d]); | ||||
|   } | ||||
| } | ||||
|  | ||||
| inline void unpack_B( | ||||
|     at::BFloat16* __restrict__ Btmp, | ||||
|     const at::Float8_e4m3fn* __restrict__ packed_B, | ||||
|     int N, | ||||
|     int K, | ||||
|     int ldb, | ||||
|     int ldb_tmp, | ||||
|     float scale) { | ||||
| #if defined(CPU_CAPABILITY_AVX512) | ||||
|   // [K/2, N, 2] | ||||
|   const int K2 = K >> 1; | ||||
|   const int ldb2 = ldb; // ldb * 2 >> 1; | ||||
|   const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(packed_B); | ||||
|   const __m512 vd = _mm512_set1_ps(scale); | ||||
|  | ||||
|   constexpr int BLOCK_N = block_size_n(); | ||||
|   static_assert(BLOCK_N == 32); | ||||
|  | ||||
|   // prefetch distance | ||||
|   constexpr int PREFETCH_SIZE_K = 64; | ||||
|  | ||||
| #pragma GCC unroll 4 | ||||
|   for (int k = 0; k < K2; ++k) { | ||||
|     __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2); | ||||
|     if constexpr (PREFETCH_SIZE_K > 0) { | ||||
|       _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0); | ||||
|     } | ||||
|  | ||||
|     __m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0); | ||||
|     __m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1); | ||||
|  | ||||
|     __m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0); | ||||
|     __m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1); | ||||
|  | ||||
|     // Apply scale | ||||
|     __m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0)); | ||||
|     __m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1)); | ||||
|     __m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0)); | ||||
|     __m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1)); | ||||
|  | ||||
|     f0_lo = _mm512_mul_ps(f0_lo, vd); | ||||
|     f0_hi = _mm512_mul_ps(f0_hi, vd); | ||||
|     f1_lo = _mm512_mul_ps(f1_lo, vd); | ||||
|     f1_hi = _mm512_mul_ps(f1_hi, vd); | ||||
|  | ||||
|     bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo); | ||||
|     bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo); | ||||
|  | ||||
|     _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0); | ||||
|     _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1); | ||||
|   } | ||||
| #else | ||||
|   TORCH_CHECK(false, "unpack_B: scalar path not implemented!"); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <typename scalar_t, typename packed_t, bool has_bias, int BLOCK_M, int BLOCK_N> | ||||
| struct tinygemm_kernel_nn { | ||||
|   static inline void apply( | ||||
|       const scalar_t* __restrict__ A, const packed_t* __restrict__ B, scalar_t* __restrict__ C, | ||||
|       const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) { | ||||
|     TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| #if defined(CPU_CAPABILITY_AVX512) | ||||
| template <bool has_bias, int BLOCK_M, int BLOCK_N> | ||||
| struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BLOCK_N> { | ||||
|   static inline void apply( | ||||
|       const at::BFloat16* __restrict__ A, const at::Float8_e4m3fn* __restrict__ B, at::BFloat16* __restrict__ C, | ||||
|       const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) { | ||||
|  | ||||
|     constexpr int ROWS = BLOCK_M; | ||||
|     constexpr int COLS = BLOCK_N / 16; | ||||
|  | ||||
|     const int KB = div_up(K, BLOCK_K); | ||||
|  | ||||
|     // prefetch distance | ||||
|     constexpr int PREFETCH_SIZE_K = 64; | ||||
|     constexpr int PREFETCH_SIZE_KB = 1; | ||||
|  | ||||
|     __m512bh va; | ||||
|     __m512bh vb[COLS]; | ||||
|     __m512 vc[ROWS * COLS]; | ||||
|     __m512 vsum[ROWS * COLS]; | ||||
|  | ||||
|     // block quant scale | ||||
|     __m512 vscale; | ||||
|  | ||||
|     auto loadc = [&](auto i) { | ||||
|       constexpr int col = i % COLS; | ||||
|       if constexpr (has_bias) { | ||||
|         vc[i] = _mm512_loadu_ps(bias + col * 16); | ||||
|       } else { | ||||
|         vc[i] = _mm512_setzero_ps(); | ||||
|       } | ||||
|     }; | ||||
|     Unroll<ROWS * COLS>{}(loadc); | ||||
|  | ||||
|     const int lda2 = lda >> 1; | ||||
|     const int ldb2 = ldb; // ldb * 2 >> 1; | ||||
|     const float* a_ptr = reinterpret_cast<const float*>(A); | ||||
|     const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(B); | ||||
|  | ||||
|     auto compute = [&](auto i, int k) { | ||||
|       constexpr int row = i / COLS; | ||||
|       constexpr int col = i % COLS; | ||||
|  | ||||
|       if constexpr (col == 0) { | ||||
|         va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); | ||||
|         if constexpr (PREFETCH_SIZE_K > 0) { | ||||
|           _mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0); | ||||
|         } | ||||
|       } | ||||
|       if constexpr (row == 0) { | ||||
|         if constexpr (col % 2 == 0) { | ||||
|           __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16); | ||||
|           if constexpr (PREFETCH_SIZE_K > 0) { | ||||
|             _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); | ||||
|           } | ||||
|           vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0)); | ||||
|           vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1)); | ||||
|         } | ||||
|       } | ||||
|       vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]); | ||||
|     }; | ||||
|  | ||||
|     constexpr int BLOCK_K2 = BLOCK_K >> 1; | ||||
|     for (int kb = 0; kb < KB; ++kb) { | ||||
|       int kb_start = kb * BLOCK_K2; | ||||
|       int kb_end = std::min(K, kb_start + BLOCK_K2); | ||||
|       // 1. load scale vector | ||||
|       vscale = _mm512_set1_ps(scale[kb]); | ||||
|       if constexpr (PREFETCH_SIZE_KB > 0) { | ||||
|         _mm_prefetch(scale + kb + PREFETCH_SIZE_KB, _MM_HINT_T0); | ||||
|       } | ||||
|       // 2. zero vsum for each block | ||||
|       Unroll<ROWS * COLS>{}([&](auto i) { | ||||
|         vsum[i] = _mm512_setzero_ps(); | ||||
|       }); | ||||
|       // 3. accumulate across each block | ||||
|       for (int k = kb_start; k < kb_end; ++k) { | ||||
|         Unroll<ROWS * COLS>{}(compute, k); | ||||
|       } | ||||
|       // 4. apply scale | ||||
|       Unroll<ROWS * COLS>{}([&](auto i) { | ||||
|         vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]); | ||||
|       }); | ||||
|     } | ||||
|  | ||||
|     auto storec = [&](auto i) { | ||||
|       constexpr int row = i / COLS; | ||||
|       constexpr int col = i % COLS; | ||||
|       // for COLS = 2,4 use 512bit store | ||||
|       if constexpr (col % 2 == 0) { | ||||
|         _mm512_storeu_si512( | ||||
|             reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), | ||||
|             (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); | ||||
|       } | ||||
|     }; | ||||
|     Unroll<ROWS * COLS>{}(storec); | ||||
|   } | ||||
| }; | ||||
| #endif | ||||
|  | ||||
| #define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE)                          \ | ||||
|     tinygemm_kernel_nn<scalar_t, at::Float8_e4m3fn, has_bias, MB_SIZE, NB_SIZE>::apply(         \ | ||||
|         A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \ | ||||
|         has_bias ? bias + nb_start : nullptr, scale, K, lda, ldb, ldc, block_size_K); | ||||
|  | ||||
| template <typename scalar_t, typename packed_t, bool has_bias> | ||||
| struct brgemm { | ||||
|   static inline void apply( | ||||
|       const scalar_t* __restrict__ A, | ||||
|       const packed_t* __restrict__ B, | ||||
|       scalar_t* __restrict__ C, | ||||
|       scalar_t* __restrict__ Btmp, | ||||
|       float* __restrict__ Ctmp, | ||||
|       const float* __restrict__ bias, | ||||
|       const float* __restrict__ scale, | ||||
|       int M, | ||||
|       int N, | ||||
|       int K, | ||||
|       int lda, | ||||
|       int ldb, | ||||
|       int ldc) { | ||||
|     TORCH_CHECK(false, "struct brgemm: primary template not implemented!"); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <bool has_bias> | ||||
| struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> { | ||||
|   static inline void apply( | ||||
|       const at::BFloat16* __restrict__ A, | ||||
|       const at::Float8_e4m3fn* __restrict__ B, | ||||
|       at::BFloat16* __restrict__ C, | ||||
|       at::BFloat16* __restrict__ Btmp, | ||||
|       float* __restrict__ Ctmp, | ||||
|       const float* __restrict__ bias, | ||||
|       const float* __restrict__ scale, | ||||
|       int M, | ||||
|       int N, | ||||
|       int K, | ||||
|       int lda, | ||||
|       int ldb, | ||||
|       int ldc) { | ||||
|  | ||||
|     constexpr int BLOCK_N = block_size_n(); | ||||
|  | ||||
|     // [K, BLOCK_N] -> [K / 2, BLOCK_N * 2] | ||||
|     const int ldb_tmp = BLOCK_N; | ||||
|  | ||||
|     for (int k = 0; k < K; k += BLOCK_K) { | ||||
|       int kb_size = std::min(BLOCK_K, K - k); | ||||
|  | ||||
|       int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128 | ||||
|       unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]); | ||||
|     } | ||||
|  | ||||
|     at::native::cpublas::brgemm( | ||||
|         M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp); | ||||
|  | ||||
|     // copy from Ctmp to C | ||||
|     for (int m = 0; m < M; ++m) { | ||||
|       if constexpr (has_bias) { | ||||
|         copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); | ||||
|       } else { | ||||
|         copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <typename scalar_t, bool has_bias> | ||||
| void tinygemm_kernel( | ||||
|     const scalar_t* __restrict__ A, | ||||
|     const at::Float8_e4m3fn* __restrict__ B, | ||||
|     scalar_t* __restrict__ C, | ||||
|     scalar_t* __restrict__ Btmp, | ||||
|     float* __restrict__ Ctmp, | ||||
|     const float* __restrict__ scale, | ||||
|     const float* __restrict__ bias, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K, | ||||
|     int64_t lda, | ||||
|     int64_t ldb, | ||||
|     int64_t ldc, | ||||
|     bool brg, | ||||
|     int64_t block_size_K) { | ||||
|  | ||||
|   if (brg) { | ||||
|     brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply( | ||||
|         A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc); | ||||
|     return; | ||||
|   } | ||||
|  | ||||
|   // pattern: 1-4-16 | ||||
|   constexpr int64_t BLOCK_M = 4; | ||||
|   constexpr int64_t BLOCK_N = 64; | ||||
|   const int64_t MB = div_up(M, BLOCK_M); | ||||
|   const int64_t NB = div_up(N, BLOCK_N); | ||||
|   for (int mb = 0; mb < MB; ++mb) { | ||||
|     int64_t mb_start = mb * BLOCK_M; | ||||
|     int64_t mb_size = std::min(BLOCK_M, M - mb_start); | ||||
|     for (int64_t nb = 0; nb < NB; ++nb) { | ||||
|       int64_t nb_start = nb * BLOCK_N; | ||||
|       int64_t nb_size = std::min(BLOCK_N, N - nb_start); | ||||
|  | ||||
|       switch(mb_size << 4 | nb_size >> 4) { | ||||
|         case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; | ||||
|         case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; | ||||
|         case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; | ||||
|         case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; | ||||
|         default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void fp8_scaled_mm_kernel_impl( | ||||
|     scalar_t* __restrict__ out, | ||||
|     const scalar_t* __restrict__ mat1, | ||||
|     const at::Float8_e4m3fn* __restrict__ mat2, | ||||
|     const float* __restrict__ scales2, | ||||
|     const float* __restrict__ bias, | ||||
|     scalar_t* __restrict__ buffer, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K, | ||||
|     int64_t mat1_strideM, | ||||
|     int64_t out_strideM, | ||||
|     int64_t block_size_N, | ||||
|     int64_t block_size_K, | ||||
|     int64_t buffer_size_per_thread) { | ||||
|  | ||||
|   constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; | ||||
|   constexpr int64_t BLOCK_N = block_size_n(); | ||||
|   const int64_t MB = div_up(M, BLOCK_M); | ||||
|   const int64_t NB = div_up(N, BLOCK_N); | ||||
|  | ||||
|   const int64_t scale_size_K = div_up(K, block_size_K); | ||||
|   const int64_t blocks_n_per_group = block_size_N / BLOCK_N; | ||||
|  | ||||
|   const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M); | ||||
|  | ||||
|   // parallel on [MB, NB] | ||||
|   AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { | ||||
|     at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { | ||||
|       int64_t mb{0}, nb{0}; | ||||
|       data_index_init(begin, mb, MB, nb, NB); | ||||
|  | ||||
|       int tid = at::get_thread_num(); | ||||
|       scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread; | ||||
|       float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K)); | ||||
|  | ||||
|       for (int64_t i = begin; i < end; ++i) { | ||||
|         UNUSED(i); | ||||
|         const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K; | ||||
|  | ||||
|         int64_t mb_start = mb * BLOCK_M; | ||||
|         int64_t mb_size = std::min(M - mb_start, BLOCK_M); | ||||
|         int64_t nb_start = nb * BLOCK_N; | ||||
|         int64_t nb_size = std::min(N - nb_start, BLOCK_N); | ||||
|  | ||||
|         tinygemm_kernel<scalar_t, has_bias>( | ||||
|             /*   A            */ mat1 + mb_start * mat1_strideM, | ||||
|             /*   B            */ mat2 + nb_start * K, // nb * BLOCK_N * K | ||||
|             /*   C            */ out + mb_start * out_strideM + nb_start, | ||||
|             /*   Btmp         */ Btmp, | ||||
|             /*   Ctmp         */ Ctmp, | ||||
|             /*   scale        */ scale_ptr, | ||||
|             /*   bias         */ bias + nb_start, | ||||
|             /*   M            */ mb_size, | ||||
|             /*   N            */ nb_size, | ||||
|             /*   K            */ K, | ||||
|             /*   lda          */ mat1_strideM, | ||||
|             /*   ldb          */ nb_size, | ||||
|             /*   ldc          */ out_strideM, | ||||
|             /*   brg          */ use_brgemm, | ||||
|             /*   block_size_K */ block_size_K); | ||||
|  | ||||
|         // move to the next index | ||||
|         data_index_step(mb, MB, nb, NB); | ||||
|       } | ||||
|  | ||||
|       if (use_brgemm) { | ||||
|         at::native::cpublas::brgemm_release(); | ||||
|       } | ||||
|     }); | ||||
|   }); | ||||
| } | ||||
|  | ||||
| } // anonymous namespace | ||||
|  | ||||
| // tinygemm interface | ||||
| template <typename scalar_t> | ||||
| void tinygemm_kernel( | ||||
|     const scalar_t* __restrict__ A, | ||||
|     const at::Float8_e4m3fn* __restrict__ B, | ||||
|     scalar_t* __restrict__ C, | ||||
|     scalar_t* __restrict__ Btmp, | ||||
|     float* __restrict__ Ctmp, | ||||
|     const float* __restrict__ scale, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K, | ||||
|     int64_t lda, | ||||
|     int64_t ldb, | ||||
|     int64_t ldc, | ||||
|     bool brg, | ||||
|     int64_t block_size_K) { | ||||
|   tinygemm_kernel<scalar_t, false>(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K); | ||||
| } | ||||
|  | ||||
| #define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE)    \ | ||||
|   template void tinygemm_kernel<TYPE>(         \ | ||||
|       const TYPE* __restrict__ A,              \ | ||||
|       const at::Float8_e4m3fn* __restrict__ B, \ | ||||
|       TYPE* __restrict__ C,                    \ | ||||
|       TYPE* __restrict__ Btmp,                 \ | ||||
|       float* __restrict__ Ctmp,                \ | ||||
|       const float* __restrict__ scale,         \ | ||||
|       int64_t M,                               \ | ||||
|       int64_t N,                               \ | ||||
|       int64_t K,                               \ | ||||
|       int64_t lda,                             \ | ||||
|       int64_t ldb,                             \ | ||||
|       int64_t ldc,                             \ | ||||
|       bool brg,                                \ | ||||
|       int64_t block_size_K) | ||||
|  | ||||
| INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); | ||||
| INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); | ||||
|  | ||||
| at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, | ||||
|     std::vector<int64_t> block_size, std::optional<at::Tensor>& bias, | ||||
|     at::ScalarType out_dtype, bool is_vnni) { | ||||
|   RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias})); | ||||
|  | ||||
|   auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); | ||||
|  | ||||
|   CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); | ||||
|   CHECK_INPUT(mat2); | ||||
|   CHECK_INPUT(scales2); | ||||
|   TORCH_CHECK(scales2.scalar_type() == at::kFloat, | ||||
|       "fp8_scaled_mm_cpu: expect scales2 to be float32."); | ||||
|  | ||||
|   int64_t M = mat1.size(0); | ||||
|   int64_t N = mat2.size(0); | ||||
|   int64_t K = mat2.size(1); | ||||
|  | ||||
|   CHECK_EQ(mat1.size(1), K); | ||||
|   CHECK_DIM(2, mat1); | ||||
|   CHECK_DIM(2, mat2); | ||||
|  | ||||
|   TORCH_CHECK(block_size.size() == 2, | ||||
|       "fp8_scaled_mm_cpu: expect block_size.size() to be 2."); | ||||
|  | ||||
|   int64_t block_size_N = block_size[0]; | ||||
|   int64_t block_size_K = block_size[1]; | ||||
|  | ||||
|   constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; | ||||
|   constexpr int64_t BLOCK_N = block_size_n(); | ||||
|   TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N"); | ||||
|   TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K"); | ||||
|   CHECK_EQ(scales2.size(0), div_up(N, block_size_N)); | ||||
|   CHECK_EQ(scales2.size(1), div_up(K, block_size_K)); | ||||
|  | ||||
|   const auto st = mat1.scalar_type(); | ||||
|   TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, | ||||
|       "fp8_scaled_mm_cpu: expect A to be bfloat16 or half."); | ||||
|   TORCH_CHECK(st == out_dtype, | ||||
|       "fp8_scaled_mm_cpu: expect A has same dtype with out_dtype."); | ||||
|   TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn, | ||||
|       "fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3."); | ||||
|   TORCH_CHECK(scales2.scalar_type() == at::kFloat, | ||||
|       "fp8_scaled_mm_cpu: expect scales to be float32."); | ||||
|   auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); | ||||
|  | ||||
|   // strides | ||||
|   int64_t mat1_strideM = mat1.stride(0); | ||||
|   int64_t out_strideM = out.stride(0); | ||||
|  | ||||
|   const bool has_bias = bias.has_value(); | ||||
|   const float* bias_data = nullptr; | ||||
|   if (has_bias) { | ||||
|     CHECK_EQ(bias.value().size(0), N); | ||||
|     bias_data = bias.value().data_ptr<float>(); | ||||
|   } | ||||
|  | ||||
|   // Btmp : [T, BLOCK_N * K] | ||||
|   // Ctmp : [T, BLOCK_M * BLOCK_N] | ||||
|   int num_threads = at::get_num_threads(); | ||||
|   int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2; | ||||
|   auto buffer = at::empty({num_threads, size_per_thread}, mat1.options()); | ||||
|  | ||||
|   AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] { | ||||
|     fp8_scaled_mm_kernel_impl<scalar_t>( | ||||
|         out.data_ptr<scalar_t>(), | ||||
|         mat1.data_ptr<scalar_t>(), | ||||
|         packed_w.data_ptr<at::Float8_e4m3fn>(), | ||||
|         scales2.data_ptr<float>(), | ||||
|         bias_data, | ||||
|         buffer.data_ptr<scalar_t>(), | ||||
|         M, | ||||
|         N, | ||||
|         K, | ||||
|         mat1_strideM, | ||||
|         out_strideM, | ||||
|         block_size_N, | ||||
|         block_size_K, | ||||
|         size_per_thread); | ||||
|   }); | ||||
|  | ||||
|   return out; | ||||
| } | ||||
							
								
								
									
										440
									
								
								csrc/cpu/sgl-kernels/gemm_int8.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										440
									
								
								csrc/cpu/sgl-kernels/gemm_int8.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,440 @@ | ||||
| // Adapted from | ||||
| // https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu | ||||
|  | ||||
| #include "common.h" | ||||
| #include "vec.h" | ||||
| #include "gemm.h" | ||||
|  | ||||
| // clang-format off | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N> | ||||
| struct tinygemm_kernel_nn { | ||||
|   static inline void apply( | ||||
|       const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C, | ||||
|       const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, | ||||
|       const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||
|     TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| #if defined(CPU_CAPABILITY_AVX512) | ||||
| template <bool has_bias, int BLOCK_M, int BLOCK_N> | ||||
| struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> { | ||||
|   static inline void apply( | ||||
|       const uint8_t* __restrict__ A, const int8_t* __restrict__ B, at::BFloat16* __restrict__ C, | ||||
|       const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, | ||||
|       const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||
|  | ||||
|     constexpr int ROWS = BLOCK_M; | ||||
|     constexpr int COLS = BLOCK_N / 16; | ||||
|     static_assert(COLS % 2 == 0); | ||||
|  | ||||
|     // prefetch distance | ||||
|     constexpr int PREFETCH_SIZE_K = 0; | ||||
|  | ||||
|     __m512i va; | ||||
|     __m512i vb[COLS]; | ||||
|     __m512i vc[ROWS * COLS]; | ||||
|     __m512i vcomp[COLS]; | ||||
|     __m512  vd0; | ||||
|     __m512  vd1[COLS]; | ||||
|  | ||||
|     // oops! 4x4 spills but luckly we use 4x2 | ||||
|     __m512 vbias[COLS]; | ||||
|  | ||||
|     // [NOTE]: s8s8 igemm compensation in avx512-vnni | ||||
|     // | ||||
|     // avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate: | ||||
|     // | ||||
|     //   a * b = (a + 128) * b - 128 * b | ||||
|     //   s   s       u       s    u    s | ||||
|     // | ||||
|     // 1) 128 * b is pre-computed when packing B to vnni formats | ||||
|     // 2) a + 128 is fused when dynamically quantize A | ||||
|     // | ||||
|     auto loadc = [&](auto i) { | ||||
|       vc[i] = _mm512_set1_epi32(0); | ||||
|     }; | ||||
|     Unroll<ROWS * COLS>{}(loadc); | ||||
|  | ||||
|     const int64_t K4 = K >> 2; | ||||
|     const int64_t lda4 = lda >> 2; | ||||
|     const int64_t ldb4 = ldb; // ldb * 4 >> 2; | ||||
|     const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A); | ||||
|     const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B); | ||||
|  | ||||
|     auto compute = [&](auto i, int64_t k) { | ||||
|       constexpr int row = i / COLS; | ||||
|       constexpr int col = i % COLS; | ||||
|  | ||||
|       if constexpr (col == 0) { | ||||
|         va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); | ||||
|       } | ||||
|       if constexpr (row == 0) { | ||||
|         vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16); | ||||
|         if constexpr (PREFETCH_SIZE_K > 0) { | ||||
|           _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0); | ||||
|         } | ||||
|       } | ||||
|       vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); | ||||
|     }; | ||||
|     for (int64_t k = 0; k < K4; ++k) { | ||||
|       Unroll<ROWS * COLS>{}(compute, k); | ||||
|     } | ||||
|  | ||||
|     auto storec = [&](auto i) { | ||||
|       constexpr int row = i / COLS; | ||||
|       constexpr int col = i % COLS; | ||||
|  | ||||
|       // load a scale | ||||
|       if constexpr(col == 0) { | ||||
|         vd0 = _mm512_set1_ps(As[row]); | ||||
|       } | ||||
|       // load b scale and vcomp per 2 vectors | ||||
|       // also load bias if any | ||||
|       if constexpr (row == 0) { | ||||
|         if constexpr (col % 2 == 0) { | ||||
|           vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16); | ||||
|           vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); | ||||
|           vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); | ||||
|           vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); | ||||
|           if constexpr (has_bias) { | ||||
|             vbias[col + 0] = _mm512_loadu_ps(bias + col * 16); | ||||
|             vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       // for COLS = 2, 4 use 512bit store | ||||
|       if constexpr (col % 2 == 0) { | ||||
|         __m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0])); | ||||
|         __m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1])); | ||||
|         if constexpr (has_bias) { | ||||
|           vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]); | ||||
|           vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]); | ||||
|         } else { | ||||
|           vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]); | ||||
|           vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]); | ||||
|         } | ||||
|  | ||||
|         _mm512_storeu_si512( | ||||
|             reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), | ||||
|             (__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0))); | ||||
|       } | ||||
|     }; | ||||
|     Unroll<ROWS * COLS>{}(storec); | ||||
|   } | ||||
| }; | ||||
| #endif | ||||
|  | ||||
| #define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE)                          \ | ||||
|     tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply(         \ | ||||
|         A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \ | ||||
|         As + mb_start, Bs + nb_start, Bcomp + nb_start,                      \ | ||||
|         has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc); | ||||
|  | ||||
| template <typename scalar_t, bool has_bias> | ||||
| void tinygemm_kernel( | ||||
|     const uint8_t* __restrict__ A, | ||||
|     const int8_t* __restrict__ B, | ||||
|     scalar_t* __restrict__ C, | ||||
|     int32_t* __restrict__ Ctmp, | ||||
|     const float* __restrict__ As, | ||||
|     const float* __restrict__ Bs, | ||||
|     const float* __restrict__ bias, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K, | ||||
|     int64_t lda, | ||||
|     int64_t ldb, | ||||
|     int64_t ldc, | ||||
|     bool brg) { | ||||
|  | ||||
|   // B compensation | ||||
|   const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K); | ||||
|  | ||||
|   // pattern: 1-4-16 | ||||
|   constexpr int64_t BLOCK_M = 4; | ||||
|   constexpr int64_t BLOCK_N = 64; | ||||
|   const int64_t MB = div_up(M, BLOCK_M); | ||||
|   const int64_t NB = div_up(N, BLOCK_N); | ||||
|   for (int64_t mb = 0; mb < MB; ++mb) { | ||||
|     int64_t mb_start = mb * BLOCK_M; | ||||
|     int64_t mb_size = std::min(BLOCK_M, M - mb_start); | ||||
|     for (int64_t nb = 0; nb < NB; ++nb) { | ||||
|       int64_t nb_start = nb * BLOCK_N; | ||||
|       int64_t nb_size = std::min(BLOCK_N, N - nb_start); | ||||
|  | ||||
|       switch(mb_size << 4 | nb_size >> 4) { | ||||
|         // mb_size = 1 | ||||
|         case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; | ||||
|         case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break; | ||||
|         // mb_size = 2 | ||||
|         case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; | ||||
|         case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break; | ||||
|         // mb_size = 3 | ||||
|         case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; | ||||
|         case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break; | ||||
|         // mb_size = 4 | ||||
|         case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; | ||||
|         case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break; | ||||
|         default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| void int8_scaled_mm_kernel_impl( | ||||
|     scalar_t* __restrict__ out, | ||||
|     const uint8_t* __restrict__ mat1, | ||||
|     const int8_t* __restrict__ mat2, | ||||
|     const float* __restrict__ scales1, | ||||
|     const float* __restrict__ scales2, | ||||
|     const float* __restrict__ bias, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K) { | ||||
|  | ||||
|   constexpr int64_t BLOCK_M = block_size_m(); | ||||
|   constexpr int64_t BLOCK_N = block_size_n(); | ||||
|   const int64_t MB = div_up(M, BLOCK_M); | ||||
|   const int64_t NB = div_up(N, BLOCK_N); | ||||
|  | ||||
|   // TODO: brgemm u8s8 depends on PyTorch 2.7 release. | ||||
|   const bool use_brgemm = false; | ||||
|  | ||||
|   // K + 4 after compensation | ||||
|   const int64_t packed_row_size = get_row_size<int8_t>(K); | ||||
|  | ||||
|   AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { | ||||
|     at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { | ||||
|       int64_t mb{0}, nb{0}; | ||||
|       data_index_init(begin, mb, MB, nb, NB); | ||||
|  | ||||
|       // for brgemm, use int32_t for accumulate | ||||
|       alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N]; | ||||
|  | ||||
|       for (int i = begin; i < end; ++i) { | ||||
|         UNUSED(i); | ||||
|         int mb_start = mb * BLOCK_M; | ||||
|         int mb_size = std::min(M - mb_start, BLOCK_M); | ||||
|         int nb_start = nb * BLOCK_N; | ||||
|         int nb_size = std::min(N - nb_start, BLOCK_N); | ||||
|  | ||||
|         tinygemm_kernel<scalar_t, has_bias>( | ||||
|             /*   A */ mat1 + mb_start * K, | ||||
|             /*   B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */, | ||||
|             /*   C */ out + mb_start * N + nb_start, | ||||
|             /* Ctmp*/ Ctmp, | ||||
|             /*  As */ scales1 + mb_start, | ||||
|             /*  Bs */ scales2 + nb_start, | ||||
|             /* bias*/ bias + nb_start, | ||||
|             /*   M */ mb_size, | ||||
|             /*   N */ nb_size, | ||||
|             /*   K */ K, | ||||
|             /* lda */ K, | ||||
|             /* ldb */ nb_size, | ||||
|             /* ldc */ N, | ||||
|             /* brg */ use_brgemm); | ||||
|  | ||||
|         // move to the next index | ||||
|         data_index_step(mb, MB, nb, NB); | ||||
|       } | ||||
|  | ||||
|       if (use_brgemm) { | ||||
|         at::native::cpublas::brgemm_release(); | ||||
|       } | ||||
|     }); | ||||
|   }); | ||||
| } | ||||
|  | ||||
| } // anonymous namespace | ||||
|  | ||||
| // tinygemm interface | ||||
| template <typename scalar_t> | ||||
| void tinygemm_kernel(const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C, | ||||
|     int32_t* __restrict__ Ctmp,  const float* __restrict__ As, const float* __restrict__ Bs, | ||||
|     int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) { | ||||
|   tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg); | ||||
| } | ||||
|  | ||||
| #define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE)                                                     \ | ||||
|     template void tinygemm_kernel<TYPE>(                                                        \ | ||||
|         const uint8_t* __restrict__ A, const int8_t* __restrict__ B, TYPE* __restrict__ C,      \ | ||||
|         int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs, \ | ||||
|         int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) | ||||
|  | ||||
| INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); | ||||
| INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); | ||||
|  | ||||
| std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A) { | ||||
|   RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector<c10::IValue>({A})); | ||||
|  | ||||
|   CHECK_LAST_DIM_CONTIGUOUS_INPUT(A); | ||||
|   CHECK_DIM(2, A); | ||||
|  | ||||
|   int64_t M = A.size(0); | ||||
|   int64_t K = A.size(1); | ||||
|   int64_t lda = A.stride(0); | ||||
|  | ||||
|   const auto st = A.scalar_type(); | ||||
|   TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, | ||||
|       "per_token_quant_int8: expect A to be bfloat16 or half."); | ||||
|  | ||||
|   auto Aq = at::empty({M, K}, A.options().dtype(at::kByte)); | ||||
|   auto As = at::empty({M}, A.options().dtype(at::kFloat)); | ||||
|  | ||||
|   AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] { | ||||
|     uint8_t* __restrict__ Aq_data = Aq.data_ptr<uint8_t>(); | ||||
|     float* __restrict__ As_data = As.data_ptr<float>(); | ||||
|     const scalar_t* __restrict__ A_data = A.data_ptr<scalar_t>(); | ||||
|  | ||||
|     at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) { | ||||
|       for (int64_t m = begin; m < end; ++m) { | ||||
|         quantize_row_int8<scalar_t>( | ||||
|             Aq_data + m * K, | ||||
|             As_data[m], | ||||
|             A_data + m * lda, | ||||
|             K); | ||||
|       } | ||||
|     }); | ||||
|   }); | ||||
|   return std::make_tuple(Aq, As); | ||||
| } | ||||
|  | ||||
| // weight     :  static, per-channel, symmetric | ||||
| // activation : dynamic,   per-token, symmetric | ||||
| // | ||||
| // mat1    : [M, K] | ||||
| // mat2    : [N, K] | ||||
| // scales1 : [M] | ||||
| // scales2 : [N] | ||||
| // bias    : [N] | ||||
| // out     : [M, N] | ||||
| // | ||||
| at::Tensor int8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, | ||||
|     at::Tensor& scales1, at::Tensor& scales2, | ||||
|     std::optional<at::Tensor>& bias, at::ScalarType out_dtype, bool is_vnni) { | ||||
|   RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias})); | ||||
|  | ||||
|   auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); | ||||
|  | ||||
|   CHECK_INPUT(mat1); | ||||
|   CHECK_INPUT(mat2); | ||||
|   CHECK_INPUT(scales1); | ||||
|   CHECK_INPUT(scales2); | ||||
|   CHECK_DIM(2, mat1); | ||||
|   CHECK_DIM(2, mat2); | ||||
|  | ||||
|   int64_t M = mat1.size(0); | ||||
|   int64_t N = mat2.size(0); | ||||
|   int64_t K = mat1.size(1); | ||||
|  | ||||
|   // see [NOTE]: s8s8 igemm compensation in avx512-vnni | ||||
|   CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K)); | ||||
|   CHECK_EQ(scales1.numel(), M); | ||||
|   CHECK_EQ(scales2.numel(), N); | ||||
|  | ||||
|   TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8."); | ||||
|   TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8."); | ||||
|   TORCH_CHECK(scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat, | ||||
|       "int8_scaled_mm: expect scales to be float32."); | ||||
|  | ||||
|   auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); | ||||
|  | ||||
|   const bool has_bias = bias.has_value(); | ||||
|   const float* bias_data = nullptr; | ||||
|   if (has_bias) { | ||||
|     CHECK_EQ(bias.value().size(0), N); | ||||
|     bias_data = bias.value().data_ptr<float>(); | ||||
|   } | ||||
|  | ||||
|   AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] { | ||||
|     int8_scaled_mm_kernel_impl<scalar_t>( | ||||
|         out.data_ptr<scalar_t>(), | ||||
|         mat1.data_ptr<uint8_t>(), | ||||
|         packed_w.data_ptr<int8_t>(), | ||||
|         scales1.data_ptr<float>(), | ||||
|         scales2.data_ptr<float>(), | ||||
|         bias_data, | ||||
|         M, | ||||
|         N, | ||||
|         K); | ||||
|   }); | ||||
|   return out; | ||||
| } | ||||
|  | ||||
| // fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu` | ||||
| at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, | ||||
|     const std::optional<at::Tensor>& bias, at::ScalarType out_dtype, bool is_vnni) { | ||||
|   RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias})); | ||||
|  | ||||
|   auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); | ||||
|  | ||||
|   CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); | ||||
|   CHECK_INPUT(mat2); | ||||
|   CHECK_INPUT(scales2); | ||||
|   CHECK_DIM(2, mat1); | ||||
|   CHECK_DIM(2, mat2); | ||||
|  | ||||
|   int64_t M = mat1.size(0); | ||||
|   int64_t N = mat2.size(0); | ||||
|   int64_t K = mat1.size(1); | ||||
|   int64_t lda = mat1.stride(0); | ||||
|  | ||||
|   // see [NOTE]: s8s8 igemm compensation in avx512-vnni | ||||
|   CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K)); | ||||
|   CHECK_EQ(scales2.numel(), N); | ||||
|  | ||||
|   const auto st = mat1.scalar_type(); | ||||
|   TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, | ||||
|       "int8_scaled_mm_with_quant: expect A to be bfloat16 or half."); | ||||
|   TORCH_CHECK(st == out_dtype, | ||||
|       "int8_scaled_mm_with_quant: expect A has same dtype with out_dtype."); | ||||
|   TORCH_CHECK(mat2.scalar_type() == at::kChar, | ||||
|       "int8_scaled_mm_with_quant: expect mat2 to be int8."); | ||||
|   TORCH_CHECK(scales2.scalar_type() == at::kFloat, | ||||
|       "int8_scaled_mm_with_quant: expect scales to be float32."); | ||||
|  | ||||
|   const int64_t buffer_size = M * K + M * sizeof(float); | ||||
|   auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte)); | ||||
|   auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); | ||||
|  | ||||
|   const bool has_bias = bias.has_value(); | ||||
|   const float* bias_data = nullptr; | ||||
|   if (has_bias) { | ||||
|     CHECK_EQ(bias.value().size(0), N); | ||||
|     bias_data = bias.value().data_ptr<float>(); | ||||
|   } | ||||
|  | ||||
|   AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] { | ||||
|     uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>(); | ||||
|     float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K)); | ||||
|     const scalar_t* __restrict__ A_data = mat1.data_ptr<scalar_t>(); | ||||
|  | ||||
|     at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) { | ||||
|       for (int64_t m = begin; m < end; ++m) { | ||||
|         quantize_row_int8<scalar_t>( | ||||
|             Aq_data + m * K, | ||||
|             As_data[m], | ||||
|             A_data + m * lda, | ||||
|             K); | ||||
|       } | ||||
|     }); | ||||
|  | ||||
|     int8_scaled_mm_kernel_impl<scalar_t>( | ||||
|         out.data_ptr<scalar_t>(), | ||||
|         Aq_data, | ||||
|         packed_w.data_ptr<int8_t>(), | ||||
|         As_data, | ||||
|         scales2.data_ptr<float>(), | ||||
|         bias_data, | ||||
|         M, | ||||
|         N, | ||||
|         K); | ||||
|   }); | ||||
|   return out; | ||||
| } | ||||
							
								
								
									
										1330
									
								
								csrc/cpu/sgl-kernels/moe.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1330
									
								
								csrc/cpu/sgl-kernels/moe.cpp
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										502
									
								
								csrc/cpu/sgl-kernels/moe_fp8.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										502
									
								
								csrc/cpu/sgl-kernels/moe_fp8.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,502 @@ | ||||
| // Adapted from | ||||
| // https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu | ||||
|  | ||||
| #include "common.h" | ||||
| #include "gemm.h" | ||||
| #include "vec.h" | ||||
|  | ||||
| // clang-format off | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| template <typename scalar_t> | ||||
| inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { | ||||
|   using Vec = at::vec::Vectorized<scalar_t>; | ||||
|   // no remainder | ||||
|   #pragma GCC unroll 4 | ||||
|   for (int64_t d = 0; d < size; d += Vec::size()) { | ||||
|     Vec data = Vec::loadu(input + d); | ||||
|     data.store(out + d); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| inline void copy_mul_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, float weight, int64_t size) { | ||||
|   using bVec = at::vec::Vectorized<scalar_t>; | ||||
|   using fVec = at::vec::Vectorized<float>; | ||||
|   constexpr int kVecSize = bVec::size(); | ||||
|   const fVec weight_vec = fVec(weight); | ||||
|   int64_t d; | ||||
|   #pragma GCC unroll 4 | ||||
|   for (d = 0; d <= size - kVecSize; d += kVecSize) { | ||||
|     bVec x = bVec::loadu(input + d); | ||||
|     fVec x0, x1; | ||||
|     std::tie(x0, x1) = at::vec::convert_to_float(x); | ||||
|     x0 = x0 * weight_vec; | ||||
|     x1 = x1 * weight_vec; | ||||
|     bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1); | ||||
|     out_vec.store(out + d); | ||||
|   } | ||||
|   for (; d < size; ++d) { | ||||
|     out[d] = static_cast<scalar_t>(input[d] * weight); | ||||
|   } | ||||
| } | ||||
|  | ||||
| // acc from [topk, K] to [K] | ||||
| template <typename scalar_t> | ||||
| inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { | ||||
|   using bVec = at::vec::Vectorized<scalar_t>; | ||||
|   using fVec = at::vec::Vectorized<float>; | ||||
|   constexpr int kVecSize = bVec::size(); | ||||
|   if (topk == 1) { | ||||
|     // do copy for topk = 1 | ||||
|     copy_stub(out, input, K); | ||||
|   } else { | ||||
|     // do sum for topk != 1 | ||||
|     int64_t d; | ||||
|     #pragma GCC unroll 4 | ||||
|     for (d = 0; d <= K - kVecSize; d += kVecSize) { | ||||
|       fVec sum_fvec0 = fVec(0.f); | ||||
|       fVec sum_fvec1 = fVec(0.f); | ||||
|       for (int t = 0; t < topk; ++t) { | ||||
|         bVec x_bvec = bVec::loadu(input + t * K + d); | ||||
|         fVec x_fvec0, x_fvec1; | ||||
|         std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); | ||||
|  | ||||
|         sum_fvec0 += x_fvec0; | ||||
|         sum_fvec1 += x_fvec1; | ||||
|       } | ||||
|       bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1); | ||||
|       out_bvec.store(out + d); | ||||
|     } | ||||
|     for (; d < K; ++d) { | ||||
|       float sum_val = 0.f; | ||||
|       for (int t = 0; t < topk; ++t) { | ||||
|         sum_val += static_cast<float>(input[t * K + d]); | ||||
|       } | ||||
|       out[d] = static_cast<scalar_t>(sum_val); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| // out = input + input2 * scale | ||||
| template <typename scalar_t> | ||||
| inline void add_mul_stub( | ||||
|     scalar_t* __restrict__ out, | ||||
|     const scalar_t* __restrict__ input, | ||||
|     const scalar_t* __restrict__ input2, | ||||
|     float scale, | ||||
|     int64_t size) { | ||||
|   using bVec = at::vec::Vectorized<scalar_t>; | ||||
|   using fVec = at::vec::Vectorized<float>; | ||||
|   constexpr int kVecSize = bVec::size(); | ||||
|   const fVec s_vec = fVec(scale); | ||||
|  | ||||
|   int64_t d; | ||||
| #pragma GCC unroll 4 | ||||
|   for (d = 0; d <= size - kVecSize; d += kVecSize) { | ||||
|     bVec x_bvec = bVec::loadu(input + d); | ||||
|     fVec x0, x1; | ||||
|     std::tie(x0, x1) = at::vec::convert_to_float(x_bvec); | ||||
|  | ||||
|     bVec y_bvec = bVec::loadu(input2 + d); | ||||
|     fVec y0, y1; | ||||
|     std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); | ||||
|  | ||||
|     x0 = x0 + y0 * s_vec; | ||||
|     x1 = x1 + y1 * s_vec; | ||||
|     bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1); | ||||
|     out_vec.store(out + d); | ||||
|   } | ||||
|   for (; d < size; ++d) { | ||||
|     out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| inline void silu_and_mul_stub( | ||||
|     scalar_t* __restrict__ out, | ||||
|     const scalar_t* __restrict__ input, | ||||
|     const scalar_t* __restrict__ input2, | ||||
|     int64_t size) { | ||||
|   using bVec = at::vec::Vectorized<scalar_t>; | ||||
|   using fVec = at::vec::Vectorized<float>; | ||||
|   const fVec one = fVec(1.f); | ||||
|  | ||||
|   // no remainder | ||||
| #pragma GCC unroll 4 | ||||
|   for (int64_t d = 0; d < size; d += bVec::size()) { | ||||
|     bVec x = bVec::loadu(input + d); | ||||
|     fVec x0, x1; | ||||
|     std::tie(x0, x1) = at::vec::convert_to_float(x); | ||||
|     bVec y = bVec::loadu(input2 + d); | ||||
|     fVec y0, y1; | ||||
|     std::tie(y0, y1) = at::vec::convert_to_float(y); | ||||
|     x0 = x0 / (one + x0.neg().exp_u20()); | ||||
|     x1 = x1 / (one + x1.neg().exp_u20()); | ||||
|     x0 = x0 * y0; | ||||
|     x1 = x1 * y1; | ||||
|     bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1); | ||||
|     out_vec.store(out + d); | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // anonymous namespace | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void fused_experts_fp8_kernel_impl( | ||||
|     scalar_t* __restrict__ output, | ||||
|     scalar_t* __restrict__ ic0, | ||||
|     scalar_t* __restrict__ ic1, | ||||
|     scalar_t* __restrict__ ic2, | ||||
|     scalar_t* __restrict__ A_tmp, | ||||
|     scalar_t* __restrict__ B_tmp, | ||||
|     float* __restrict__ C_tmp, | ||||
|     const scalar_t* __restrict__ input, | ||||
|     const at::Float8_e4m3fn* __restrict__ packed_w1, | ||||
|     const at::Float8_e4m3fn* __restrict__ packed_w2, | ||||
|     const float* __restrict__ w1s, | ||||
|     const float* __restrict__ w2s, | ||||
|     int64_t block_size_N, | ||||
|     int64_t block_size_K, | ||||
|     const float* __restrict__ topk_weights, | ||||
|     const int32_t* __restrict__ sorted_ids, | ||||
|     const int32_t* __restrict__ expert_ids, | ||||
|     const int32_t* __restrict__ offsets, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K, | ||||
|     int64_t E, | ||||
|     int64_t topk, | ||||
|     int64_t num_tokens_post_pad) { | ||||
|  | ||||
|   constexpr int64_t BLOCK_M = block_size_m(); | ||||
|   constexpr int64_t BLOCK_N = block_size_n(); | ||||
|  | ||||
|   // stage 1: intermediate_cache0 = hidden_states @ w1 | ||||
|   const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); | ||||
|   const int64_t NB = div_up(2 * N, BLOCK_N); | ||||
|   int64_t scale_size_N = div_up(2 * N, block_size_N); | ||||
|   int64_t scale_size_K = div_up(K, block_size_K); | ||||
|   int64_t blocks_n_per_group = block_size_N / BLOCK_N; | ||||
|  | ||||
|   const int64_t stride_e = 2 * N * K; | ||||
|   const int64_t stride_n = K; | ||||
|  | ||||
|   // here we only parallel on half of 2N to fuse silu_and_mul with gemm | ||||
|   at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { | ||||
|     // get local pointers | ||||
|     int tid = at::get_thread_num(); | ||||
|     scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; | ||||
|  | ||||
|     bool is_brgemm_used = false; | ||||
|  | ||||
|     for (int64_t i = begin; i < end; ++i) { | ||||
|       int64_t mb = i / NB; | ||||
|       int64_t nb = i % NB; | ||||
|  | ||||
|       int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); | ||||
|  | ||||
|       // B shape [K, n_size] in vnni format | ||||
|       int32_t expert_id = expert_ids[mb]; | ||||
|       const at::Float8_e4m3fn* __restrict__ B = packed_w1 + expert_id * stride_e + nb * BLOCK_N * stride_n; | ||||
|       const float* __restrict__ Bs = w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; | ||||
|  | ||||
|       // 1.a load A | ||||
|       const int32_t* A_ids = sorted_ids + mb * BLOCK_M; | ||||
|       int64_t m_size = offsets[mb + 1] - offsets[mb]; | ||||
|  | ||||
|       const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size); | ||||
|       is_brgemm_used = is_brgemm_used || use_brgemm; | ||||
|  | ||||
|       for (int64_t m = 0; m < m_size; ++m) { | ||||
|         int32_t index = A_ids[m] / topk; | ||||
|         copy_stub(A + m * K, input + index * K, K); | ||||
|       } | ||||
|  | ||||
|       const int64_t offset = offsets[mb]; | ||||
|       tinygemm_kernel<scalar_t>( | ||||
|           /*   A            */ A, | ||||
|           /*   B            */ B, | ||||
|           /*   C            */ ic0 + offset * 2 * N + nb * BLOCK_N, | ||||
|           /*   Btmp         */ B_tmp + tid * BLOCK_N * std::max(K, N), | ||||
|           /*   Ctmp         */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, | ||||
|           /*   scale        */ Bs, | ||||
|           /*   M            */ m_size, | ||||
|           /*   N            */ n_size, | ||||
|           /*   K            */ K, | ||||
|           /*   lda          */ K, | ||||
|           /*   ldb          */ n_size, | ||||
|           /*   ldc          */ 2 * N, | ||||
|           /*   brg          */ use_brgemm, | ||||
|           /*   block_size_K */ block_size_K); | ||||
|     } | ||||
|  | ||||
|     if (is_brgemm_used) { | ||||
|       at::native::cpublas::brgemm_release(); | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   // stage 1.5: intermediate_cache1 = silu(intermediate_cache0) | ||||
|   at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) { | ||||
|     for (int64_t m = begin; m < end; ++m) { | ||||
|       silu_and_mul_stub( | ||||
|           ic1 + m * N, | ||||
|           ic0 + m * 2 * N, | ||||
|           ic0 + m * 2 * N + N, | ||||
|           N); | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 | ||||
|   //   w2 : [E, K, N] as [E, OC, IC] | ||||
|   const int64_t OC = K;  // rename K as OC | ||||
|   const int64_t IC = N;  // rename N as IC | ||||
|   const int64_t MB2 = MB; | ||||
|   const int64_t NB2 = div_up(OC, BLOCK_N); | ||||
|   scale_size_N = div_up(K, block_size_N); | ||||
|   scale_size_K = div_up(N, block_size_K); | ||||
|   const int64_t stride_e2 = OC * IC; | ||||
|   const int64_t stride_oc = IC; | ||||
|  | ||||
|   // parallel on [MB2, NB2] | ||||
|   at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { | ||||
|     int tid = at::get_thread_num(); | ||||
|     alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; | ||||
|  | ||||
|     bool is_brgemm_used = false; | ||||
|  | ||||
|     for (int64_t i = begin; i < end; ++i) { | ||||
|       int64_t mb = i / NB2; | ||||
|       int64_t nb = i % NB2; | ||||
|  | ||||
|       int64_t m_size = offsets[mb + 1] - offsets[mb]; | ||||
|       int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); | ||||
|  | ||||
|       const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size); | ||||
|       is_brgemm_used = is_brgemm_used || use_brgemm; | ||||
|  | ||||
|       // A ptr from ic1 of [M * topk, N] in sorted order | ||||
|       // so as to avoid copy A to tmp buffer again | ||||
|       const scalar_t* __restrict__ A = ic1 + offsets[mb] * N; | ||||
|       const int32_t* A_ids = sorted_ids + mb * BLOCK_M; | ||||
|  | ||||
|       // B shape [IC, n_size] in vnni format | ||||
|       int32_t expert_id = expert_ids[mb]; | ||||
|       const at::Float8_e4m3fn* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; | ||||
|       const float* __restrict__ Bs = w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; | ||||
|  | ||||
|       tinygemm_kernel<scalar_t>( | ||||
|           /*   A            */ A, | ||||
|           /*   B            */ B, | ||||
|           /*   C            */ C, | ||||
|           /*   Btmp         */ B_tmp + tid * BLOCK_N * std::max(K, N), | ||||
|           /*   Ctmp         */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, | ||||
|           /*   scale        */ Bs, | ||||
|           /*   M            */ m_size, | ||||
|           /*   N            */ n_size, | ||||
|           /*   K            */ IC, | ||||
|           /*   lda          */ IC, | ||||
|           /*   ldb          */ n_size, | ||||
|           /*   ldc          */ BLOCK_N, | ||||
|           /*   brg          */ use_brgemm, | ||||
|           /*   block_size_K */ block_size_K); | ||||
|  | ||||
|       // 2.b copy from C to ic2 in original order | ||||
|       //   and also mul topk_weights in float32 | ||||
|       for (int64_t m = 0; m < m_size; ++m) { | ||||
|         int32_t index = A_ids[m]; | ||||
|         float weight = topk_weights[index]; | ||||
|         copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     if (is_brgemm_used) { | ||||
|       at::native::cpublas::brgemm_release(); | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   // stage 3: out = intermediate_cache2.sum(dim=1) | ||||
|   //   from [M, topk, K] to [M, K] | ||||
|   at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { | ||||
|     for (int64_t m = begin; m < end; ++m) { | ||||
|       sum_stub(output + m * K, ic2 + m * topk * K, topk, K); | ||||
|     } | ||||
|   }); | ||||
| } | ||||
|  | ||||
| #define INSTANTIATE_MOE_FP8_TEMPLATE(TYPE)             \ | ||||
|   template void fused_experts_fp8_kernel_impl<TYPE>(   \ | ||||
|       TYPE* __restrict__ output,                       \ | ||||
|       TYPE* __restrict__ ic0,                          \ | ||||
|       TYPE* __restrict__ ic1,                          \ | ||||
|       TYPE* __restrict__ ic2,                          \ | ||||
|       TYPE* __restrict__ A_tmp,                        \ | ||||
|       TYPE* __restrict__ B_tmp,                        \ | ||||
|       float* __restrict__ C_tmp,                       \ | ||||
|       const TYPE* __restrict__ input,                  \ | ||||
|       const at::Float8_e4m3fn* __restrict__ packed_w1, \ | ||||
|       const at::Float8_e4m3fn* __restrict__ packed_w2, \ | ||||
|       const float* __restrict__ w1s,                   \ | ||||
|       const float* __restrict__ w2s,                   \ | ||||
|       int64_t block_size_N,                            \ | ||||
|       int64_t block_size_K,                            \ | ||||
|       const float* __restrict__ topk_weights,          \ | ||||
|       const int32_t* __restrict__ sorted_ids,          \ | ||||
|       const int32_t* __restrict__ expert_ids,          \ | ||||
|       const int32_t* __restrict__ offsets,             \ | ||||
|       int64_t M,                                       \ | ||||
|       int64_t N,                                       \ | ||||
|       int64_t K,                                       \ | ||||
|       int64_t E,                                       \ | ||||
|       int64_t topk,                                    \ | ||||
|       int64_t num_tokens_post_pad) | ||||
|  | ||||
| INSTANTIATE_MOE_FP8_TEMPLATE(at::BFloat16); | ||||
| INSTANTIATE_MOE_FP8_TEMPLATE(at::Half); | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void shared_expert_fp8_kernel_impl( | ||||
|     scalar_t* __restrict__ output, | ||||
|     scalar_t* __restrict__ ic0, | ||||
|     scalar_t* __restrict__ ic1, | ||||
|     scalar_t* __restrict__ B_tmp, | ||||
|     float* __restrict__ C_tmp, | ||||
|     const scalar_t* __restrict__ input, | ||||
|     const at::Float8_e4m3fn* __restrict__ packed_w1, | ||||
|     const at::Float8_e4m3fn* __restrict__ packed_w2, | ||||
|     const float* __restrict__ w1s, | ||||
|     const float* __restrict__ w2s, | ||||
|     int64_t block_size_N, | ||||
|     int64_t block_size_K, | ||||
|     const scalar_t* __restrict__ fused_experts_out, | ||||
|     float routed_scaling_factor, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K) { | ||||
|  | ||||
|   constexpr int64_t BLOCK_M = block_size_m(); | ||||
|   constexpr int64_t BLOCK_N = block_size_n(); | ||||
|  | ||||
|   // stage 1: intermediate_cache0 = hidden_states @ w1 | ||||
|   const int64_t MB = div_up(M, BLOCK_M); | ||||
|   const int64_t NB = div_up(2 * N, BLOCK_N); | ||||
|   int64_t scale_size_K = div_up(K, block_size_K); | ||||
|   int64_t blocks_n_per_group = block_size_N / BLOCK_N; | ||||
|  | ||||
|   const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M); | ||||
|  | ||||
|   at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { | ||||
|     int tid = at::get_thread_num(); | ||||
|  | ||||
|     for (int64_t i = begin; i < end; ++i) { | ||||
|       int64_t mb = i / NB; | ||||
|       int64_t nb = i % NB; | ||||
|       int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); | ||||
|       int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); | ||||
|  | ||||
|       tinygemm_kernel<scalar_t>( | ||||
|           /*   A            */ input + mb * BLOCK_M * K, | ||||
|           /*   B            */ packed_w1 + nb * BLOCK_N * K, | ||||
|           /*   C            */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N, | ||||
|           /*   Btmp         */ B_tmp + tid * BLOCK_N * std::max(K, N), | ||||
|           /*   Ctmp         */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, | ||||
|           /*   scale        */ w1s + (nb / blocks_n_per_group) * scale_size_K, | ||||
|           /*   M            */ m_size, | ||||
|           /*   N            */ n_size, | ||||
|           /*   K            */ K, | ||||
|           /*   lda          */ K, | ||||
|           /*   ldb          */ n_size, | ||||
|           /*   ldc          */ 2 * N, | ||||
|           /*   brg          */ use_brgemm, | ||||
|           /*   block_size_K */ block_size_K); | ||||
|     } | ||||
|  | ||||
|     if (use_brgemm) { | ||||
|       at::native::cpublas::brgemm_release(); | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   // stage 1.5: intermediate_cache1 = silu(intermediate_cache0) | ||||
|   at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { | ||||
|     for (int64_t m = begin; m < end; ++m) { | ||||
|       silu_and_mul_stub( | ||||
|           ic1 + m * N, | ||||
|           ic0 + m * 2 * N, | ||||
|           ic0 + m * 2 * N + N, | ||||
|           N); | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 | ||||
|   //   w2 : [K, N] as [OC, IC] | ||||
|   const int64_t OC = K;  // rename K as OC | ||||
|   const int64_t IC = N;  // rename N as IC | ||||
|   const int64_t MB2 = MB; | ||||
|   const int64_t NB2 = div_up(K, BLOCK_N); | ||||
|   scale_size_K = div_up(N, block_size_K); | ||||
|  | ||||
|   // parallel on [MB2, NB2] | ||||
|   at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { | ||||
|     int tid = at::get_thread_num(); | ||||
|     alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; | ||||
|  | ||||
|     for (int64_t i = begin; i < end; ++i) { | ||||
|       int64_t mb = i / NB2; | ||||
|       int64_t nb = i % NB2; | ||||
|       int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); | ||||
|       int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); | ||||
|  | ||||
|       // 2.a gemm: C = A @ B | ||||
|       tinygemm_kernel<scalar_t>( | ||||
|           /*   A            */ ic1 + mb * BLOCK_M * N, | ||||
|           /*   B            */ packed_w2 + nb * BLOCK_N * N, | ||||
|           /*   C            */ C, | ||||
|           /*   Btmp         */ B_tmp + tid * BLOCK_N * std::max(K, N), | ||||
|           /*   Ctmp         */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, | ||||
|           /*   scale        */ w2s + (nb / blocks_n_per_group) * scale_size_K, | ||||
|           /*   M            */ m_size, | ||||
|           /*   N            */ n_size, | ||||
|           /*   K            */ IC, | ||||
|           /*   lda          */ IC, | ||||
|           /*   ldb          */ n_size, | ||||
|           /*   ldc          */ BLOCK_N, | ||||
|           /*   brg          */ use_brgemm, | ||||
|           /*   block_size_K */ block_size_K); | ||||
|  | ||||
|       // 2.b copy from C to output and add fused_experts_out | ||||
|       scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; | ||||
|       const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; | ||||
|       for (int64_t m = 0; m < m_size; ++m) { | ||||
|         add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); | ||||
|       } | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   if (use_brgemm) { | ||||
|     at::native::cpublas::brgemm_release(); | ||||
|   } | ||||
| } | ||||
|  | ||||
| #define INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(TYPE)   \ | ||||
|   template void shared_expert_fp8_kernel_impl<TYPE>(   \ | ||||
|       TYPE* __restrict__ output,                       \ | ||||
|       TYPE* __restrict__ ic0,                          \ | ||||
|       TYPE* __restrict__ ic1,                          \ | ||||
|       TYPE* __restrict__ B_tmp,                        \ | ||||
|       float* __restrict__ C_tmp,                       \ | ||||
|       const TYPE* __restrict__ input,                  \ | ||||
|       const at::Float8_e4m3fn* __restrict__ packed_w1, \ | ||||
|       const at::Float8_e4m3fn* __restrict__ packed_w2, \ | ||||
|       const float* __restrict__ w1s,                   \ | ||||
|       const float* __restrict__ w2s,                   \ | ||||
|       int64_t block_size_N,                            \ | ||||
|       int64_t block_size_K,                            \ | ||||
|       const TYPE* __restrict__ fused_experts_out,      \ | ||||
|       float routed_scaling_factor,                     \ | ||||
|       int64_t M,                                       \ | ||||
|       int64_t N,                                       \ | ||||
|       int64_t K) | ||||
|  | ||||
| INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::BFloat16); | ||||
| INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::Half); | ||||
							
								
								
									
										769
									
								
								csrc/cpu/sgl-kernels/moe_int8.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										769
									
								
								csrc/cpu/sgl-kernels/moe_int8.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,769 @@ | ||||
| // Adapted from | ||||
| // https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu | ||||
|  | ||||
| #include "common.h" | ||||
| #include "vec.h" | ||||
| #include "gemm.h" | ||||
|  | ||||
| // clang-format off | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| template <typename scalar_t> | ||||
| inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { | ||||
|   using Vec = at::vec::Vectorized<scalar_t>; | ||||
|   // no remainder | ||||
|   #pragma GCC unroll 4 | ||||
|   for (int64_t d = 0; d < size; d += Vec::size()) { | ||||
|     Vec data = Vec::loadu(input + d); | ||||
|     data.store(out + d); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <> | ||||
| inline void copy_stub<uint8_t>(uint8_t* __restrict__ out, const uint8_t* __restrict__ input, int64_t size) { | ||||
|   // size might be 64x + 32 | ||||
|   std::memcpy(out, input, size * sizeof(uint8_t)); | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) { | ||||
|   using bVec = at::vec::Vectorized<scalar_t>; | ||||
|   using fVec = at::vec::Vectorized<float>; | ||||
|   constexpr int kVecSize = bVec::size(); | ||||
|   const fVec weight_vec = fVec(weight); | ||||
|   int64_t d; | ||||
|   #pragma GCC unroll 4 | ||||
|   for (d = 0; d <= size - kVecSize; d += kVecSize) { | ||||
|     fVec data0 = fVec::loadu(input + d) * weight_vec; | ||||
|     fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec; | ||||
|     bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1); | ||||
|     out_vec.store(out + d); | ||||
|   } | ||||
|   for (; d < size; ++d) { | ||||
|     out[d] = static_cast<scalar_t>(input[d] * weight); | ||||
|   } | ||||
| } | ||||
|  | ||||
| // acc from [topk, K] to [K] | ||||
| template <typename scalar_t> | ||||
| inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { | ||||
|   using bVec = at::vec::Vectorized<scalar_t>; | ||||
|   using fVec = at::vec::Vectorized<float>; | ||||
|   constexpr int kVecSize = bVec::size(); | ||||
|   if (topk == 1) { | ||||
|     // do copy for topk = 1 | ||||
|     copy_stub(out, input, K); | ||||
|   } else { | ||||
|     // do sum for topk != 1 | ||||
|     int64_t d; | ||||
|     #pragma GCC unroll 4 | ||||
|     for (d = 0; d <= K - kVecSize; d += kVecSize) { | ||||
|       fVec sum_fvec0 = fVec(0.f); | ||||
|       fVec sum_fvec1 = fVec(0.f); | ||||
|       for (int t = 0; t < topk; ++t) { | ||||
|         bVec x_bvec = bVec::loadu(input + t * K + d); | ||||
|         fVec x_fvec0, x_fvec1; | ||||
|         std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); | ||||
|  | ||||
|         sum_fvec0 += x_fvec0; | ||||
|         sum_fvec1 += x_fvec1; | ||||
|       } | ||||
|       bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1); | ||||
|       out_bvec.store(out + d); | ||||
|     } | ||||
|     for (; d < K; ++d) { | ||||
|       float sum_val = 0.f; | ||||
|       for (int t = 0; t < topk; ++t) { | ||||
|         sum_val += static_cast<float>(input[t * K + d]); | ||||
|       } | ||||
|       out[d] = static_cast<scalar_t>(sum_val); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| // out = input + input2 * scale | ||||
| template <typename scalar_t> | ||||
| inline void add_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, | ||||
|     const scalar_t* __restrict__ input2, float scale, int64_t size) { | ||||
|  | ||||
|   using bVec = at::vec::Vectorized<scalar_t>; | ||||
|   using fVec = at::vec::Vectorized<float>; | ||||
|   constexpr int kVecSize = bVec::size(); | ||||
|   const fVec s_vec = fVec(scale); | ||||
|   int64_t d; | ||||
|   #pragma GCC unroll 4 | ||||
|   for (d = 0; d <= size - kVecSize; d += kVecSize) { | ||||
|     fVec x0 = fVec::loadu(input + d); | ||||
|     fVec x1 = fVec::loadu(input + d + fVec::size()); | ||||
|  | ||||
|     bVec y_bvec = bVec::loadu(input2 + d); | ||||
|     fVec y0, y1; | ||||
|     std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); | ||||
|  | ||||
|     x0 = x0 + y0 * s_vec; | ||||
|     x1 = x1 + y1 * s_vec; | ||||
|     bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1); | ||||
|     out_vec.store(out + d); | ||||
|   } | ||||
|   for (; d < size; ++d) { | ||||
|     out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale); | ||||
|   } | ||||
| } | ||||
|  | ||||
| /// gemm for w13 | ||||
| template <typename scalar_t, int BLOCK_M, int BLOCK_N> | ||||
| struct tinygemm_kernel_vnni { | ||||
|   static inline void apply( | ||||
|       const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, scalar_t* __restrict__ C, | ||||
|       const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1, | ||||
|       const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1, | ||||
|       int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||
|     TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| #if defined(CPU_CAPABILITY_AVX512) | ||||
| template <int BLOCK_M, int BLOCK_N> | ||||
| struct tinygemm_kernel_vnni<at::BFloat16, BLOCK_M, BLOCK_N> { | ||||
|   static inline void apply( | ||||
|       const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, at::BFloat16* __restrict__ C, | ||||
|       const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1, | ||||
|       const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1, | ||||
|       int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||
|  | ||||
|     constexpr int ROWS = BLOCK_M; | ||||
|     constexpr int COLS = BLOCK_N / 16; | ||||
|     static_assert(COLS % 2 == 0); | ||||
|  | ||||
|     __m512i va; | ||||
|     __m512i vb0[COLS]; | ||||
|     __m512i vb1[COLS]; | ||||
|     __m512i vc0[ROWS * COLS]; | ||||
|     __m512i vc1[ROWS * COLS]; | ||||
|     __m512i vcomp0[COLS]; | ||||
|     __m512i vcomp1[COLS]; | ||||
|     __m512  was; | ||||
|     __m512  vbs0[COLS]; | ||||
|     __m512  vbs1[COLS]; | ||||
|  | ||||
|     auto loadc = [&](auto i) { | ||||
|       vc0[i] = _mm512_set1_epi32(0); | ||||
|       vc1[i] = _mm512_set1_epi32(0); | ||||
|     }; | ||||
|     Unroll<ROWS * COLS>{}(loadc); | ||||
|  | ||||
|     const int64_t K4 = K >> 2; | ||||
|     const int64_t lda4 = lda >> 2; | ||||
|     const int64_t ldb4 = ldb; // ldb * 4 >> 2; | ||||
|     const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A); | ||||
|     const int32_t* b0_ptr = reinterpret_cast<const int32_t*>(B0); | ||||
|     const int32_t* b1_ptr = reinterpret_cast<const int32_t*>(B1); | ||||
|  | ||||
|     auto compute = [&](auto i, int64_t k) { | ||||
|       constexpr int row = i / COLS; | ||||
|       constexpr int col = i % COLS; | ||||
|  | ||||
|       if constexpr (col == 0) { | ||||
|         va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); | ||||
|       } | ||||
|       if constexpr (row == 0) { | ||||
|         vb0[col] = _mm512_loadu_si512(b0_ptr + k * ldb4 + col * 16); | ||||
|         vb1[col] = _mm512_loadu_si512(b1_ptr + k * ldb4 + col * 16); | ||||
|       } | ||||
|       vc0[i] = _mm512_dpbusd_epi32(vc0[i], va, vb0[col]); | ||||
|       vc1[i] = _mm512_dpbusd_epi32(vc1[i], va, vb1[col]); | ||||
|     }; | ||||
|     for (int64_t k = 0; k < K4; ++k) { | ||||
|       Unroll<ROWS * COLS>{}(compute, k); | ||||
|     } | ||||
|  | ||||
|     auto scalec = [&](auto i) { | ||||
|       constexpr int row = i / COLS; | ||||
|       constexpr int col = i % COLS; | ||||
|  | ||||
|       // load a scale | ||||
|       if constexpr(col == 0) { | ||||
|         was = _mm512_set1_ps(As[row]); | ||||
|       } | ||||
|       // load b scale and vcomp | ||||
|       if constexpr (row == 0) { | ||||
|         vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16); | ||||
|         vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16); | ||||
|         vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16); | ||||
|         vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16); | ||||
|       } | ||||
|       __m512 c0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc0[i], vcomp0[col])); | ||||
|       __m512 c1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc1[i], vcomp1[col])); | ||||
|       vc0[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c0, was), vbs0[col])); | ||||
|       vc1[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c1, was), vbs1[col])); | ||||
|     }; | ||||
|     Unroll<ROWS * COLS>{}(scalec); | ||||
|  | ||||
|     using Vec = at::vec::Vectorized<float>; | ||||
|     const Vec one = Vec(1.f); | ||||
|     auto storec = [&](auto i) { | ||||
|       constexpr int row = i / COLS; | ||||
|       constexpr int col = i % COLS; | ||||
|       // for COLS = 2, 4 use 512bit store | ||||
|       if constexpr (col % 2 == 0) { | ||||
|         Vec x0 = _mm512_castsi512_ps(vc0[row * COLS + col + 0]); | ||||
|         Vec x1 = _mm512_castsi512_ps(vc0[row * COLS + col + 1]); | ||||
|         Vec y0 = _mm512_castsi512_ps(vc1[row * COLS + col + 0]); | ||||
|         Vec y1 = _mm512_castsi512_ps(vc1[row * COLS + col + 1]); | ||||
|         // silu | ||||
|         x0 = x0 / (one + x0.neg().exp_u20()); | ||||
|         x1 = x1 / (one + x1.neg().exp_u20()); | ||||
|         // mul | ||||
|         x0 = x0 * y0; | ||||
|         x1 = x1 * y1; | ||||
|  | ||||
|         _mm512_storeu_si512( | ||||
|             reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), | ||||
|             (__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0)))); | ||||
|         } | ||||
|     }; | ||||
|     Unroll<ROWS * COLS>{}(storec); | ||||
|   } | ||||
| }; | ||||
| #endif | ||||
|  | ||||
| #define LAUNCH_TINYGEMM_KERNEL_VNNI(MB_SIZE, NB_SIZE)                        \ | ||||
|     tinygemm_kernel_vnni<scalar_t, MB_SIZE, NB_SIZE>::apply(                 \ | ||||
|         A + mb_start * lda, B0 + nb_start * 4, B1 + nb_start * 4,            \ | ||||
|         C + mb_start * ldc + nb_start, As + mb_start,                        \ | ||||
|         Bs0 + nb_start, Bs1 + nb_start, Bcomp0 + nb_start, Bcomp1 + nb_start,\ | ||||
|         K, lda, ldb, ldc); | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void tinygemm_kernel( | ||||
|     const uint8_t* __restrict__ A, | ||||
|     const int8_t* __restrict__ B0, | ||||
|     const int8_t* __restrict__ B1, | ||||
|     scalar_t* __restrict__ C, | ||||
|     const float* __restrict__ As, | ||||
|     const float* __restrict__ Bs0, | ||||
|     const float* __restrict__ Bs1, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K, | ||||
|     int64_t lda, | ||||
|     int64_t ldb, | ||||
|     int64_t ldc) { | ||||
|  | ||||
|   const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K); | ||||
|   const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K); | ||||
|  | ||||
|   // pattern: 1-(2+2)-(8+8) | ||||
|   constexpr int64_t BLOCK_M = 4; | ||||
|   constexpr int64_t BLOCK_N = 32; | ||||
|   const int64_t MB = div_up(M, BLOCK_M); | ||||
|   const int64_t NB = div_up(N, BLOCK_N); | ||||
|   for (int mb = 0; mb < MB; ++mb) { | ||||
|     int64_t mb_start = mb * BLOCK_M; | ||||
|     int64_t mb_size = std::min(BLOCK_M, M - mb_start); | ||||
|     for (int64_t nb = 0; nb < NB; ++nb) { | ||||
|       int64_t nb_start = nb * BLOCK_N; | ||||
|       int64_t nb_size = std::min(BLOCK_N, N - nb_start); | ||||
|  | ||||
|       switch(mb_size << 4 | nb_size >> 4) { | ||||
|         case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI(1, 32); break; | ||||
|         case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI(2, 32); break; | ||||
|         case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI(3, 32); break; | ||||
|         case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI(4, 32); break; | ||||
|         default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| /// gemm for w2 | ||||
| template <typename scalar_t, int BLOCK_M, int BLOCK_N> | ||||
| struct tinygemm_kernel_vnni2 { | ||||
|   static inline void apply( | ||||
|       const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, | ||||
|       const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, | ||||
|       int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||
|     TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| #if defined(CPU_CAPABILITY_AVX512) | ||||
| template <int BLOCK_M, int BLOCK_N> | ||||
| struct tinygemm_kernel_vnni2<at::BFloat16, BLOCK_M, BLOCK_N> { | ||||
|   static inline void apply( | ||||
|       const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, | ||||
|       const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, | ||||
|       int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { | ||||
|  | ||||
|     constexpr int ROWS = BLOCK_M; | ||||
|     constexpr int COLS = BLOCK_N / 16; | ||||
|     static_assert(COLS % 2 == 0); | ||||
|  | ||||
|     __m512i va; | ||||
|     __m512i vb[COLS]; | ||||
|     __m512i vc[ROWS * COLS]; | ||||
|     __m512i vcomp[COLS]; | ||||
|     __m512  was; | ||||
|     __m512  vbs[COLS]; | ||||
|  | ||||
|     auto loadc = [&](auto i) { | ||||
|       vc[i] = _mm512_set1_epi32(0); | ||||
|     }; | ||||
|     Unroll<ROWS * COLS>{}(loadc); | ||||
|  | ||||
|     const int64_t K4 = K >> 2; | ||||
|     const int64_t lda4 = lda >> 2; | ||||
|     const int64_t ldb4 = ldb; // ldb * 4 >> 2; | ||||
|     const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A); | ||||
|     const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B); | ||||
|  | ||||
|     auto compute = [&](auto i, int64_t k) { | ||||
|       constexpr int row = i / COLS; | ||||
|       constexpr int col = i % COLS; | ||||
|  | ||||
|       if constexpr (col == 0) { | ||||
|         va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); | ||||
|       } | ||||
|       if constexpr (row == 0) { | ||||
|         vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16); | ||||
|       } | ||||
|       vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); | ||||
|     }; | ||||
|     for (int64_t k = 0; k < K4; ++k) { | ||||
|       Unroll<ROWS * COLS>{}(compute, k); | ||||
|     } | ||||
|  | ||||
|     auto storec = [&](auto i) { | ||||
|       constexpr int row = i / COLS; | ||||
|       constexpr int col = i % COLS; | ||||
|  | ||||
|       // load a scale | ||||
|       if constexpr(col == 0) { | ||||
|         was = _mm512_set1_ps(As[row]); | ||||
|       } | ||||
|       // load b scale and vcomp per 2 vectors | ||||
|       // also load bias if any | ||||
|       if constexpr (row == 0) { | ||||
|         if constexpr (col % 2 == 0) { | ||||
|           vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16); | ||||
|           vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); | ||||
|           vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); | ||||
|           vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); | ||||
|         } | ||||
|       } | ||||
|       __m512 x = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[i], vcomp[col])); | ||||
|       x = _mm512_mul_ps(_mm512_mul_ps(x, was), vbs[col]); | ||||
|       _mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x); | ||||
|     }; | ||||
|     Unroll<ROWS * COLS>{}(storec); | ||||
|   } | ||||
| }; | ||||
| #endif | ||||
|  | ||||
| #define LAUNCH_TINYGEMM_KERNEL_VNNI2(MB_SIZE, NB_SIZE)                       \ | ||||
|     tinygemm_kernel_vnni2<scalar_t, MB_SIZE, NB_SIZE>::apply(                \ | ||||
|         A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \ | ||||
|         As + mb_start, Bs + nb_start, Bcomp + nb_start,                      \ | ||||
|         K, lda, ldb, ldc); | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void tinygemm_kernel( | ||||
|     const uint8_t* __restrict__ A, | ||||
|     const int8_t* __restrict__ B, | ||||
|     float* __restrict__ C, | ||||
|     const float* __restrict__ As, | ||||
|     const float* __restrict__ Bs, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K, | ||||
|     int64_t lda, | ||||
|     int64_t ldb, | ||||
|     int64_t ldc) { | ||||
|  | ||||
|   // B compensation | ||||
|   const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K); | ||||
|  | ||||
|   // pattern: 1-4-16 | ||||
|   constexpr int64_t BLOCK_M = 4; | ||||
|   constexpr int64_t BLOCK_N = 64; | ||||
|   const int64_t MB = div_up(M, BLOCK_M); | ||||
|   const int64_t NB = div_up(N, BLOCK_N); | ||||
|   for (int64_t mb = 0; mb < MB; ++mb) { | ||||
|     int64_t mb_start = mb * BLOCK_M; | ||||
|     int64_t mb_size = std::min(BLOCK_M, M - mb_start); | ||||
|     for (int64_t nb = 0; nb < NB; ++nb) { | ||||
|       int64_t nb_start = nb * BLOCK_N; | ||||
|       int64_t nb_size = std::min(BLOCK_N, N - nb_start); | ||||
|  | ||||
|       switch(mb_size << 4 | nb_size >> 4) { | ||||
|         case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI2(1, 32); break; | ||||
|         case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI2(2, 32); break; | ||||
|         case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI2(3, 32); break; | ||||
|         case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI2(4, 32); break; | ||||
|         default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| } // anonymous namespace | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void fused_experts_int8_kernel_impl( | ||||
|     scalar_t* __restrict__ output, | ||||
|     scalar_t* __restrict__ ic1, | ||||
|     scalar_t* __restrict__ ic2, | ||||
|     uint8_t* __restrict__ A_tmp, | ||||
|     float* __restrict__ C_tmp, | ||||
|     uint8_t* __restrict__ Aq_tmp, | ||||
|     float* __restrict__ As_tmp, | ||||
|     const scalar_t* __restrict__ input, | ||||
|     const int8_t* __restrict__ packed_w1, | ||||
|     const int8_t* __restrict__ packed_w2, | ||||
|     const float* __restrict__ w1s, | ||||
|     const float* __restrict__ w2s, | ||||
|     const float* __restrict__ topk_weights, | ||||
|     const int32_t* __restrict__ sorted_ids, | ||||
|     const int32_t* __restrict__ expert_ids, | ||||
|     const int32_t* __restrict__ offsets, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K, | ||||
|     int64_t E, | ||||
|     int64_t topk, | ||||
|     int64_t num_tokens_post_pad) { | ||||
|  | ||||
|   // handle 2 tiles per block | ||||
|   constexpr int64_t BLOCK_M = block_size_m(); | ||||
|   constexpr int64_t BLOCK_N = block_size_n(); | ||||
|  | ||||
|   // stage 0: quantize input to uint8, [M, K] | ||||
|   at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { | ||||
|     for (int64_t m = begin; m < end; ++m) { | ||||
|       quantize_row_int8<scalar_t>( | ||||
|           Aq_tmp + m * K, | ||||
|           As_tmp[m], | ||||
|           input + m * K, | ||||
|           K); | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   // stage 1: intermediate_cache1 = silu(hidden_states @ w1) | ||||
|   const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); | ||||
|   const int64_t NB = div_up(N, BLOCK_N); | ||||
|  | ||||
|   // strides for w1: [E, 2N, K] | ||||
|   TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); | ||||
|  | ||||
|   // K and N are packed for int8 | ||||
|   const int64_t packed_K = get_row_size<int8_t>(K); | ||||
|   const int64_t packed_N = get_row_size<int8_t>(N); | ||||
|  | ||||
|   const int64_t stride_e = 2 * N * packed_K; | ||||
|   const int64_t stride_n = packed_K; | ||||
|   // here we only parallel on half of 2N to fuse silu_and_mul with gemm | ||||
|   at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { | ||||
|     // get local pointers | ||||
|     int tid = at::get_thread_num(); | ||||
|     uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; | ||||
|  | ||||
|     alignas(64) float As[BLOCK_M]; | ||||
|  | ||||
|     for (int64_t i = begin; i < end; ++i) { | ||||
|       int64_t mb = i / NB; | ||||
|       int64_t nb = i % NB; | ||||
|  | ||||
|       // nb0 from top half and nb1 from bottom half | ||||
|       int64_t nb0 = nb, nb1 = nb + NB; | ||||
|       int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); | ||||
|  | ||||
|       // B shape [K, n_size] in vnni format | ||||
|       int32_t expert_id = expert_ids[mb]; | ||||
|       const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n; | ||||
|       const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n; | ||||
|       const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N; | ||||
|       const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N; | ||||
|  | ||||
|       // 1.a load A | ||||
|       const int32_t* A_ids = sorted_ids + mb * BLOCK_M; | ||||
|       int64_t m_size = offsets[mb + 1] - offsets[mb]; | ||||
|  | ||||
|       for (int64_t m = 0; m < m_size; ++m) { | ||||
|         int32_t index = A_ids[m] / topk; | ||||
|         copy_stub(A + m * K, Aq_tmp + index * K, K); | ||||
|         As[m] = As_tmp[index]; | ||||
|       } | ||||
|  | ||||
|       // fused 1.b: silu_and_mul(A @ B0, A @ B1) | ||||
|       const int64_t offset = offsets[mb]; | ||||
|       tinygemm_kernel( | ||||
|           /* A     */ A, | ||||
|           /* B0    */ B0, | ||||
|           /* B1    */ B1, | ||||
|           /* C     */ ic1 + offset * N + nb * BLOCK_N, | ||||
|           /* As    */ As, | ||||
|           /* Bs0   */ Bs0, | ||||
|           /* Bs1   */ Bs1, | ||||
|           /* M     */ m_size, | ||||
|           /* N     */ n_size, | ||||
|           /* K     */ K, | ||||
|           /* lda   */ K, | ||||
|           /* ldb   */ n_size, | ||||
|           /* ldc   */ N); | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   // stage 1.5: quantize ic1 to uint8, [M * topk, N] | ||||
|   at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) { | ||||
|     for (int64_t m = begin; m < end; ++m) { | ||||
|       quantize_row_int8<scalar_t>( | ||||
|           Aq_tmp + m * N, | ||||
|           As_tmp[m], | ||||
|           ic1 + m * N, | ||||
|           N); | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 | ||||
|   //   w2 : [E, K, N] as [E, OC, IC] | ||||
|   const int64_t OC = K;  // rename K as OC | ||||
|   const int64_t IC = N;  // rename N as IC | ||||
|   const int64_t MB2 = MB; | ||||
|   const int64_t NB2 = div_up(OC, BLOCK_N); | ||||
|   const int64_t stride_e2 = OC * packed_N; | ||||
|   const int64_t stride_oc = packed_N; | ||||
|  | ||||
|   // parallel on [MB2, NB2] | ||||
|   at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { | ||||
|     // get local pointers | ||||
|     int tid = at::get_thread_num(); | ||||
|     // we won't be using C1 for gemm2 | ||||
|     float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; | ||||
|  | ||||
|     for (int64_t i = begin; i < end; ++i) { | ||||
|       int64_t mb = i / NB2; | ||||
|       int64_t nb = i % NB2; | ||||
|  | ||||
|       int64_t m_size = offsets[mb + 1] - offsets[mb]; | ||||
|       int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); | ||||
|  | ||||
|       // A ptr from ic1 of [M * topk, N] in sorted order | ||||
|       // so as to avoid copy A to tmp buffer again | ||||
|       const uint8_t* __restrict__ A = Aq_tmp + offsets[mb] * N; | ||||
|       const float* __restrict__ As = As_tmp + offsets[mb]; | ||||
|       const int32_t* A_ids = sorted_ids + mb * BLOCK_M; | ||||
|  | ||||
|       // B shape [IC, n_size] in vnni format | ||||
|       int32_t expert_id = expert_ids[mb]; | ||||
|       const int8_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; | ||||
|       const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N; | ||||
|  | ||||
|       // 2.a gemm: C = A @ B | ||||
|       tinygemm_kernel<scalar_t>( | ||||
|           /* A     */ A, | ||||
|           /* B     */ B, | ||||
|           /* C     */ C, | ||||
|           /* As    */ As, | ||||
|           /* Bs    */ Bs, | ||||
|           /* M     */ m_size, | ||||
|           /* N     */ n_size, | ||||
|           /* K     */ IC, | ||||
|           /* lda   */ IC, | ||||
|           /* ldb   */ n_size, | ||||
|           /* ldc   */ BLOCK_N); | ||||
|  | ||||
|       // 2.b copy from C to ic2 in original order | ||||
|       //   and also mul topk_weights in float32 | ||||
|       for (int64_t m = 0; m < m_size; ++m) { | ||||
|         int32_t index = A_ids[m]; | ||||
|         float weight = topk_weights[index]; | ||||
|         copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); | ||||
|       } | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   // stage 3: out = intermediate_cache2.sum(dim=1) | ||||
|   //   from [M, topk, K] to [M, K] | ||||
|   at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { | ||||
|     for (int64_t m = begin; m < end; ++m) { | ||||
|       sum_stub(output + m * K, ic2 + m * topk * K, topk, K); | ||||
|     } | ||||
|   }); | ||||
| } | ||||
|  | ||||
| #define INSTANTIATE_MOE_INT8_TEMPLATE(TYPE)                                                  \ | ||||
|   template void fused_experts_int8_kernel_impl<TYPE> (                                       \ | ||||
|       TYPE* __restrict__ output, TYPE* __restrict__ ic1,                                     \ | ||||
|       TYPE* __restrict__ ic2, uint8_t* __restrict__ A_tmp,                                   \ | ||||
|       float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp,                               \ | ||||
|       float* __restrict__ As_tmp, const TYPE* __restrict__ input,                            \ | ||||
|       const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2,            \ | ||||
|       const float* __restrict__ w1s, const float* __restrict__ w2s,                          \ | ||||
|       const float* __restrict__ topk_weights, const int32_t* __restrict__ sorted_ids,        \ | ||||
|       const int32_t* __restrict__ expert_ids, const int32_t* __restrict__ offsets,           \ | ||||
|       int64_t M, int64_t N, int64_t K, int64_t E, int64_t topk, int64_t num_tokens_post_pad) | ||||
|  | ||||
| INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16); | ||||
| INSTANTIATE_MOE_INT8_TEMPLATE(at::Half); | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void shared_expert_int8_kernel_impl( | ||||
|     scalar_t* __restrict__ output, | ||||
|     scalar_t* __restrict__ ic1, | ||||
|     float* __restrict__ C_tmp, | ||||
|     uint8_t* __restrict__ Aq_tmp, | ||||
|     float* __restrict__ As_tmp, | ||||
|     const scalar_t* __restrict__ input, | ||||
|     const int8_t* __restrict__ packed_w1, | ||||
|     const int8_t* __restrict__ packed_w2, | ||||
|     const float* __restrict__ w1s, | ||||
|     const float* __restrict__ w2s, | ||||
|     const scalar_t* __restrict__ fused_experts_out, | ||||
|     float routed_scaling_factor, | ||||
|     int64_t M, | ||||
|     int64_t N, | ||||
|     int64_t K) { | ||||
|  | ||||
|   // handle 2 tiles per block | ||||
|   constexpr int64_t BLOCK_M = block_size_m(); | ||||
|   constexpr int64_t BLOCK_N = block_size_n(); | ||||
|  | ||||
|   // stage 0: quantize input to uint8, [M, K] | ||||
|   at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { | ||||
|     for (int64_t m = begin; m < end; ++m) { | ||||
|       quantize_row_int8<scalar_t>( | ||||
|           Aq_tmp + m * K, | ||||
|           As_tmp[m], | ||||
|           input + m * K, | ||||
|           K); | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|    // stage 1: intermediate_cache1 = silu(hidden_states @ w1) | ||||
|   const int64_t MB = div_up(M, BLOCK_M); | ||||
|   const int64_t NB = div_up(N, BLOCK_N); | ||||
|  | ||||
|   TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); | ||||
|  | ||||
|   // K and N are packed for int8 | ||||
|   const int64_t packed_K = get_row_size<int8_t>(K); | ||||
|   const int64_t packed_N = get_row_size<int8_t>(N); | ||||
|   const int64_t stride_n = packed_K; | ||||
|  | ||||
|   // here we only parallel on half of 2N to fuse silu_and_mul with gemm | ||||
|   at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { | ||||
|     for (int64_t i = begin; i < end; ++i) { | ||||
|       int64_t mb = i / NB; | ||||
|       int64_t nb = i % NB; | ||||
|  | ||||
|       // nb0 from top half and nb1 from bottom half | ||||
|       int64_t nb0 = nb, nb1 = nb + NB; | ||||
|       int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); | ||||
|       int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); | ||||
|  | ||||
|       // A shape [m_size, K] | ||||
|       const uint8_t* A = Aq_tmp + mb * BLOCK_M * K; | ||||
|       const float* As = As_tmp + mb * BLOCK_M; | ||||
|  | ||||
|       // B shape [K, n_size] in vnni format | ||||
|       const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; | ||||
|       const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; | ||||
|       const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N; | ||||
|       const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N; | ||||
|  | ||||
|       // fused 1.b: silu_and_mul(A @ B0, A @ B1) | ||||
|       tinygemm_kernel( | ||||
|           /* A     */ A, | ||||
|           /* B0    */ B0, | ||||
|           /* B1    */ B1, | ||||
|           /* C     */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, | ||||
|           /* As    */ As, | ||||
|           /* Bs0   */ Bs0, | ||||
|           /* Bs1   */ Bs1, | ||||
|           /* M     */ m_size, | ||||
|           /* N     */ n_size, | ||||
|           /* K     */ K, | ||||
|           /* lda   */ K, | ||||
|           /* ldb   */ n_size, | ||||
|           /* ldc   */ N); | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   // stage 1.5: quantize ic1 to uint8, [M * topk, N] | ||||
|   at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { | ||||
|     for (int64_t m = begin; m < end; ++m) { | ||||
|       quantize_row_int8<scalar_t>( | ||||
|           Aq_tmp + m * N, | ||||
|           As_tmp[m], | ||||
|           ic1 + m * N, | ||||
|           N); | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 | ||||
|   //   w2 : [K, N] as [OC, IC] | ||||
|   const int64_t OC = K;  // rename K as OC | ||||
|   const int64_t IC = N;  // rename N as IC | ||||
|   const int64_t MB2 = MB; | ||||
|   const int64_t NB2 = div_up(OC, BLOCK_N); | ||||
|   const int64_t stride_oc = packed_N; | ||||
|  | ||||
|   // parallel on [MB2, NB2] | ||||
|   at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { | ||||
|     // get local pointers | ||||
|     int tid = at::get_thread_num(); | ||||
|     // we won't be using C1 for gemm2 | ||||
|     float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; | ||||
|  | ||||
|     for (int64_t i = begin; i < end; ++i) { | ||||
|       int64_t mb = i / NB2; | ||||
|       int64_t nb = i % NB2; | ||||
|  | ||||
|       int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); | ||||
|       int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); | ||||
|  | ||||
|       // A shape [m_size, IC] | ||||
|       const uint8_t* __restrict__ A = Aq_tmp + mb * BLOCK_M * N; | ||||
|       const float* __restrict__ As = As_tmp + mb * BLOCK_M; | ||||
|  | ||||
|       // B shape [IC, n_size] in vnni format | ||||
|       const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc; | ||||
|       const float* __restrict__ Bs = w2s + nb * BLOCK_N; | ||||
|  | ||||
|       // 2.a gemm: C = A @ B | ||||
|       tinygemm_kernel<scalar_t>( | ||||
|           /* A     */ A, | ||||
|           /* B     */ B, | ||||
|           /* C     */ C, | ||||
|           /* As    */ As, | ||||
|           /* Bs    */ Bs, | ||||
|           /* M     */ m_size, | ||||
|           /* N     */ n_size, | ||||
|           /* K     */ IC, | ||||
|           /* lda   */ IC, | ||||
|           /* ldb   */ n_size, | ||||
|           /* ldc   */ BLOCK_N); | ||||
|  | ||||
|       // 2.b copy from C to output and add fused_experts_out | ||||
|       scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; | ||||
|       const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; | ||||
|       for (int64_t m = 0; m < m_size; ++m) { | ||||
|         add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); | ||||
|       } | ||||
|     } | ||||
|   }); | ||||
| } | ||||
|  | ||||
| #define INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(TYPE)                                        \ | ||||
|   template void shared_expert_int8_kernel_impl<TYPE> (                                       \ | ||||
|       TYPE* __restrict__ output, TYPE* __restrict__ ic1,                                     \ | ||||
|       float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp,                               \ | ||||
|       float* __restrict__ As_tmp, const TYPE* __restrict__ input,                            \ | ||||
|       const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2,            \ | ||||
|       const float* __restrict__ w1s, const float* __restrict__ w2s,                          \ | ||||
|       const TYPE* __restrict__ fused_experts_out, float routed_scaling_factor,               \ | ||||
|       int64_t M, int64_t N, int64_t K) | ||||
|  | ||||
| INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::BFloat16); | ||||
| INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::Half); | ||||
							
								
								
									
										308
									
								
								csrc/cpu/sgl-kernels/vec.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										308
									
								
								csrc/cpu/sgl-kernels/vec.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,308 @@ | ||||
| // Adapted from | ||||
| // https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu | ||||
|  | ||||
| #pragma once | ||||
|  | ||||
| // clang-format off | ||||
|  | ||||
| #if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__) | ||||
| #define CPU_CAPABILITY_AVX512 | ||||
| #endif | ||||
|  | ||||
| #include <ATen/cpu/vec/functional.h> | ||||
| #include <ATen/cpu/vec/vec.h> | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| using namespace at::vec; | ||||
|  | ||||
| template <typename scalar_t, | ||||
|           typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0> | ||||
| inline Vectorized<scalar_t> convert_from_float_ext(const Vectorized<float>& a, const Vectorized<float>& b) { | ||||
|   return at::vec::convert_from_float<scalar_t>(a, b); | ||||
| } | ||||
|  | ||||
| #if defined(CPU_CAPABILITY_AVX512) | ||||
|  | ||||
| // `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics | ||||
| // use native instruction for bfloat16->float32 conversion | ||||
| template <> | ||||
| inline Vectorized<at::BFloat16> convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorized<float>& b) { | ||||
|   return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a))); | ||||
| } | ||||
|  | ||||
| #define CVT_BF16_TO_FP32(a) \ | ||||
|     _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) | ||||
|  | ||||
| #define CVT_FP16_TO_FP32(a) \ | ||||
|     _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) | ||||
|  | ||||
| // this doesn't hanel NaN. | ||||
| inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) { | ||||
|   const __m512i x = _mm512_cvtepu8_epi16(fp8_vec); | ||||
|  | ||||
|   const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4); | ||||
|   const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3); | ||||
|   const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7); | ||||
|   const __m512i nonsign = _mm512_or_si512(exp, mant); | ||||
|  | ||||
|   const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8); | ||||
|   const __m512i combined = _mm512_or_si512(nonsign, sign); | ||||
|  | ||||
|   const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512()); | ||||
|   return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined); | ||||
| } | ||||
|  | ||||
| inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) { | ||||
|   // The following conversion is without denorm behavior, that is to say, | ||||
|   //   Max subnorm   : S.0000.111 = 0.875 ∗ 2**(−6) | ||||
|   //   Min subnorm   : S.0000.001 = 2**(−9) | ||||
|   // 0.0019 ~ 0.0137 cannot be converted correctly. | ||||
|   __m512i x = _mm512_cvtepu8_epi16(fp8_vec); | ||||
|   auto mask = _mm512_cmpneq_epi16_mask( | ||||
|       _mm512_and_si512(x, _mm512_set1_epi16(127)), | ||||
|       _mm512_setzero_si512());  // mask = x & 0x7f | ||||
|   auto mask_nan = _mm512_cmpneq_epi16_mask( | ||||
|       _mm512_and_si512(x, _mm512_set1_epi16(127)), | ||||
|       _mm512_set1_epi16(127));                                                      // mask_nan = x & 0x7f | ||||
|   auto mantissa = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4);  // mantissa = (x & 7) << 4 | ||||
|   auto exponent = _mm512_add_epi16( | ||||
|       _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), | ||||
|       _mm512_set1_epi16(120));  // exponent = (((x >> 3) & 15) + 120) | ||||
|   auto nonsign = _mm512_maskz_mov_epi16(mask, _mm512_or_si512(mantissa, _mm512_slli_epi16(exponent, 7))); | ||||
|   nonsign = _mm512_mask_mov_epi16(_mm512_set1_epi16(0x7fff), mask_nan, nonsign);  // deal with Nan | ||||
|   return (__m512bh)(_mm512_or_si512( | ||||
|       nonsign, | ||||
|       _mm512_slli_epi16( | ||||
|           _mm512_and_si512(x, _mm512_set1_epi16(128)), | ||||
|           8)));  // add sign (x & 128) << 8 | ||||
| } | ||||
|  | ||||
| inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) { | ||||
|   __m512i x = _mm512_cvtepu8_epi16(fp8_vec); | ||||
|   __m512i lg2mant = _mm512_mask_mov_epi16( | ||||
|       _mm512_mask_mov_epi16( | ||||
|           _mm512_setzero_si512(), _mm512_test_epi16_mask(x, _mm512_set1_epi16(2)), _mm512_set1_epi16(1)), | ||||
|       _mm512_test_epi16_mask(x, _mm512_set1_epi16(4)), | ||||
|       _mm512_set1_epi16(2)); | ||||
|   return (__m512bh)(_mm512_or_si512( | ||||
|       _mm512_maskz_mov_epi16( | ||||
|           _mm512_cmpneq_epi16_mask(_mm512_and_si512(x, _mm512_set1_epi16(127)), _mm512_setzero_si512()), | ||||
|           _mm512_mask_blend_epi16( | ||||
|               _mm512_test_epi16_mask(x, _mm512_set1_epi16(120)), | ||||
|               _mm512_or_si512( | ||||
|                   _mm512_and_si512( | ||||
|                       _mm512_sllv_epi16( | ||||
|                           _mm512_and_si512(x, _mm512_set1_epi16(3)), _mm512_sub_epi16(_mm512_set1_epi16(7), lg2mant)), | ||||
|                       _mm512_set1_epi16(0x007f)), | ||||
|                   _mm512_slli_epi16(_mm512_add_epi16(lg2mant, _mm512_set1_epi16(118)), 7)), | ||||
|               _mm512_or_si512( | ||||
|                   _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4), | ||||
|                   _mm512_slli_epi16( | ||||
|                       _mm512_add_epi16( | ||||
|                           _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), _mm512_set1_epi16(120)), | ||||
|                       7)))), | ||||
|       _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(128)), 8))); | ||||
| } | ||||
|  | ||||
| inline __m512bh CVT_FP8_TO_BF16(__m256i a) { | ||||
| #ifdef SGLANG_CPU_FP8_CVT_FTZ | ||||
|   return cvt_e4m3_bf16_intrinsic_no_nan(a); | ||||
| #else | ||||
|   return cvt_e4m3_bf16_intrinsic_with_denorm(a); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| #endif | ||||
|  | ||||
| // vector to scalar reduction | ||||
| #if defined(CPU_CAPABILITY_AVX512) && 0 | ||||
| inline float vec_reduce_sum(const Vectorized<float>& a) { | ||||
|   return _mm512_reduce_add_ps(__m512(a)); | ||||
| } | ||||
|  | ||||
| inline float vec_reduce_max(const Vectorized<float>& a) { | ||||
|   return _mm512_reduce_max_ps(__m512(a)); | ||||
| } | ||||
| #else | ||||
| inline float vec_reduce_sum(const Vectorized<float>& a) { | ||||
|   return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return x + y; }, a); | ||||
| } | ||||
|  | ||||
| inline float vec_reduce_max(const Vectorized<float>& a) { | ||||
|   return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return maximum(x, y); }, a); | ||||
| } | ||||
| #endif | ||||
|  | ||||
| // https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282 | ||||
| template <typename scalar_t> | ||||
| inline void quantize_row_int8(uint8_t* __restrict__ Aq, float& As, | ||||
|     const scalar_t* __restrict__ A, int64_t K, float eps = 1e-7) { | ||||
|  | ||||
|   float amax = 0.f; // absolute max | ||||
|   for (int64_t k = 0; k < K; ++k) { | ||||
|     const float val = static_cast<float>(A[k]); | ||||
|     amax = std::max(amax, std::abs(val)); | ||||
|   } | ||||
|  | ||||
|   amax = std::max(amax, eps); | ||||
|   const float scale = amax / 127; | ||||
|   const float inv_scale = 127 / amax; | ||||
|  | ||||
|   for (int64_t k = 0; k < K; ++k) { | ||||
|     const float val = static_cast<float>(A[k]) * inv_scale; | ||||
|     Aq[k] = (uint8_t)(std::round(val)) + 128; | ||||
|   } | ||||
|   As = scale; | ||||
| } | ||||
|  | ||||
| #if defined(CPU_CAPABILITY_AVX512) | ||||
| template <> | ||||
| inline void quantize_row_int8<at::BFloat16>(uint8_t* __restrict__ Aq, float& As, | ||||
|     const at::BFloat16* __restrict__ A, int64_t K, float eps) { | ||||
|  | ||||
|   const __m512 signBit = _mm512_set1_ps(-0.0f); | ||||
|   const __m512i off = _mm512_set1_epi32(128); | ||||
|  | ||||
|   // K is 32x, no remainder | ||||
|   float amax = 0.f; | ||||
|   __m512 vamax0 = _mm512_set1_ps(0.f); | ||||
|   __m512 vamax1 = _mm512_set1_ps(0.f); | ||||
|   for (int64_t k = 0; k < K; k += 32) { | ||||
|     __m512i va = _mm512_loadu_si512((void*)(A + k)); | ||||
|     __m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0)); | ||||
|     __m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1)); | ||||
|     vamax0 = _mm512_max_ps(vamax0, _mm512_andnot_ps(signBit, va0)); | ||||
|     vamax1 = _mm512_max_ps(vamax1, _mm512_andnot_ps(signBit, va1)); | ||||
|   } | ||||
|   amax = _mm512_reduce_max_ps(_mm512_max_ps(vamax0, vamax1)); | ||||
|   amax = std::max(amax, eps); | ||||
|   const float scale = amax / 127; | ||||
|   const float inv_scale = 127 / amax; | ||||
|   const __m512 vd = _mm512_set1_ps(inv_scale); | ||||
|  | ||||
|   for (int64_t k = 0; k < K; k += 32) { | ||||
|     __m512i va = _mm512_loadu_si512((void*)(A + k)); | ||||
|     __m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0)); | ||||
|     __m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1)); | ||||
|     va0 = _mm512_mul_ps(va0, vd); | ||||
|     va1 = _mm512_mul_ps(va1, vd); | ||||
|     va0 = _mm512_roundscale_ps(va0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); | ||||
|     va1 = _mm512_roundscale_ps(va1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); | ||||
|     __m128i i0 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va0), off)); | ||||
|     __m128i i1 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va1), off)); | ||||
|     _mm256_storeu_si256(reinterpret_cast<__m256i*>(Aq + k), _mm256_set_m128i(i1, i0)); | ||||
|   } | ||||
|   As = scale; | ||||
| } | ||||
| #endif | ||||
|  | ||||
| // transpose utils | ||||
| // taken from my PR in ggml: https://github.com/ggml-org/llama.cpp/pull/8998 | ||||
| #if defined(CPU_CAPABILITY_AVX512) | ||||
| inline void transpose_16x16_32bit(__m512i * v) { | ||||
|   __m512i v1[16]; | ||||
|   v1[0] = _mm512_unpacklo_epi32(v[0], v[1]); | ||||
|   v1[1] = _mm512_unpackhi_epi32(v[0], v[1]); | ||||
|   v1[2] = _mm512_unpacklo_epi32(v[2], v[3]); | ||||
|   v1[3] = _mm512_unpackhi_epi32(v[2], v[3]); | ||||
|   v1[4] = _mm512_unpacklo_epi32(v[4], v[5]); | ||||
|   v1[5] = _mm512_unpackhi_epi32(v[4], v[5]); | ||||
|   v1[6] = _mm512_unpacklo_epi32(v[6], v[7]); | ||||
|   v1[7] = _mm512_unpackhi_epi32(v[6], v[7]); | ||||
|   v1[8] = _mm512_unpacklo_epi32(v[8], v[9]); | ||||
|   v1[9] = _mm512_unpackhi_epi32(v[8], v[9]); | ||||
|   v1[10] = _mm512_unpacklo_epi32(v[10], v[11]); | ||||
|   v1[11] = _mm512_unpackhi_epi32(v[10], v[11]); | ||||
|   v1[12] = _mm512_unpacklo_epi32(v[12], v[13]); | ||||
|   v1[13] = _mm512_unpackhi_epi32(v[12], v[13]); | ||||
|   v1[14] = _mm512_unpacklo_epi32(v[14], v[15]); | ||||
|   v1[15] = _mm512_unpackhi_epi32(v[14], v[15]); | ||||
|  | ||||
|   v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]); | ||||
|   v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]); | ||||
|   v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]); | ||||
|   v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]); | ||||
|   v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]); | ||||
|   v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]); | ||||
|   v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]); | ||||
|   v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]); | ||||
|   v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]); | ||||
|   v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]); | ||||
|   v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]); | ||||
|   v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]); | ||||
|   v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]); | ||||
|   v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]); | ||||
|   v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]); | ||||
|   v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]); | ||||
|  | ||||
|   v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88); | ||||
|   v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88); | ||||
|   v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88); | ||||
|   v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88); | ||||
|   v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd); | ||||
|   v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd); | ||||
|   v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd); | ||||
|   v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd); | ||||
|   v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88); | ||||
|   v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88); | ||||
|   v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88); | ||||
|   v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88); | ||||
|   v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd); | ||||
|   v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd); | ||||
|   v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd); | ||||
|   v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd); | ||||
|  | ||||
|   v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88); | ||||
|   v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88); | ||||
|   v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88); | ||||
|   v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88); | ||||
|   v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88); | ||||
|   v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88); | ||||
|   v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88); | ||||
|   v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88); | ||||
|   v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd); | ||||
|   v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd); | ||||
|   v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd); | ||||
|   v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd); | ||||
|   v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd); | ||||
|   v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd); | ||||
|   v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd); | ||||
|   v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd); | ||||
| } | ||||
|  | ||||
| // remove warning : ignoring attributes on template argument ‘__m512i’ [-Wignored-attributes] | ||||
| #pragma GCC diagnostic push | ||||
| #pragma GCC diagnostic ignored "-Wignored-attributes" | ||||
|  | ||||
| // transpose from [2, 32] to [32, 2] | ||||
| inline std::tuple<__m512i, __m512i> transpose_2x32_16bit(__m512i r0, __m512i r1) { | ||||
|   // r0: {a0, a1, ..., a31} | ||||
|   // r1: {b0, b1, ..., b31} | ||||
|   // | ||||
|   // d0: {a0,   b0, ..., a15, b15} | ||||
|   // d1: {a16, b16, ..., a31, b31} | ||||
|   // | ||||
|   __m512i d0 = _mm512_unpacklo_epi16(r0, r1); | ||||
|   __m512i d1 = _mm512_unpackhi_epi16(r0, r1); | ||||
|   r0 = _mm512_shuffle_i32x4(d0, d1, 0x88); | ||||
|   r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd); | ||||
|   d0 = _mm512_shuffle_i32x4(r0, r1, 0x88); | ||||
|   d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd); | ||||
|   return std::make_tuple(d0, d1); | ||||
| } | ||||
| #pragma GCC diagnostic pop | ||||
|  | ||||
| #endif | ||||
|  | ||||
| // TODO: debug print, remove me later | ||||
| template<typename scalar_t> | ||||
| void print_array(scalar_t* ptr, int size) { | ||||
|   for (int d = 0; d < size; ++d) { | ||||
|     if (d % 16 == 0) { std::cout << std::endl; } | ||||
|     std::cout << ptr[d] << " "; | ||||
|   } | ||||
|   std::cout << std::endl; | ||||
| } | ||||
|  | ||||
| } // anonymous namespace | ||||
							
								
								
									
										178
									
								
								csrc/cpu/shm.cpp
									
									
									
									
									
								
							
							
						
						
									
										178
									
								
								csrc/cpu/shm.cpp
									
									
									
									
									
								
							| @ -7,9 +7,10 @@ | ||||
|  | ||||
| namespace { | ||||
| #define MAX_SHM_RANK_NUM 8 | ||||
| #define MAX_THREAD_NUM 12 | ||||
| #define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024) | ||||
| #define MIN_THREAD_PROCESS_SIZE (8 * 1024) | ||||
| #define PER_THREAD_SHM_BUFFER_BYTES (2 * 1024 * 1024) | ||||
| static_assert(PER_THREAD_SHM_BUFFER_BYTES % 2 == 0); | ||||
| #define PER_THREAD_SHM_BUFFER_OFFSET (PER_THREAD_SHM_BUFFER_BYTES >> 1) | ||||
| #define MIN_THREAD_PROCESS_SIZE (256) | ||||
| #define MAX_P2P_SEND_TENSOR_NUM 8 | ||||
|  | ||||
| template <typename scalar_t> | ||||
| @ -32,10 +33,10 @@ struct KernelVecType<c10::Half> { | ||||
|   using scalar_vec_t = vec_op::FP16Vec16; | ||||
| }; | ||||
|  | ||||
| enum class ThreadSHMStat : char { THREAD_READY = 0, SHM_DATA_READY, DONE }; | ||||
|  | ||||
| struct ThreadSHMContext { | ||||
|   volatile ThreadSHMStat thread_stats[MAX_SHM_RANK_NUM]; | ||||
|   volatile char _curr_thread_stamp; | ||||
|   volatile char _ready_thread_stamp; | ||||
|   char _padding1[6]; | ||||
|   int thread_id; | ||||
|   int thread_num; | ||||
|   int rank; | ||||
| @ -44,14 +45,19 @@ struct ThreadSHMContext { | ||||
|   int swizzled_ranks[MAX_SHM_RANK_NUM]; | ||||
|   void* thread_shm_ptrs[MAX_SHM_RANK_NUM]; | ||||
|   ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM]; | ||||
|   size_t _thread_buffer_mask; | ||||
|   char _padding2[56]; | ||||
|  | ||||
|   ThreadSHMContext(const int thread_id, const int thread_num, const int rank, | ||||
|                    const int group_size, void* thread_shm_ptr) | ||||
|       : thread_id(thread_id), | ||||
|       : _curr_thread_stamp(1), | ||||
|         _ready_thread_stamp(0), | ||||
|         thread_id(thread_id), | ||||
|         thread_num(thread_num), | ||||
|         rank(rank), | ||||
|         group_size(group_size), | ||||
|         _spinning_count(0) { | ||||
|         _spinning_count(0), | ||||
|         _thread_buffer_mask(0) { | ||||
|     static_assert(sizeof(ThreadSHMContext) % 64 == 0); | ||||
|     TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM); | ||||
|     TORCH_CHECK((size_t)this % 64 == 0); | ||||
| @ -60,7 +66,6 @@ struct ThreadSHMContext { | ||||
|       shm_contexts[i] = nullptr; | ||||
|       thread_shm_ptrs[i] = nullptr; | ||||
|       swizzled_ranks[i] = (i + rank) % group_size; | ||||
|       thread_stats[i] = ThreadSHMStat::DONE; | ||||
|     } | ||||
|     set_context(rank, this, thread_shm_ptr); | ||||
|   } | ||||
| @ -77,59 +82,66 @@ struct ThreadSHMContext { | ||||
|  | ||||
|   template <typename T> | ||||
|   T* get_thread_shm_ptr(int rank) { | ||||
|     return reinterpret_cast<T*>(thread_shm_ptrs[rank]); | ||||
|     return reinterpret_cast<T*>( | ||||
|         reinterpret_cast<int8_t*>(thread_shm_ptrs[rank]) + | ||||
|         (PER_THREAD_SHM_BUFFER_OFFSET & _thread_buffer_mask)); | ||||
|   } | ||||
|  | ||||
|   void next_buffer() { _thread_buffer_mask ^= 0xFFFFFFFFFFFFFFFF; } | ||||
|  | ||||
|   char get_curr_stamp() const { return _curr_thread_stamp; } | ||||
|  | ||||
|   char get_ready_stamp() const { return _ready_thread_stamp; } | ||||
|  | ||||
|   void next_stamp() { | ||||
|     _mm_mfence(); | ||||
|     _curr_thread_stamp += 1; | ||||
|   } | ||||
|  | ||||
|   void commit_ready_stamp() { | ||||
|     _mm_mfence(); | ||||
|     _ready_thread_stamp = _curr_thread_stamp; | ||||
|   } | ||||
|  | ||||
|   int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; } | ||||
|  | ||||
|   void wait_for_all(ThreadSHMStat prev_stat) { | ||||
|     for (int idx = 0; idx < group_size; ++idx) { | ||||
|   template <typename Cond> | ||||
|   void wait_for_all(Cond&& cond) { | ||||
|     for (int idx = 1; idx < group_size; ++idx) { | ||||
|       int rank = get_swizzled_rank(idx); | ||||
|       while (thread_stats[rank] == prev_stat) { | ||||
|         ++_spinning_count; | ||||
|         _mm_pause(); | ||||
|       } | ||||
|       wait_for_one(rank, std::forward<Cond>(cond)); | ||||
|     } | ||||
|     vec_op::mem_barrier(); | ||||
|   } | ||||
|  | ||||
|   void wait_for_one(int rank, ThreadSHMStat prev_stat) { | ||||
|     while (thread_stats[rank] == prev_stat) { | ||||
|   template <typename Cond> | ||||
|   void wait_for_one(int rank, Cond&& cond) { | ||||
|     ThreadSHMContext* rank_ctx = shm_contexts[rank]; | ||||
|     for (;;) { | ||||
|       char local_curr_stamp = get_curr_stamp(); | ||||
|       char local_ready_stamp = get_ready_stamp(); | ||||
|       char rank_curr_stamp = rank_ctx->get_curr_stamp(); | ||||
|       char rank_ready_stamp = rank_ctx->get_ready_stamp(); | ||||
|       if (cond(local_curr_stamp, local_ready_stamp, rank_curr_stamp, | ||||
|                rank_ready_stamp)) { | ||||
|         break; | ||||
|       } | ||||
|       ++_spinning_count; | ||||
|       _mm_pause(); | ||||
|     } | ||||
|     vec_op::mem_barrier(); | ||||
|   } | ||||
|  | ||||
|   void set_thread_stat(ThreadSHMStat stat) { | ||||
|     for (int idx = 0; idx < group_size; ++idx) { | ||||
|       int rank = get_swizzled_rank(idx); | ||||
|       shm_contexts[rank]->thread_stats[this->rank] = stat; | ||||
|     } | ||||
|   static bool check_no_buffer_conflict(char local_curr_stamp, | ||||
|                                        char local_ready_stamp, | ||||
|                                        char rank_curr_stamp, | ||||
|                                        char rank_ready_stamp) { | ||||
|     char temp = rank_curr_stamp + 2; | ||||
|     return local_curr_stamp != temp; | ||||
|   } | ||||
|  | ||||
|   void set_thread_stat(int target_rank, ThreadSHMStat stat) { | ||||
|     for (int idx = 0; idx < group_size; ++idx) { | ||||
|       int rank = get_swizzled_rank(idx); | ||||
|       shm_contexts[rank]->thread_stats[target_rank] = stat; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   // barrier for all ranks in the group, used for all2all ops | ||||
|   // DONE -> THREAD_READY -> SHM_DATA_READY -> DONE -> ... | ||||
|   void barrier(ThreadSHMStat next_stat) { | ||||
|     if (next_stat == ThreadSHMStat::THREAD_READY) { | ||||
|       set_thread_stat(ThreadSHMStat::THREAD_READY); | ||||
|       wait_for_all(ThreadSHMStat::DONE); | ||||
|     } else if (next_stat == ThreadSHMStat::SHM_DATA_READY) { | ||||
|       set_thread_stat(ThreadSHMStat::SHM_DATA_READY); | ||||
|       wait_for_all(ThreadSHMStat::THREAD_READY); | ||||
|     } else if (next_stat == ThreadSHMStat::DONE) { | ||||
|       set_thread_stat(ThreadSHMStat::DONE); | ||||
|       wait_for_all(ThreadSHMStat::SHM_DATA_READY); | ||||
|     } else { | ||||
|       TORCH_CHECK(false, "Invalid next_stat to barrier."); | ||||
|     } | ||||
|   static bool check_stamp_ready(char local_curr_stamp, char local_ready_stamp, | ||||
|                                 char rank_curr_stamp, char rank_ready_stamp) { | ||||
|     char temp = local_curr_stamp + 1; | ||||
|     return (local_curr_stamp == rank_ready_stamp) || (temp == rank_ready_stamp); | ||||
|   } | ||||
|  | ||||
|   std::string to_string() const { | ||||
| @ -164,7 +176,7 @@ class SHMManager { | ||||
|                       const int group_size) | ||||
|       : _rank(rank), | ||||
|         _group_size(group_size), | ||||
|         _thread_num(std::min(torch::get_num_threads(), MAX_THREAD_NUM)), | ||||
|         _thread_num(torch::get_num_threads()), | ||||
|         _shm_names({""}), | ||||
|         _shared_mem_ptrs({nullptr}), | ||||
|         _shm_ctx(nullptr) { | ||||
| @ -326,7 +338,8 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) { | ||||
|       (total_units_num + thread_num - 1) / thread_num; | ||||
|   int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t); | ||||
|   int64_t max_per_thread_iteration_elem_num = | ||||
|       PER_THREAD_SHM_BUFFER_BYTES / sizeof(scalar_t); | ||||
|       (PER_THREAD_SHM_BUFFER_BYTES >> 1) / | ||||
|       sizeof(scalar_t);  // Note: double buffer | ||||
|   int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num; | ||||
|  | ||||
| #pragma omp parallel for schedule(static, 1) | ||||
| @ -336,10 +349,13 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) { | ||||
|     int64_t curr_elem_num = | ||||
|         std::min(max_per_thread_iteration_elem_num, end - offset); | ||||
|     ThreadSHMContext* thread_ctx = ctx + i; | ||||
|     bool fast_mode = ((end - offset) <= max_per_thread_iteration_elem_num); | ||||
|  | ||||
|     while (curr_elem_num > 0) { | ||||
|       inner_func(thread_ctx, offset, curr_elem_num); | ||||
|       inner_func(thread_ctx, offset, curr_elem_num, fast_mode); | ||||
|  | ||||
|       thread_ctx->next_stamp(); | ||||
|       thread_ctx->next_buffer(); | ||||
|       offset += max_per_thread_iteration_elem_num; | ||||
|       curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset); | ||||
|     } | ||||
| @ -397,7 +413,7 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, | ||||
|   shm_cc_ops::shm_cc_loop<scalar_t>( | ||||
|       ctx, elem_num, | ||||
|       [&](ThreadSHMContext* thread_ctx, int64_t data_offset, | ||||
|           int64_t data_elem_num) { | ||||
|           int64_t data_elem_num, bool fast_mode) { | ||||
|         int rank = thread_ctx->rank; | ||||
|         scalar_t* thread_shm_ptr = | ||||
|             thread_ctx->get_thread_shm_ptr<scalar_t>(rank); | ||||
| @ -410,16 +426,17 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, | ||||
|               thread_ctx->get_swizzled_rank(idx + 1)); | ||||
|         }); | ||||
|  | ||||
|         thread_ctx->barrier(ThreadSHMStat::THREAD_READY); | ||||
|         if (!fast_mode) { | ||||
|           thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict); | ||||
|         } | ||||
|  | ||||
|         shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr, | ||||
|                                   thread_data_elem_num); | ||||
|  | ||||
|         thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY); | ||||
|  | ||||
|         thread_ctx->commit_ready_stamp(); | ||||
|         int64_t aligned_data_elem_num = | ||||
|             (data_elem_num / vec_elem_num) * vec_elem_num; | ||||
|         int64_t i = 0; | ||||
|         thread_ctx->wait_for_all(ThreadSHMContext::check_stamp_ready); | ||||
| #pragma GCC unroll 4 | ||||
|         for (; i < aligned_data_elem_num; i += vec_elem_num) { | ||||
|           vec_t local_data(thread_data_ptr + i);  // load from cache | ||||
| @ -447,8 +464,6 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, | ||||
|           reduced_data.save(thread_data_ptr + i, | ||||
|                             data_elem_num - aligned_data_elem_num); | ||||
|         } | ||||
|  | ||||
|         thread_ctx->barrier(ThreadSHMStat::DONE); | ||||
|       }); | ||||
|  | ||||
|   return; | ||||
| @ -488,18 +503,18 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num, | ||||
|   shm_cc_ops::shm_cc_loop<scalar_t>( | ||||
|       ctx, elem_num, | ||||
|       [&](ThreadSHMContext* thread_ctx, int64_t data_offset, | ||||
|           int64_t data_elem_num) { | ||||
|           int64_t data_elem_num, bool fast_mode) { | ||||
|         int rank = thread_ctx->rank; | ||||
|         scalar_t* thread_shm_ptr = | ||||
|             thread_ctx->get_thread_shm_ptr<scalar_t>(rank); | ||||
|  | ||||
|         thread_ctx->barrier(ThreadSHMStat::THREAD_READY); | ||||
|  | ||||
|         shm_cc_ops::memcpy_to_shm(thread_shm_ptr, data + data_offset, | ||||
|                                   data_elem_num * sizeof(scalar_t)); | ||||
|  | ||||
|         thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY); | ||||
|         if (!fast_mode) { | ||||
|           thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict); | ||||
|         } | ||||
|  | ||||
|         shm_cc_ops::memcpy(thread_shm_ptr, data + data_offset, | ||||
|                            data_elem_num * sizeof(scalar_t)); | ||||
|         thread_ctx->commit_ready_stamp(); | ||||
|         if (rank == dst) { | ||||
|           shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset, | ||||
|                              data_elem_num * sizeof(scalar_t)); | ||||
| @ -508,12 +523,12 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num, | ||||
|             scalar_t* src_ptr = | ||||
|                 thread_ctx->get_thread_shm_ptr<scalar_t>(src_rank);  // shm | ||||
|             scalar_t* dst_ptr = outputs[src_rank] + data_offset; | ||||
|             shm_cc_ops::memcpy_from_shm(dst_ptr, src_ptr, | ||||
|                                         data_elem_num * sizeof(scalar_t)); | ||||
|             thread_ctx->wait_for_one(src_rank, | ||||
|                                      ThreadSHMContext::check_stamp_ready); | ||||
|             shm_cc_ops::memcpy(dst_ptr, src_ptr, | ||||
|                                data_elem_num * sizeof(scalar_t)); | ||||
|           } | ||||
|         } | ||||
|  | ||||
|         thread_ctx->barrier(ThreadSHMStat::DONE); | ||||
|       }); | ||||
|  | ||||
|   return; | ||||
| @ -599,7 +614,7 @@ struct TensorListMeta { | ||||
|   int8_t _padding[40]; | ||||
| }; | ||||
|  | ||||
| void shm_send_tensor_list_impl(ThreadSHMContext* ctx, | ||||
| void shm_send_tensor_list_impl(ThreadSHMContext* ctx, int64_t dst, | ||||
|                                const std::vector<torch::Tensor>& tensor_list) { | ||||
|   CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl) | ||||
|   std::vector<torch::Tensor> tensor_list_with_metadata; | ||||
| @ -620,12 +635,11 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx, | ||||
|   shm_cc_ops::shm_cc_loop<int8_t>( | ||||
|       ctx, metadata->total_bytes, | ||||
|       [&](ThreadSHMContext* thread_ctx, int64_t data_offset, | ||||
|           int64_t data_elem_num) { | ||||
|           int64_t data_elem_num, bool fast_mode) { | ||||
|         int rank = thread_ctx->rank; | ||||
|         // Wait until the receiver set the stat to DONE | ||||
|         thread_ctx->wait_for_one(rank, ThreadSHMStat::SHM_DATA_READY); | ||||
|  | ||||
|         int64_t curr_shm_offset = 0; | ||||
|         thread_ctx->wait_for_one(dst, | ||||
|                                  ThreadSHMContext::check_no_buffer_conflict); | ||||
|         while (curr_shm_offset < data_elem_num) { | ||||
|           MemPiece frag = metadata->get_data(data_offset + curr_shm_offset); | ||||
|           frag.size = std::min(frag.size, data_elem_num - curr_shm_offset); | ||||
| @ -634,8 +648,7 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx, | ||||
|               frag.ptr, frag.size); | ||||
|           curr_shm_offset += frag.size; | ||||
|         } | ||||
|  | ||||
|         thread_ctx->set_thread_stat(rank, ThreadSHMStat::SHM_DATA_READY); | ||||
|         thread_ctx->commit_ready_stamp(); | ||||
|       }); | ||||
| } | ||||
|  | ||||
| @ -646,8 +659,7 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx, | ||||
|   torch::Tensor metadata_tensor = | ||||
|       torch::empty({sizeof(TensorListMeta)}, options); | ||||
|  | ||||
|   // Wait until the sender set the stat of the thread 0 to SHM_DATA_READY | ||||
|   ctx->wait_for_one(src, ThreadSHMStat::DONE); | ||||
|   ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready); | ||||
|   shm_cc_ops::memcpy(metadata_tensor.data_ptr(), | ||||
|                      ctx->get_thread_shm_ptr<void>(src), | ||||
|                      sizeof(TensorListMeta)); | ||||
| @ -664,9 +676,8 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx, | ||||
|   shm_cc_ops::shm_cc_loop<int8_t>( | ||||
|       ctx, metadata.total_bytes, | ||||
|       [&](ThreadSHMContext* thread_ctx, int64_t data_offset, | ||||
|           int64_t data_elem_num) { | ||||
|         // Wait until the sender set the stat to SHM_DATA_READY | ||||
|         thread_ctx->wait_for_one(src, ThreadSHMStat::DONE); | ||||
|           int64_t data_elem_num, bool fast_mode) { | ||||
|         ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready); | ||||
|         int64_t curr_shm_offset = 0; | ||||
|         while (curr_shm_offset < data_elem_num) { | ||||
|           MemPiece frag = metadata.get_data(data_offset + curr_shm_offset); | ||||
| @ -677,8 +688,6 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx, | ||||
|               frag.size); | ||||
|           curr_shm_offset += frag.size; | ||||
|         } | ||||
|  | ||||
|         thread_ctx->set_thread_stat(src, ThreadSHMStat::DONE); | ||||
|       }); | ||||
|  | ||||
|   std::vector<torch::Tensor> tensor_list; | ||||
| @ -756,7 +765,8 @@ void shm_send_tensor_list(int64_t handle, | ||||
|                           int64_t dst) { | ||||
|   CPU_KERNEL_GUARD_IN(shm_send_tensor_list) | ||||
|   shm_send_tensor_list_impl( | ||||
|       SHMManager::get_singleton_instance(handle)->get_shm_ctx(), tensor_list); | ||||
|       SHMManager::get_singleton_instance(handle)->get_shm_ctx(), dst, | ||||
|       tensor_list); | ||||
|   CPU_KERNEL_GUARD_OUT(shm_send_tensor_list) | ||||
| } | ||||
|  | ||||
| @ -778,4 +788,4 @@ std::string join_shm_manager(int64_t handle, const std::string& name) { | ||||
|   TORCH_CHECK(shm_manager); | ||||
|   shm_manager->join(name); | ||||
|   return shm_manager->get_shm_ctx()->to_string(); | ||||
| } | ||||
| } | ||||
|  | ||||
| @ -50,6 +50,27 @@ void shm_send_tensor_list(int64_t handle, | ||||
|  | ||||
| std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src); | ||||
|  | ||||
| at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, | ||||
|                                 const std::optional<at::Tensor>& bias, | ||||
|                                 bool is_vnni); | ||||
|  | ||||
| at::Tensor convert_weight_packed(at::Tensor& weight); | ||||
|  | ||||
| at::Tensor fused_experts_cpu( | ||||
|     at::Tensor& hidden_states, at::Tensor& w1, at::Tensor& w2, | ||||
|     at::Tensor& topk_weights, at::Tensor& topk_ids, bool inplace, | ||||
|     bool use_int8_w8a8, bool use_fp8_w8a16, | ||||
|     const std::optional<at::Tensor>& w1_scale, | ||||
|     const std::optional<at::Tensor>& w2_scale, | ||||
|     const std::optional<std::vector<int64_t>> block_size, | ||||
|     const std::optional<at::Tensor>& a1_scale, | ||||
|     const std::optional<at::Tensor>& a2_scale, bool is_vnni); | ||||
|  | ||||
| at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, | ||||
|                                      at::Tensor& scales2, | ||||
|                                      const std::optional<at::Tensor>& bias, | ||||
|                                      at::ScalarType out_dtype, bool is_vnni); | ||||
|  | ||||
| TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | ||||
|   // vLLM custom ops | ||||
|  | ||||
| @ -131,16 +152,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | ||||
|  | ||||
|   // Quantization | ||||
| #ifdef __AVX512F__ | ||||
|   at::Tag stride_tag = at::Tag::needs_fixed_stride_order; | ||||
|   // Compute int8 quantized tensor for given scaling factor. | ||||
|   ops.def( | ||||
|       "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," | ||||
|       "Tensor? azp) -> ()"); | ||||
|       "Tensor? azp) -> ()", | ||||
|       {stride_tag}); | ||||
|   ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); | ||||
|  | ||||
|   // Compute int8 quantized tensor and scaling factor | ||||
|   ops.def( | ||||
|       "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " | ||||
|       "Tensor!? azp) -> ()"); | ||||
|       "Tensor!? azp) -> ()", | ||||
|       {stride_tag}); | ||||
|   ops.impl("dynamic_scaled_int8_quant", torch::kCPU, | ||||
|            &dynamic_scaled_int8_quant); | ||||
|   // W8A8 GEMM, supporting symmetric per-tensor or per-row/column | ||||
| @ -148,7 +172,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | ||||
|   ops.def( | ||||
|       "cutlass_scaled_mm(Tensor! out, Tensor a," | ||||
|       "                  Tensor b, Tensor a_scales," | ||||
|       "                  Tensor b_scales, Tensor? bias) -> ()"); | ||||
|       "                  Tensor b_scales, Tensor? bias) -> ()", | ||||
|       {stride_tag}); | ||||
|   ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); | ||||
|   // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column | ||||
|   // quantization. | ||||
| @ -156,7 +181,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | ||||
|       "cutlass_scaled_mm_azp(Tensor! out, Tensor a," | ||||
|       "                  Tensor b, Tensor a_scales," | ||||
|       "                  Tensor b_scales, Tensor azp_adj," | ||||
|       "                  Tensor? azp, Tensor? bias) -> ()"); | ||||
|       "                  Tensor? azp, Tensor? bias) -> ()", | ||||
|       {stride_tag}); | ||||
|   ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); | ||||
| #elif defined(__powerpc64__) | ||||
|   // Compute int8 quantized tensor for given scaling factor. | ||||
| @ -209,6 +235,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | ||||
|   ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)", | ||||
|           &shm_recv_tensor_list); | ||||
| #endif | ||||
|  | ||||
|   // sgl-kernels | ||||
| #if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__) | ||||
|   ops.def( | ||||
|       "weight_packed_linear(Tensor(a0!) mat1, Tensor(a1!) mat2, Tensor(a2!)? " | ||||
|       "bias, bool is_vnni) -> Tensor"); | ||||
|   ops.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear); | ||||
|   ops.def("convert_weight_packed(Tensor! weight) -> Tensor"); | ||||
|   ops.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed); | ||||
|   ops.def( | ||||
|       "fused_experts_cpu(Tensor! hidden_states, Tensor w1, Tensor w2, Tensor " | ||||
|       "topk_weights, Tensor topk_ids, bool inplace, bool use_int8_w8a8, bool " | ||||
|       "use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, SymInt[]? " | ||||
|       "block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> " | ||||
|       "Tensor"); | ||||
|   ops.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu); | ||||
|   ops.def( | ||||
|       "int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, " | ||||
|       "Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor"); | ||||
|   ops.impl("int8_scaled_mm_with_quant", torch::kCPU, | ||||
|            &int8_scaled_mm_with_quant); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { | ||||
|  | ||||
| @ -54,8 +54,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { | ||||
|     *(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp); | ||||
|     int page_num = numa_migrate_pages(pid, src_mask, mask); | ||||
|     if (page_num == -1) { | ||||
|       TORCH_CHECK(false, | ||||
|                   "numa_migrate_pages failed. errno: " + std::to_string(errno)); | ||||
|       TORCH_WARN("numa_migrate_pages failed. errno: " + std::to_string(errno)); | ||||
|     } | ||||
|  | ||||
|     // restrict memory allocation node. | ||||
| @ -105,4 +104,4 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { | ||||
|  | ||||
|   return ss.str(); | ||||
| } | ||||
| #endif | ||||
| #endif | ||||
|  | ||||
							
								
								
									
										114
									
								
								csrc/custom_quickreduce.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								csrc/custom_quickreduce.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,114 @@ | ||||
| #include <ATen/cuda/Exceptions.h> | ||||
| #include <c10/cuda/CUDAGuard.h> | ||||
| #include <c10/cuda/CUDAStream.h> | ||||
| #include <torch/all.h> | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
|  | ||||
|   #include "quickreduce/quick_reduce.h" | ||||
|  | ||||
| quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size, | ||||
|                                    std::optional<int64_t> qr_max_size) { | ||||
|   if (world_size > 8) | ||||
|     throw std::invalid_argument("world size > 8 is not supported"); | ||||
|   if (world_size == 6) | ||||
|     throw std::invalid_argument("world size == 6 is not supported"); | ||||
|   if (world_size % 2 != 0) | ||||
|     throw std::invalid_argument("Odd num gpus is not supported for now"); | ||||
|   if (rank < 0 || rank >= world_size) | ||||
|     throw std::invalid_argument("invalid rank passed in"); | ||||
|   quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms(); | ||||
|   fptr->init(world_size, rank, qr_max_size); | ||||
|   return (quickreduce::fptr_t)fptr; | ||||
| } | ||||
|  | ||||
| void qr_destroy(quickreduce::fptr_t _fa) { | ||||
|   if (_fa) { | ||||
|     auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa); | ||||
|     fa->destroy(); | ||||
|     delete fa; | ||||
|   } | ||||
| } | ||||
|  | ||||
| torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) { | ||||
|   auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa); | ||||
|   hipIpcMemHandle_t handle = fa->get_handle(); | ||||
|   auto options = | ||||
|       torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); | ||||
|   auto data_handle = | ||||
|       torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options); | ||||
|   std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t)); | ||||
|   return data_handle; | ||||
| } | ||||
|  | ||||
| void qr_open_handles(quickreduce::fptr_t _fa, | ||||
|                      const std::vector<torch::Tensor>& handles) { | ||||
|   auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa); | ||||
|   std::vector<hipIpcMemHandle_t> ipc_handles; | ||||
|   ipc_handles.reserve(handles.size()); | ||||
|   for (auto& handle : handles) { | ||||
|     // Ensure the tensor is on the same device as the current device. | ||||
|     hipIpcMemHandle_t ipc_handle; | ||||
|     std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t)); | ||||
|     ipc_handles.push_back(ipc_handle); | ||||
|   } | ||||
|   fa->open_ipc_handles(ipc_handles); | ||||
| } | ||||
|  | ||||
| void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp, | ||||
|                    torch::Tensor& out, int64_t quant_level, bool cast_bf2half) { | ||||
|   auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa); | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); | ||||
|   auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA(); | ||||
|  | ||||
|   TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); | ||||
|   TORCH_CHECK_EQ(inp.numel(), out.numel()); | ||||
|   TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize); | ||||
|   if (out.scalar_type() == at::ScalarType::Half) { | ||||
|     fa->allreduce<half, false>(reinterpret_cast<half*>(inp.data_ptr()), | ||||
|                                reinterpret_cast<half*>(out.data_ptr()), | ||||
|                                out.numel(), quant_level, stream); | ||||
|   } else if (out.scalar_type() == at::ScalarType::BFloat16) { | ||||
|     if (cast_bf2half) { | ||||
|       fa->allreduce<half, true>(reinterpret_cast<half*>(inp.data_ptr()), | ||||
|                                 reinterpret_cast<half*>(out.data_ptr()), | ||||
|                                 out.numel(), quant_level, stream); | ||||
|     } else { | ||||
|       fa->allreduce<quickreduce::nv_bfloat16, false>( | ||||
|           reinterpret_cast<quickreduce::nv_bfloat16*>(inp.data_ptr()), | ||||
|           reinterpret_cast<quickreduce::nv_bfloat16*>(out.data_ptr()), | ||||
|           out.numel(), quant_level, stream); | ||||
|     } | ||||
|   } else { | ||||
|     throw std::runtime_error( | ||||
|         "quick allreduce only supports float16 and bfloat16"); | ||||
|   } | ||||
| } | ||||
|  | ||||
| int64_t qr_max_size() { | ||||
|   // The default is 2GB (2,147,483,648 bytes) | ||||
|   return static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1; | ||||
| } | ||||
|  | ||||
|   #define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half)       \ | ||||
|     template struct quickreduce::AllReduceTwoshot<T, Codec<T, 2>, \ | ||||
|                                                   cast_bf2half>;  \ | ||||
|     template struct quickreduce::AllReduceTwoshot<T, Codec<T, 4>, \ | ||||
|                                                   cast_bf2half>;  \ | ||||
|     template struct quickreduce::AllReduceTwoshot<T, Codec<T, 8>, cast_bf2half>; | ||||
|  | ||||
| INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false) | ||||
| INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false) | ||||
| INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false) | ||||
| INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false) | ||||
| INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true) | ||||
| INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true) | ||||
| INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true) | ||||
| INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true) | ||||
|  | ||||
| INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false) | ||||
| INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false) | ||||
| INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false) | ||||
| INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false) | ||||
|  | ||||
| #endif  // USE_ROCM | ||||
| @ -185,9 +185,7 @@ void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, | ||||
|         params.conv_states_ptr = nullptr; | ||||
|     } | ||||
|  | ||||
|     // Otherwise the kernel will be launched from cuda:0 device | ||||
|     // Cast to char to avoid compiler warning about narrowing | ||||
|     at::cuda::CUDAGuard device_guard{(char)x.get_device()}; | ||||
|     const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); | ||||
|     auto stream = at::cuda::getCurrentCUDAStream().stream(); | ||||
|     DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { | ||||
|             causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream); | ||||
| @ -278,9 +276,7 @@ void causal_conv1d_update(const at::Tensor &x, | ||||
|         params.conv_state_indices_ptr = nullptr; | ||||
|     } | ||||
|  | ||||
|     // Otherwise the kernel will be launched from cuda:0 device | ||||
|     // Cast to char to avoid compiler warning about narrowing | ||||
|     at::cuda::CUDAGuard device_guard{(char)x.get_device()}; | ||||
|     const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); | ||||
|     auto stream = at::cuda::getCurrentCUDAStream().stream(); | ||||
|     DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { | ||||
|             causal_conv1d_update_cuda<input_t, weight_t>(params, stream); | ||||
|  | ||||
| @ -647,9 +647,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, | ||||
|                        ); | ||||
|  | ||||
|      | ||||
|     // Otherwise the kernel will be launched from cuda:0 device | ||||
|     // Cast to char to avoid compiler warning about narrowing | ||||
|     at::cuda::CUDAGuard device_guard{(char)u.get_device()}; | ||||
|     const at::cuda::OptionalCUDAGuard device_guard(device_of(u)); | ||||
|     auto stream = at::cuda::getCurrentCUDAStream().stream(); | ||||
|     DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { | ||||
|         selective_scan_fwd_cuda<input_t, weight_t>(params, stream); | ||||
|  | ||||
| @ -1255,8 +1255,6 @@ __global__ void Marlin( | ||||
|     if constexpr (has_zp && !is_zp_float) { | ||||
|       if (is_new_zp) { | ||||
|         if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; | ||||
|         FragB frag_zp_0; | ||||
|         FragB frag_zp_1; | ||||
|         int zp_quant_0, zp_quant_1; | ||||
|  | ||||
|         if constexpr (w_type.size_bits() == 4) { | ||||
|  | ||||
| @ -13,232 +13,45 @@ | ||||
| namespace vllm { | ||||
| namespace moe { | ||||
|  | ||||
| namespace { | ||||
| __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, | ||||
|                                          int32_t col) { | ||||
|   // don't worry about overflow because num_experts is relatively small | ||||
|   return row * total_col + col; | ||||
| } | ||||
| }  // namespace | ||||
|  | ||||
| template <typename scalar_t, typename token_cnts_t> | ||||
| __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, | ||||
|                                             int32_t* sorted_token_ids, | ||||
|                                             int32_t* expert_ids, | ||||
|                                             int32_t* total_tokens_post_pad, | ||||
|                                             int32_t num_experts, | ||||
|                                             int32_t block_size, size_t numel) { | ||||
|   const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); | ||||
|   const size_t start_idx = threadIdx.x * tokens_per_thread; | ||||
|  | ||||
|   extern __shared__ int32_t shared_mem[]; | ||||
|   int32_t* cumsum = shared_mem;  // 1d tensor with shape (num_experts + 1) | ||||
|   token_cnts_t* tokens_cnts = | ||||
|       (token_cnts_t*)(shared_mem + num_experts + | ||||
|                       1);  // 2d tensor with shape (blockDim.x + 1, num_experts) | ||||
|  | ||||
|   for (int i = 0; i < num_experts; ++i) { | ||||
|     tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * In the first step we compute token_cnts[thread_index + 1][expert_index], | ||||
|    * which counts how many tokens in the token shard of thread_index are | ||||
|    * assigned to expert expert_index. | ||||
|    */ | ||||
|   for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { | ||||
|     ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; | ||||
|   } | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   // For each expert we accumulate the token counts from the different threads. | ||||
|   if (threadIdx.x < num_experts) { | ||||
|     tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; | ||||
|     for (int i = 1; i <= blockDim.x; ++i) { | ||||
|       tokens_cnts[index(num_experts, i, threadIdx.x)] += | ||||
|           tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   // We accumulate the token counts of all experts in thread 0. | ||||
|   if (threadIdx.x == 0) { | ||||
|     cumsum[0] = 0; | ||||
|     for (int i = 1; i <= num_experts; ++i) { | ||||
|       cumsum[i] = cumsum[i - 1] + | ||||
|                   CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], | ||||
|                           block_size) * | ||||
|                       block_size; | ||||
|     } | ||||
|     *total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]); | ||||
|   } | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   /** | ||||
|    * For each expert, each thread processes the tokens of the corresponding | ||||
|    * blocks and stores the corresponding expert_id for each block. | ||||
|    */ | ||||
|   if (threadIdx.x < num_experts) { | ||||
|     for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; | ||||
|          i += block_size) { | ||||
|       expert_ids[i / block_size] = threadIdx.x; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Each thread processes a token shard, calculating the index of each token | ||||
|    * after sorting by expert number. Given the example topk_ids = | ||||
|    * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, | ||||
|    * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a | ||||
|    * padding value(preset in python). | ||||
|    */ | ||||
|   for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { | ||||
|     int32_t expert_id = topk_ids[i]; | ||||
|     /** The cumsum[expert_id] stores the starting index of the tokens that the | ||||
|      * expert with expert_id needs to process, and | ||||
|      * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens | ||||
|      * processed by the expert with expert_id within the current thread's token | ||||
|      * shard. | ||||
|      */ | ||||
|     int32_t rank_post_pad = | ||||
|         tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + | ||||
|         cumsum[expert_id]; | ||||
|     sorted_token_ids[rank_post_pad] = i; | ||||
|     ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; | ||||
|   } | ||||
| } | ||||
|  | ||||
| // TODO(simon): this is temporarily adapted from | ||||
| // https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7 | ||||
| // we did this to unblock Deepseek V3 but there should be a better | ||||
| // implementation to manage shared memory. | ||||
| template <typename scalar_t> | ||||
| __global__ void moe_align_block_size_global_mem_kernel( | ||||
|     scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, | ||||
|     int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, | ||||
|     int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) { | ||||
|   const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); | ||||
|   const size_t start_idx = threadIdx.x * tokens_per_thread; | ||||
| __global__ void moe_align_block_size_kernel( | ||||
|     const scalar_t* __restrict__ topk_ids, | ||||
|     int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, | ||||
|     int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, | ||||
|     int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, | ||||
|     size_t numel, int32_t* __restrict__ cumsum) { | ||||
|   extern __shared__ int32_t shared_counts[]; | ||||
|  | ||||
|   for (int i = 0; i < num_experts; ++i) { | ||||
|     tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * In the first step we compute token_cnts[thread_index + 1][expert_index], | ||||
|    * which counts how many tokens in the token shard of thread_index are | ||||
|    * assigned to expert expert_index. | ||||
|    */ | ||||
|   for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { | ||||
|     ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; | ||||
|   } | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   // For each expert we accumulate the token counts from the different threads. | ||||
|   if (threadIdx.x < num_experts) { | ||||
|     tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; | ||||
|     for (int i = 1; i <= blockDim.x; ++i) { | ||||
|       tokens_cnts[index(num_experts, i, threadIdx.x)] += | ||||
|           tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   // We accumulate the token counts of all experts in thread 0. | ||||
|   if (threadIdx.x == 0) { | ||||
|     cumsum[0] = 0; | ||||
|     for (int i = 1; i <= num_experts; ++i) { | ||||
|       cumsum[i] = cumsum[i - 1] + | ||||
|                   CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], | ||||
|                           block_size) * | ||||
|                       block_size; | ||||
|     } | ||||
|     *total_tokens_post_pad = cumsum[num_experts]; | ||||
|   } | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   /** | ||||
|    * For each expert, each thread processes the tokens of the corresponding | ||||
|    * blocks and stores the corresponding expert_id for each block. | ||||
|    */ | ||||
|   if (threadIdx.x < num_experts) { | ||||
|     for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; | ||||
|          i += block_size) { | ||||
|       expert_ids[i / block_size] = threadIdx.x; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|    * Each thread processes a token shard, calculating the index of each token | ||||
|    * after sorting by expert number. Given the example topk_ids = | ||||
|    * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, | ||||
|    * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a | ||||
|    * padding value(preset in python). | ||||
|    */ | ||||
|   for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { | ||||
|     int32_t expert_id = topk_ids[i]; | ||||
|     /** The cumsum[expert_id] stores the starting index of the tokens that the | ||||
|      * expert with expert_id needs to process, and | ||||
|      * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens | ||||
|      * processed by the expert with expert_id within the current thread's token | ||||
|      * shard. | ||||
|      */ | ||||
|     int32_t rank_post_pad = | ||||
|         tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + | ||||
|         cumsum[expert_id]; | ||||
|     sorted_token_ids[rank_post_pad] = i; | ||||
|     ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; | ||||
|   } | ||||
| } | ||||
|  | ||||
| // taken from | ||||
| // https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957 | ||||
| template <typename scalar_t> | ||||
| __global__ void sgl_moe_align_block_size_kernel( | ||||
|     scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, | ||||
|     int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, | ||||
|     int32_t block_size, size_t numel, int32_t* cumsum) { | ||||
|   __shared__ int32_t shared_counts[32][8]; | ||||
|  | ||||
|   const int warp_id = threadIdx.x / 32; | ||||
|   const int experts_per_warp = 8; | ||||
|   const int warp_id = threadIdx.x / WARP_SIZE; | ||||
|   const int my_expert_start = warp_id * experts_per_warp; | ||||
|  | ||||
|   // Initialize shared_counts for this warp's experts | ||||
|   for (int i = 0; i < experts_per_warp; ++i) { | ||||
|     if (my_expert_start + i < num_experts) { | ||||
|       shared_counts[warp_id][i] = 0; | ||||
|     if (my_expert_start + i < padded_num_experts) { | ||||
|       shared_counts[warp_id * experts_per_warp + i] = 0; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); | ||||
|   const size_t start_idx = threadIdx.x * tokens_per_thread; | ||||
|   const size_t tid = threadIdx.x; | ||||
|   const size_t stride = blockDim.x; | ||||
|  | ||||
|   for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { | ||||
|   for (size_t i = tid; i < numel; i += stride) { | ||||
|     int expert_id = topk_ids[i]; | ||||
|     int warp_idx = expert_id / experts_per_warp; | ||||
|     int expert_offset = expert_id % experts_per_warp; | ||||
|     atomicAdd(&shared_counts[warp_idx][expert_offset], 1); | ||||
|     atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1); | ||||
|   } | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   // Single thread computes cumulative sum and total tokens | ||||
|   if (threadIdx.x == 0) { | ||||
|     cumsum[0] = 0; | ||||
|     for (int i = 1; i <= num_experts; ++i) { | ||||
|       int expert_count = 0; | ||||
|       int warp_idx = (i - 1) / experts_per_warp; | ||||
|       int expert_offset = (i - 1) % experts_per_warp; | ||||
|       expert_count = shared_counts[warp_idx][expert_offset]; | ||||
|       expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset]; | ||||
|  | ||||
|       cumsum[i] = | ||||
|           cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; | ||||
| @ -248,7 +61,6 @@ __global__ void sgl_moe_align_block_size_kernel( | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   // Assign expert IDs to blocks | ||||
|   if (threadIdx.x < num_experts) { | ||||
|     for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; | ||||
|          i += block_size) { | ||||
| @ -257,13 +69,11 @@ __global__ void sgl_moe_align_block_size_kernel( | ||||
|   } | ||||
| } | ||||
|  | ||||
| // taken from | ||||
| // https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957 | ||||
| template <typename scalar_t> | ||||
| __global__ void sgl_moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, | ||||
|                                           int32_t* sorted_token_ids, | ||||
|                                           int32_t* cumsum_buffer, | ||||
|                                           size_t numel) { | ||||
| __global__ void count_and_sort_expert_tokens_kernel( | ||||
|     const scalar_t* __restrict__ topk_ids, | ||||
|     int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, | ||||
|     size_t numel) { | ||||
|   const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
|   const size_t stride = blockDim.x * gridDim.x; | ||||
|  | ||||
| @ -290,132 +100,138 @@ __global__ void moe_sum_kernel( | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| __global__ void moe_align_block_size_small_batch_expert_kernel( | ||||
|     const scalar_t* __restrict__ topk_ids, | ||||
|     int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, | ||||
|     int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, | ||||
|     int32_t block_size, size_t numel) { | ||||
|   const size_t tid = threadIdx.x; | ||||
|   const size_t stride = blockDim.x; | ||||
|  | ||||
|   extern __shared__ int32_t shared_mem[]; | ||||
|   int32_t* cumsum = shared_mem; | ||||
|   int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); | ||||
|  | ||||
|   for (int i = 0; i < num_experts; ++i) { | ||||
|     tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0; | ||||
|   } | ||||
|  | ||||
|   for (size_t i = tid; i < numel; i += stride) { | ||||
|     ++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]]; | ||||
|   } | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   if (threadIdx.x < num_experts) { | ||||
|     tokens_cnts[threadIdx.x] = 0; | ||||
|     for (int i = 1; i <= blockDim.x; ++i) { | ||||
|       tokens_cnts[i * num_experts + threadIdx.x] += | ||||
|           tokens_cnts[(i - 1) * num_experts + threadIdx.x]; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   if (threadIdx.x == 0) { | ||||
|     cumsum[0] = 0; | ||||
|     for (int i = 1; i <= num_experts; ++i) { | ||||
|       cumsum[i] = | ||||
|           cumsum[i - 1] + | ||||
|           CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) * | ||||
|               block_size; | ||||
|     } | ||||
|     *total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]); | ||||
|   } | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   if (threadIdx.x < num_experts) { | ||||
|     for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; | ||||
|          i += block_size) { | ||||
|       expert_ids[i / block_size] = threadIdx.x; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   for (size_t i = tid; i < numel; i += stride) { | ||||
|     int32_t expert_id = topk_ids[i]; | ||||
|     int32_t rank_post_pad = | ||||
|         tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id]; | ||||
|     sorted_token_ids[rank_post_pad] = i; | ||||
|     ++tokens_cnts[threadIdx.x * num_experts + expert_id]; | ||||
|   } | ||||
| } | ||||
|  | ||||
| }  // namespace moe | ||||
| }  // namespace vllm | ||||
|  | ||||
| // taken from | ||||
| // https://github.com/sgl-project/sglang/blob/8b5f83ed3b7d2a49ad5c5cd5aa61c5d502f47dbc | ||||
| void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, | ||||
|                           int64_t block_size, torch::Tensor sorted_token_ids, | ||||
|                           torch::Tensor experts_ids, | ||||
|                           torch::Tensor num_tokens_post_pad) { | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|  | ||||
|   int device_max_shared_mem; | ||||
|   auto dev = topk_ids.get_device(); | ||||
|   cudaDeviceGetAttribute(&device_max_shared_mem, | ||||
|                          cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); | ||||
|  | ||||
|   const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); | ||||
|   const int32_t shared_mem_i32 = | ||||
|       ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); | ||||
|   const int32_t shared_mem_i16 = | ||||
|       ((num_thread + 1) * num_experts) * sizeof(uint16_t) + | ||||
|       (num_experts + 1) * sizeof(int32_t); | ||||
|  | ||||
|   bool use_global_memory = false; | ||||
|   bool use_i16 = false;  // Use uint16_t for shared memory token counts | ||||
|   if (shared_mem_i32 < device_max_shared_mem) { | ||||
|     // Do nothing in this case. We're all set to use int32_t token counts | ||||
|   } else if (shared_mem_i16 < device_max_shared_mem && | ||||
|              topk_ids.numel() <= 65535) { | ||||
|     // when nelements of topk_ids is smaller than 65535 (max value of uint16), | ||||
|     // element value of token_cnts would also smaller than 65535, | ||||
|     // so we can use uint16 as dtype of token_cnts | ||||
|     use_i16 = true; | ||||
|   } else { | ||||
|     use_global_memory = true; | ||||
|   } | ||||
|  | ||||
|   if (use_global_memory) { | ||||
|     VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( | ||||
|         topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { | ||||
|           // calc needed amount of shared mem for `tokens_cnts` and `cumsum` | ||||
|           // tensors | ||||
|           const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); | ||||
|  | ||||
|           auto options_int = torch::TensorOptions() | ||||
|                                  .dtype(torch::kInt) | ||||
|                                  .device(topk_ids.device()); | ||||
|           torch::Tensor token_cnts_buffer = | ||||
|               torch::empty({(num_experts + 1) * num_experts}, options_int); | ||||
|           torch::Tensor cumsum_buffer = | ||||
|               torch::empty({num_experts + 1}, options_int); | ||||
|  | ||||
|           auto kernel = | ||||
|               vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>; | ||||
|           kernel<<<1, num_thread, 0, stream>>>( | ||||
|               topk_ids.data_ptr<scalar_t>(), | ||||
|               sorted_token_ids.data_ptr<int32_t>(), | ||||
|               experts_ids.data_ptr<int32_t>(), | ||||
|               num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, | ||||
|               topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(), | ||||
|               cumsum_buffer.data_ptr<int32_t>()); | ||||
|         }); | ||||
|   } else if (use_i16) { | ||||
|     VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( | ||||
|         topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { | ||||
|           // set dynamic shared mem | ||||
|           auto kernel = | ||||
|               vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>; | ||||
|           AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( | ||||
|               (void*)kernel, shared_mem_i16)); | ||||
|           kernel<<<1, num_thread, shared_mem_i16, stream>>>( | ||||
|               topk_ids.data_ptr<scalar_t>(), | ||||
|               sorted_token_ids.data_ptr<int32_t>(), | ||||
|               experts_ids.data_ptr<int32_t>(), | ||||
|               num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, | ||||
|               topk_ids.numel()); | ||||
|         }); | ||||
|   } else { | ||||
|     VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( | ||||
|         topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { | ||||
|           auto kernel = | ||||
|               vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>; | ||||
|           AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( | ||||
|               (void*)kernel, shared_mem_i32)); | ||||
|           kernel<<<1, num_thread, shared_mem_i32, stream>>>( | ||||
|               topk_ids.data_ptr<scalar_t>(), | ||||
|               sorted_token_ids.data_ptr<int32_t>(), | ||||
|               experts_ids.data_ptr<int32_t>(), | ||||
|               num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, | ||||
|               topk_ids.numel()); | ||||
|         }); | ||||
|   } | ||||
| } | ||||
|  | ||||
| void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, | ||||
|                               int64_t block_size, | ||||
|                               torch::Tensor sorted_token_ids, | ||||
|                               torch::Tensor experts_ids, | ||||
|                               torch::Tensor num_tokens_post_pad) { | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   TORCH_CHECK(num_experts == 256, | ||||
|               "sgl_moe_align_block_size kernel only supports deepseek v3."); | ||||
|   int64_t padded_num_experts = | ||||
|       ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; | ||||
|   int experts_per_warp = WARP_SIZE; | ||||
|   int threads = 1024; | ||||
|   threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; | ||||
|  | ||||
|   VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( | ||||
|       topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { | ||||
|       topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { | ||||
|         // calc needed amount of shared mem for `cumsum` tensors | ||||
|         auto options_int = | ||||
|             torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); | ||||
|         torch::Tensor cumsum_buffer = | ||||
|             torch::zeros({num_experts + 1}, options_int); | ||||
|         bool small_batch_expert_mode = | ||||
|             (topk_ids.numel() < 1024) && (num_experts <= 64); | ||||
|  | ||||
|         auto align_kernel = | ||||
|             vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>; | ||||
|         align_kernel<<<1, 1024, 0, stream>>>( | ||||
|             topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(), | ||||
|             experts_ids.data_ptr<int32_t>(), | ||||
|             num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, | ||||
|             topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>()); | ||||
|         if (small_batch_expert_mode) { | ||||
|           const int32_t threads = max((int32_t)num_experts, WARP_SIZE); | ||||
|           const int32_t shared_mem_size = | ||||
|               ((threads + 1) * num_experts + (num_experts + 1)) * | ||||
|               sizeof(int32_t); | ||||
|  | ||||
|         const int block_threads = 256; | ||||
|         const int num_blocks = | ||||
|             (topk_ids.numel() + block_threads - 1) / block_threads; | ||||
|         const int max_blocks = 65535; | ||||
|         const int actual_blocks = std::min(num_blocks, max_blocks); | ||||
|         auto sort_kernel = vllm::moe::sgl_moe_token_sort_kernel<scalar_t>; | ||||
|         sort_kernel<<<actual_blocks, block_threads, 0, stream>>>( | ||||
|             topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(), | ||||
|             cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel()); | ||||
|           auto small_batch_expert_kernel = | ||||
|               vllm::moe::moe_align_block_size_small_batch_expert_kernel< | ||||
|                   scalar_t>; | ||||
|           small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>( | ||||
|               topk_ids.data_ptr<scalar_t>(), | ||||
|               sorted_token_ids.data_ptr<int32_t>(), | ||||
|               experts_ids.data_ptr<int32_t>(), | ||||
|               num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size, | ||||
|               topk_ids.numel()); | ||||
|         } else { | ||||
|           auto align_kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>; | ||||
|  | ||||
|           size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); | ||||
|           size_t shared_mem_size = | ||||
|               num_warps * experts_per_warp * sizeof(int32_t); | ||||
|  | ||||
|           align_kernel<<<1, threads, shared_mem_size, stream>>>( | ||||
|               topk_ids.data_ptr<scalar_t>(), | ||||
|               sorted_token_ids.data_ptr<int32_t>(), | ||||
|               experts_ids.data_ptr<int32_t>(), | ||||
|               num_tokens_post_pad.data_ptr<int32_t>(), num_experts, | ||||
|               padded_num_experts, experts_per_warp, block_size, | ||||
|               topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>()); | ||||
|  | ||||
|           const int block_threads = std::min(256, (int)threads); | ||||
|           const int num_blocks = | ||||
|               (topk_ids.numel() + block_threads - 1) / block_threads; | ||||
|           const int max_blocks = 65535; | ||||
|           const int actual_blocks = std::min(num_blocks, max_blocks); | ||||
|  | ||||
|           auto sort_kernel = | ||||
|               vllm::moe::count_and_sort_expert_tokens_kernel<scalar_t>; | ||||
|           sort_kernel<<<actual_blocks, block_threads, 0, stream>>>( | ||||
|               topk_ids.data_ptr<scalar_t>(), | ||||
|               sorted_token_ids.data_ptr<int32_t>(), | ||||
|               cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel()); | ||||
|         } | ||||
|       }); | ||||
| } | ||||
|  | ||||
| @ -423,7 +239,7 @@ void moe_sum(torch::Tensor& input,   // [num_tokens, topk, hidden_size] | ||||
|              torch::Tensor& output)  // [num_tokens, hidden_size] | ||||
| { | ||||
|   const int hidden_size = input.size(-1); | ||||
|   const int num_tokens = output.numel() / hidden_size; | ||||
|   const auto num_tokens = output.numel() / hidden_size; | ||||
|   const int topk = input.size(1); | ||||
|  | ||||
|   dim3 grid(num_tokens); | ||||
|  | ||||
| @ -12,12 +12,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, | ||||
|                           int64_t block_size, torch::Tensor sorted_token_ids, | ||||
|                           torch::Tensor experts_ids, | ||||
|                           torch::Tensor num_tokens_post_pad); | ||||
|  | ||||
| void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, | ||||
|                               int64_t block_size, | ||||
|                               torch::Tensor sorted_token_ids, | ||||
|                               torch::Tensor experts_ids, | ||||
|                               torch::Tensor num_tokens_post_pad); | ||||
| #ifndef USE_ROCM | ||||
| torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, | ||||
|                              torch::Tensor b_qweight, torch::Tensor b_scales, | ||||
|  | ||||
| @ -12,7 +12,7 @@ void moe_permute( | ||||
|     const torch::Tensor& input,                      // [n_token, hidden] | ||||
|     const torch::Tensor& topk_weights,               //[n_token, topk] | ||||
|     torch::Tensor& topk_ids,                         // [n_token, topk] | ||||
|     const torch::Tensor& token_expert_indicies,      // [n_token, topk] | ||||
|     const torch::Tensor& token_expert_indices,       // [n_token, topk] | ||||
|     const std::optional<torch::Tensor>& expert_map,  // [n_expert] | ||||
|     int64_t n_expert, int64_t n_local_expert, int64_t topk, | ||||
|     const std::optional<int64_t>& align_block_size, | ||||
| @ -27,15 +27,15 @@ void moe_permute( | ||||
|               "expert_first_token_offset must be int64"); | ||||
|   TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, | ||||
|               "topk_ids must be int32"); | ||||
|   TORCH_CHECK(token_expert_indicies.scalar_type() == at::ScalarType::Int, | ||||
|               "token_expert_indicies must be int32"); | ||||
|   TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int, | ||||
|               "token_expert_indices must be int32"); | ||||
|   TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int, | ||||
|               "src_row_id2dst_row_id_map must be int32"); | ||||
|   TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1, | ||||
|               "expert_first_token_offset shape != n_local_expert+1") | ||||
|   TORCH_CHECK( | ||||
|       src_row_id2dst_row_id_map.sizes() == token_expert_indicies.sizes(), | ||||
|       "token_expert_indicies shape must be same as src_row_id2dst_row_id_map"); | ||||
|       src_row_id2dst_row_id_map.sizes() == token_expert_indices.sizes(), | ||||
|       "token_expert_indices shape must be same as src_row_id2dst_row_id_map"); | ||||
|   auto n_token = input.sizes()[0]; | ||||
|   auto n_hidden = input.sizes()[1]; | ||||
|   auto align_block_size_value = | ||||
| @ -71,7 +71,7 @@ void moe_permute( | ||||
|                              expert_map_ptr, n_expert, stream); | ||||
|   } | ||||
|   // expert sort topk expert id and scan expert id get expert_first_token_offset | ||||
|   sortAndScanExpert(get_ptr<int>(topk_ids), get_ptr<int>(token_expert_indicies), | ||||
|   sortAndScanExpert(get_ptr<int>(topk_ids), get_ptr<int>(token_expert_indices), | ||||
|                     get_ptr<int>(permuted_experts_id), | ||||
|                     get_ptr<int>(dst_row_id2src_row_id_map), | ||||
|                     get_ptr<int64_t>(expert_first_token_offset), n_token, | ||||
| @ -190,7 +190,7 @@ void shuffle_rows(const torch::Tensor& input_tensor, | ||||
|  | ||||
| void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, | ||||
|                  torch::Tensor& topk_ids, | ||||
|                  const torch::Tensor& token_expert_indicies, | ||||
|                  const torch::Tensor& token_expert_indices, | ||||
|                  const std::optional<torch::Tensor>& expert_map, | ||||
|                  int64_t n_expert, int64_t n_local_expert, int64_t topk, | ||||
|                  const std::optional<int64_t>& align_block_size, | ||||
| @ -203,7 +203,7 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, | ||||
|  | ||||
| void moe_unpermute(const torch::Tensor& input, | ||||
|                    const torch::Tensor& topk_weights, torch::Tensor& topk_ids, | ||||
|                    const torch::Tensor& token_expert_indicies, | ||||
|                    const torch::Tensor& token_expert_indices, | ||||
|                    const std::optional<torch::Tensor>& expert_map, | ||||
|                    int64_t n_expert, int64_t n_local_expert, int64_t topk, | ||||
|                    const std::optional<int64_t>& align_block_size, | ||||
|  | ||||
| @ -20,7 +20,6 @@ __global__ void expandInputRowsKernel( | ||||
|   int expert_id = sorted_experts[expanded_dest_row]; | ||||
|  | ||||
|   extern __shared__ int64_t smem_expert_first_token_offset[]; | ||||
|   int64_t align_expanded_row_accumulate = 0; | ||||
|   if constexpr (ALIGN_BLOCK_SIZE) { | ||||
|     // load g2s | ||||
|     for (int idx = threadIdx.x; idx < num_local_experts + 1; | ||||
| @ -63,7 +62,6 @@ __global__ void expandInputRowsKernel( | ||||
|     using DataElem = cutlass::Array<T, ELEM_PER_THREAD>; | ||||
|  | ||||
|     // Duplicate and permute rows | ||||
|     int64_t const source_k_rank = expanded_source_row / num_rows; | ||||
|     int64_t const source_row = expanded_source_row % num_rows; | ||||
|  | ||||
|     auto const* source_row_ptr = | ||||
| @ -160,7 +158,6 @@ __global__ void finalizeMoeRoutingKernel( | ||||
|        elem_index += stride) { | ||||
|     ComputeElem thread_output; | ||||
|     thread_output.fill(0); | ||||
|     float row_rescale{0.f}; | ||||
|     for (int k_idx = 0; k_idx < k; ++k_idx) { | ||||
|       int64_t const expanded_original_row = original_row + k_idx * num_rows; | ||||
|       int64_t const expanded_permuted_row = | ||||
| @ -177,8 +174,6 @@ __global__ void finalizeMoeRoutingKernel( | ||||
|       auto const* expanded_permuted_rows_row_ptr = | ||||
|           expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; | ||||
|  | ||||
|       int64_t const expert_idx = expert_for_source_row[k_offset]; | ||||
|  | ||||
|       ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>( | ||||
|           expanded_permuted_rows_row_ptr[elem_index]); | ||||
|       thread_output = thread_output + row_scale * (expert_result); | ||||
|  | ||||
| @ -425,7 +425,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f | ||||
|  | ||||
| #define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB)                       \ | ||||
|     topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>(         \ | ||||
|         gating_output, nullptr, topk_weights, topk_indicies,            \ | ||||
|         gating_output, nullptr, topk_weights, topk_indices,            \ | ||||
|         token_expert_indices, num_tokens, topk, 0, num_experts,         \ | ||||
|         stream); | ||||
|  | ||||
| @ -433,7 +433,7 @@ template <typename IndType> | ||||
| void topkGatingSoftmaxKernelLauncher( | ||||
|     const float* gating_output, | ||||
|     float* topk_weights, | ||||
|     IndType* topk_indicies, | ||||
|     IndType* topk_indices, | ||||
|     int* token_expert_indices, | ||||
|     float* softmax_workspace, | ||||
|     const int num_tokens, | ||||
| @ -476,7 +476,7 @@ void topkGatingSoftmaxKernelLauncher( | ||||
|             moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>( | ||||
|                 gating_output, nullptr, softmax_workspace, num_experts); | ||||
|             moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>( | ||||
|                 softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices, | ||||
|                 softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices, | ||||
|                 num_experts, topk, 0, num_experts); | ||||
|         } | ||||
|     } | ||||
| @ -492,7 +492,7 @@ void topk_softmax( | ||||
|     torch::Tensor& gating_output)               // [num_tokens, num_experts] | ||||
| { | ||||
|     const int num_experts = gating_output.size(-1); | ||||
|     const int num_tokens = gating_output.numel() / num_experts; | ||||
|     const auto num_tokens = gating_output.numel() / num_experts; | ||||
|     const int topk = topk_weights.size(-1); | ||||
|  | ||||
|     const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); | ||||
|  | ||||
| @ -22,15 +22,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { | ||||
|       "                     Tensor! num_tokens_post_pad) -> ()"); | ||||
|   m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); | ||||
|  | ||||
|   // temporarily adapted from | ||||
|   // https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a | ||||
|   m.def( | ||||
|       "sgl_moe_align_block_size(Tensor topk_ids, int num_experts," | ||||
|       "                         int block_size, Tensor! sorted_token_ids," | ||||
|       "                         Tensor! experts_ids," | ||||
|       "                         Tensor! num_tokens_post_pad) -> ()"); | ||||
|   m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size); | ||||
|  | ||||
| #ifndef USE_ROCM | ||||
|   m.def( | ||||
|       "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " | ||||
| @ -66,7 +57,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { | ||||
|  | ||||
|   m.def( | ||||
|       "moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids," | ||||
|       "Tensor token_expert_indicies, Tensor? expert_map, int n_expert," | ||||
|       "Tensor token_expert_indices, Tensor? expert_map, int n_expert," | ||||
|       "int n_local_expert," | ||||
|       "int topk, int? align_block_size,Tensor! permuted_input, Tensor! " | ||||
|       "expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! " | ||||
|  | ||||
							
								
								
									
										11
									
								
								csrc/ops.h
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								csrc/ops.h
									
									
									
									
									
								
							| @ -360,3 +360,14 @@ std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle( | ||||
|     int64_t size); | ||||
| int64_t open_mem_handle(torch::Tensor& mem_handle); | ||||
| void free_shared_buffer(int64_t buffer); | ||||
|  | ||||
| #ifdef USE_ROCM | ||||
| fptr_t init_custom_qr(int64_t rank, int64_t world_size, | ||||
|                       std::optional<int64_t> qr_max_size = std::nullopt); | ||||
| void qr_destroy(fptr_t _fa); | ||||
| torch::Tensor qr_get_handle(fptr_t _fa); | ||||
| void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles); | ||||
| void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, | ||||
|                    int64_t quant_level, bool cast_bf2half = false); | ||||
| int64_t qr_max_size(); | ||||
| #endif | ||||
| @ -274,7 +274,6 @@ void advance_step_flashinfer( | ||||
|   cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); | ||||
|   cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev); | ||||
|  | ||||
|   [[maybe_unused]] int block_tables_stride = block_tables.stride(0); | ||||
|   TORCH_CHECK((blocks * threads > num_queries), | ||||
|               "multi-step: not enough threads to map to num_queries = ", | ||||
|               num_queries, " block_tables.stride(0) = ", block_tables.stride(0), | ||||
|  | ||||
| @ -1,15 +1,17 @@ | ||||
| #include <ATen/cuda/CUDAContext.h> | ||||
| #include <torch/all.h> | ||||
|  | ||||
| #include <cmath> | ||||
|  | ||||
| #include "../../dispatch_utils.h" | ||||
| #include "../vectorization_utils.cuh" | ||||
|  | ||||
| #ifndef USE_ROCM | ||||
|   #include <cub/util_type.cuh> | ||||
|   #include <cub/cub.cuh> | ||||
|   #include <cub/util_type.cuh> | ||||
| #else | ||||
|   #include <hipcub/util_type.hpp> | ||||
|   #include <hipcub/hipcub.hpp> | ||||
|   #include <hipcub/util_type.hpp> | ||||
| #endif | ||||
|  | ||||
| static inline __device__ int8_t float_to_int8_rn(float x) { | ||||
| @ -103,134 +105,170 @@ static inline __device__ int8_t int32_to_int8(int32_t x) { | ||||
|  | ||||
| namespace vllm { | ||||
|  | ||||
| template <typename scalar_t, typename scale_type> | ||||
| template <typename scalar_t, typename scale_t> | ||||
| __global__ void static_scaled_int8_quant_kernel( | ||||
|     scalar_t const* __restrict__ input, int8_t* __restrict__ out, | ||||
|     scale_type const* scale_ptr, const int hidden_size) { | ||||
|   int const tid = threadIdx.x; | ||||
|   int64_t const token_idx = blockIdx.x; | ||||
|   scale_type const scale = *scale_ptr; | ||||
|     const scalar_t* __restrict__ input, int8_t* __restrict__ output, | ||||
|     const scale_t* scale_ptr, const int hidden_size) { | ||||
|   const int tid = threadIdx.x; | ||||
|   const int stride = blockDim.x; | ||||
|   const int64_t token_idx = blockIdx.x; | ||||
|   const float scale = *scale_ptr; | ||||
|  | ||||
|   // Must be performed using 64-bit math to avoid integer overflow. | ||||
|   out += token_idx * hidden_size; | ||||
|   input += token_idx * hidden_size; | ||||
|   const scalar_t* row_in = input + token_idx * hidden_size; | ||||
|   int8_t* row_out = output + token_idx * hidden_size; | ||||
|  | ||||
|   for (int i = tid; i < hidden_size; i += blockDim.x) { | ||||
|     out[i] = float_to_int8_rn(static_cast<float>(input[i]) / scale); | ||||
|   } | ||||
|   vectorize_with_alignment<16>( | ||||
|       row_in, row_out, hidden_size, tid, stride, | ||||
|       [=] __device__(int8_t& dst, const scalar_t& src) { | ||||
|         dst = float_to_int8_rn(static_cast<float>(src) / scale); | ||||
|       }); | ||||
| } | ||||
|  | ||||
| template <typename scalar_t, typename scale_type, typename azp_type> | ||||
| template <typename scalar_t, typename scale_t, typename azp_t> | ||||
| __global__ void static_scaled_int8_azp_quant_kernel( | ||||
|     scalar_t const* __restrict__ input, int8_t* __restrict__ out, | ||||
|     scale_type const* scale_ptr, azp_type const* azp_ptr, | ||||
|     const int hidden_size) { | ||||
|   int const tid = threadIdx.x; | ||||
|   int64_t const token_idx = blockIdx.x; | ||||
|   scale_type const scale = *scale_ptr; | ||||
|   azp_type const azp = *azp_ptr; | ||||
|     const scalar_t* __restrict__ input, int8_t* __restrict__ output, | ||||
|     const scale_t* scale_ptr, const azp_t* azp_ptr, const int hidden_size) { | ||||
|   const int tid = threadIdx.x; | ||||
|   const int stride = blockDim.x; | ||||
|   const int64_t token_idx = blockIdx.x; | ||||
|   const float scale = *scale_ptr; | ||||
|   const azp_t azp = *azp_ptr; | ||||
|   const float inv_s = 1.0f / scale; | ||||
|  | ||||
|   // Must be performed using 64-bit math to avoid integer overflow. | ||||
|   out += token_idx * hidden_size; | ||||
|   input += token_idx * hidden_size; | ||||
|   const scalar_t* row_in = input + token_idx * hidden_size; | ||||
|   int8_t* row_out = output + token_idx * hidden_size; | ||||
|  | ||||
|   for (int i = tid; i < hidden_size; i += blockDim.x) { | ||||
|     auto const val = static_cast<float>(input[i]); | ||||
|     auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); | ||||
|     out[i] = quant_val; | ||||
|   } | ||||
|   vectorize_with_alignment<16>( | ||||
|       row_in, row_out, hidden_size, tid, stride, | ||||
|       [=] __device__(int8_t& dst, const scalar_t& src) { | ||||
|         const auto v = static_cast<float>(src) * inv_s; | ||||
|         dst = int32_to_int8(float_to_int32_rn(v) + azp); | ||||
|       }); | ||||
| } | ||||
|  | ||||
| template <typename scalar_t, typename scale_type> | ||||
| template <typename scalar_t, typename scale_t> | ||||
| __global__ void dynamic_scaled_int8_quant_kernel( | ||||
|     scalar_t const* __restrict__ input, int8_t* __restrict__ out, | ||||
|     scale_type* scale, const int hidden_size) { | ||||
|   int const tid = threadIdx.x; | ||||
|   int64_t const token_idx = blockIdx.x; | ||||
|   float absmax_val = 0.0f; | ||||
|   float const zero = 0.0f; | ||||
|     const scalar_t* __restrict__ input, int8_t* __restrict__ output, | ||||
|     scale_t* scale_out, const int hidden_size) { | ||||
|   const int tid = threadIdx.x; | ||||
|   const int stride = blockDim.x; | ||||
|   const int64_t token_idx = blockIdx.x; | ||||
|  | ||||
|   // Must be performed using 64-bit math to avoid integer overflow. | ||||
|   out += token_idx * hidden_size; | ||||
|   input += token_idx * hidden_size; | ||||
|   const scalar_t* row_in = input + token_idx * hidden_size; | ||||
|   int8_t* row_out = output + token_idx * hidden_size; | ||||
|  | ||||
|   for (int i = tid; i < hidden_size; i += blockDim.x) { | ||||
|     float val = static_cast<float>(input[i]); | ||||
|     val = val > zero ? val : -val; | ||||
|     absmax_val = val > absmax_val ? val : absmax_val; | ||||
|   // calculate for absmax | ||||
|   float thread_max = 0.f; | ||||
|   for (int i = tid; i < hidden_size; i += stride) { | ||||
|     const auto v = fabsf(static_cast<float>(row_in[i])); | ||||
|     thread_max = fmaxf(thread_max, v); | ||||
|   } | ||||
|  | ||||
|   using BlockReduce = cub::BlockReduce<float, 1024>; | ||||
|   __shared__ typename BlockReduce::TempStorage reduceStorage; | ||||
|   float const block_absmax_val_maybe = | ||||
|       BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); | ||||
|   __shared__ float block_absmax_val; | ||||
|   using BlockReduce = cub::BlockReduce<float, 256>; | ||||
|   __shared__ typename BlockReduce::TempStorage tmp; | ||||
|   float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x); | ||||
|   __shared__ float absmax; | ||||
|   if (tid == 0) { | ||||
|     block_absmax_val = block_absmax_val_maybe; | ||||
|     scale[token_idx] = block_absmax_val / 127.0f; | ||||
|     absmax = block_max; | ||||
|     scale_out[blockIdx.x] = absmax / 127.f; | ||||
|   } | ||||
|   __syncthreads(); | ||||
|  | ||||
|   float const tmp_scale = 127.0f / block_absmax_val; | ||||
|   for (int i = tid; i < hidden_size; i += blockDim.x) { | ||||
|     out[i] = float_to_int8_rn(static_cast<float>(input[i]) * tmp_scale); | ||||
|   } | ||||
|   float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax; | ||||
|  | ||||
|   // 2. quantize | ||||
|   vectorize_with_alignment<16>( | ||||
|       row_in, row_out, hidden_size, tid, stride, | ||||
|       [=] __device__(int8_t& dst, const scalar_t& src) { | ||||
|         dst = float_to_int8_rn(static_cast<float>(src) * inv_s); | ||||
|       }); | ||||
| } | ||||
|  | ||||
| template <typename scalar_t, typename scale_type, typename azp_type> | ||||
| // MinMax structure to hold min and max values in one go | ||||
| struct MinMax { | ||||
|   float min, max; | ||||
|  | ||||
|   __host__ __device__ MinMax() | ||||
|       : min(std::numeric_limits<float>::max()), | ||||
|         max(std::numeric_limits<float>::lowest()) {} | ||||
|  | ||||
|   __host__ __device__ explicit MinMax(float v) : min(v), max(v) {} | ||||
|  | ||||
|   // add a value to the MinMax | ||||
|   __host__ __device__ MinMax& operator+=(float v) { | ||||
|     min = fminf(min, v); | ||||
|     max = fmaxf(max, v); | ||||
|     return *this; | ||||
|   } | ||||
|  | ||||
|   // merge two MinMax objects | ||||
|   __host__ __device__ MinMax& operator&=(const MinMax& other) { | ||||
|     min = fminf(min, other.min); | ||||
|     max = fmaxf(max, other.max); | ||||
|     return *this; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| __host__ __device__ inline MinMax operator+(MinMax a, float v) { | ||||
|   return a += v; | ||||
| } | ||||
| __host__ __device__ inline MinMax operator&(MinMax a, const MinMax& b) { | ||||
|   return a &= b; | ||||
| } | ||||
|  | ||||
| template <typename scalar_t, typename scale_t, typename azp_t> | ||||
| __global__ void dynamic_scaled_int8_azp_quant_kernel( | ||||
|     scalar_t const* __restrict__ input, int8_t* __restrict__ out, | ||||
|     scale_type* scale, azp_type* azp, const int hidden_size) { | ||||
|   int64_t const token_idx = blockIdx.x; | ||||
|     const scalar_t* __restrict__ input, int8_t* __restrict__ output, | ||||
|     scale_t* scale_out, azp_t* azp_out, const int hidden_size) { | ||||
|   const int tid = threadIdx.x; | ||||
|   const int stride = blockDim.x; | ||||
|   const int64_t token_idx = blockIdx.x; | ||||
|  | ||||
|   // Must be performed using 64-bit math to avoid integer overflow. | ||||
|   out += token_idx * hidden_size; | ||||
|   input += token_idx * hidden_size; | ||||
|   const scalar_t* row_in = input + token_idx * hidden_size; | ||||
|   int8_t* row_out = output + token_idx * hidden_size; | ||||
|  | ||||
|   // Scan for the min and max value for this token | ||||
|   float max_val = std::numeric_limits<float>::min(); | ||||
|   float min_val = std::numeric_limits<float>::max(); | ||||
|   for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { | ||||
|     auto val = static_cast<float>(input[i]); | ||||
|     max_val = std::max(max_val, val); | ||||
|     min_val = std::min(min_val, val); | ||||
|   // 1. calculate min & max | ||||
|   MinMax thread_mm; | ||||
|   for (int i = tid; i < hidden_size; i += stride) { | ||||
|     thread_mm += static_cast<float>(row_in[i]); | ||||
|   } | ||||
|  | ||||
|   // Reduce the max and min values across the block | ||||
|   using BlockReduce = cub::BlockReduce<float, 1024>; | ||||
|   __shared__ typename BlockReduce::TempStorage reduceStorage; | ||||
|   max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x); | ||||
|   __syncthreads();  // Make sure min doesn't mess with max shared memory | ||||
|   min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x); | ||||
|   using BlockReduce = cub::BlockReduce<MinMax, 256>; | ||||
|   __shared__ typename BlockReduce::TempStorage tmp; | ||||
|  | ||||
|   __shared__ scale_type scale_sh; | ||||
|   __shared__ azp_type azp_sh; | ||||
|   MinMax mm = BlockReduce(tmp).Reduce( | ||||
|       thread_mm, | ||||
|       [] __device__(MinMax a, const MinMax& b) { | ||||
|         a &= b; | ||||
|         return a; | ||||
|       }, | ||||
|       blockDim.x); | ||||
|  | ||||
|   // Compute the scale and zero point and store them, only on the first thread | ||||
|   if (threadIdx.x == 0) { | ||||
|     float const scale_val = (max_val - min_val) / 255.0f; | ||||
|     // Use rounding to even (same as torch.round) | ||||
|     auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val); | ||||
|     auto const azp_val = static_cast<azp_type>(azp_float); | ||||
|  | ||||
|     // Store the scale and azp into shared and global | ||||
|     scale[token_idx] = scale_sh = scale_val; | ||||
|     azp[token_idx] = azp_sh = azp_val; | ||||
|   __shared__ float scale_sh; | ||||
|   __shared__ azp_t azp_sh; | ||||
|   if (tid == 0) { | ||||
|     float s = (mm.max - mm.min) / 255.f; | ||||
|     float zp = nearbyintf(-128.f - mm.min / s);  // round-to-even | ||||
|     scale_sh = s; | ||||
|     azp_sh = azp_t(zp); | ||||
|     scale_out[blockIdx.x] = s; | ||||
|     azp_out[blockIdx.x] = azp_sh; | ||||
|   } | ||||
|  | ||||
|   // Wait for the scale and azp to be computed | ||||
|   __syncthreads(); | ||||
|  | ||||
|   float const scale_val = scale_sh; | ||||
|   azp_type const azp_val = azp_sh; | ||||
|   const float inv_s = 1.f / scale_sh; | ||||
|   const azp_t azp = azp_sh; | ||||
|  | ||||
|   // Quantize the values | ||||
|   for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { | ||||
|     auto const val = static_cast<float>(input[i]); | ||||
|     auto const quant_val = | ||||
|         int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); | ||||
|     out[i] = quant_val; | ||||
|   } | ||||
|   // 2. quantize | ||||
|   vectorize_with_alignment<16>( | ||||
|       row_in, row_out, hidden_size, tid, stride, | ||||
|       [=] __device__(int8_t& dst, const scalar_t& src) { | ||||
|         const auto v = static_cast<float>(src) * inv_s; | ||||
|         dst = int32_to_int8(float_to_int32_rn(v) + azp); | ||||
|       }); | ||||
| } | ||||
|  | ||||
| }  // namespace vllm | ||||
| @ -247,7 +285,7 @@ void static_scaled_int8_quant(torch::Tensor& out,          // [..., hidden_size] | ||||
|   int const hidden_size = input.size(-1); | ||||
|   int const num_tokens = input.numel() / hidden_size; | ||||
|   dim3 const grid(num_tokens); | ||||
|   dim3 const block(std::min(hidden_size, 1024)); | ||||
|   dim3 const block(std::min(hidden_size, 256)); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|       input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { | ||||
| @ -278,7 +316,7 @@ void dynamic_scaled_int8_quant( | ||||
|   int const hidden_size = input.size(-1); | ||||
|   int const num_tokens = input.numel() / hidden_size; | ||||
|   dim3 const grid(num_tokens); | ||||
|   dim3 const block(std::min(hidden_size, 1024)); | ||||
|   dim3 const block(std::min(hidden_size, 256)); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|       input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { | ||||
|  | ||||
| @ -144,4 +144,65 @@ struct cutlass_3x_gemm_sm100 { | ||||
|       Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>; | ||||
| }; | ||||
|  | ||||
| template <typename ElementAB_, typename ElementD_, | ||||
|           template <typename, typename, typename> typename Epilogue_, | ||||
|           typename TileShape, typename ClusterShape, typename KernelSchedule, | ||||
|           typename EpilogueSchedule> | ||||
| struct cutlass_3x_gemm_sm120 { | ||||
|   using ElementAB = ElementAB_; | ||||
|   using LayoutA = cutlass::layout::RowMajor; | ||||
|   static constexpr int AlignmentA = | ||||
|       128 / cutlass::sizeof_bits<ElementAB>::value; | ||||
|  | ||||
|   using LayoutB = cutlass::layout::ColumnMajor; | ||||
|   static constexpr int AlignmentB = | ||||
|       128 / cutlass::sizeof_bits<ElementAB>::value; | ||||
|  | ||||
|   using ElementC = void; | ||||
|   using LayoutC = cutlass::layout::RowMajor; | ||||
|   static constexpr int AlignmentC = | ||||
|       128 / cutlass::sizeof_bits<ElementD_>::value; | ||||
|  | ||||
|   using ElementD = ElementD_; | ||||
|   using LayoutD = cutlass::layout::RowMajor; | ||||
|   static constexpr int AlignmentD = AlignmentC; | ||||
|  | ||||
|   using ElementAcc = | ||||
|       typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t, | ||||
|                                 float>::type; | ||||
|   using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>; | ||||
|  | ||||
|   // MMA type | ||||
|   using ElementAccumulator = float; | ||||
|  | ||||
|   // Epilogue types | ||||
|   using ElementBias = cutlass::half_t; | ||||
|   using ElementCompute = float; | ||||
|   using ElementAux = ElementD; | ||||
|   using LayoutAux = LayoutD; | ||||
|   using ElementAmax = float; | ||||
|  | ||||
|   using EVTCompute = typename Epilogue::EVTCompute; | ||||
|  | ||||
|   using CollectiveEpilogue = | ||||
|       typename cutlass::epilogue::collective::CollectiveBuilder< | ||||
|           cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, TileShape, | ||||
|           ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, | ||||
|           ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, | ||||
|           ElementD, LayoutD, AlignmentD, EpilogueSchedule, | ||||
|           EVTCompute>::CollectiveOp; | ||||
|  | ||||
|   using CollectiveMainloop = | ||||
|       typename cutlass::gemm::collective::CollectiveBuilder< | ||||
|           cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, ElementAB, | ||||
|           LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB, | ||||
|           ElementAccumulator, TileShape, ClusterShape, | ||||
|           cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( | ||||
|               sizeof(typename CollectiveEpilogue::SharedStorage))>, | ||||
|           KernelSchedule>::CollectiveOp; | ||||
|  | ||||
|   using GemmKernel = cutlass::gemm::kernel::GemmUniversal< | ||||
|       Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>; | ||||
| }; | ||||
|  | ||||
| }  // namespace vllm | ||||
|  | ||||
| @ -36,6 +36,12 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, | ||||
|                                  torch::Tensor const& b_scales, | ||||
|                                  std::optional<torch::Tensor> const& bias); | ||||
|  | ||||
| void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a, | ||||
|                                  torch::Tensor const& b, | ||||
|                                  torch::Tensor const& a_scales, | ||||
|                                  torch::Tensor const& b_scales, | ||||
|                                  std::optional<torch::Tensor> const& bias); | ||||
|  | ||||
| void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, | ||||
|                                            torch::Tensor const& a, | ||||
|                                            torch::Tensor const& b, | ||||
|  | ||||
| @ -15,11 +15,11 @@ using c3x::cutlass_gemm_caller; | ||||
| template <typename InType, typename OutType, | ||||
|           template <typename, typename, typename> typename Epilogue> | ||||
| struct sm100_fp8_config_default { | ||||
|   // M in (128, inf) | ||||
|   // M in (256, inf) | ||||
|   static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); | ||||
|   using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; | ||||
|   using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; | ||||
|   using TileShape = Shape<_256, _128, _64>; | ||||
|   using TileShape = Shape<_256, _128, _128>; | ||||
|   using ClusterShape = Shape<_2, _2, _1>; | ||||
|   using Cutlass3xGemm = | ||||
|       cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape, | ||||
| @ -28,13 +28,13 @@ struct sm100_fp8_config_default { | ||||
|  | ||||
| template <typename InType, typename OutType, | ||||
|           template <typename, typename, typename> typename Epilogue> | ||||
| struct sm100_fp8_config_M128 { | ||||
|   // M in (64, 128] | ||||
| struct sm100_fp8_config_M256 { | ||||
|   // M in (64, 256] | ||||
|   static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); | ||||
|   using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; | ||||
|   using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; | ||||
|   using TileShape = Shape<_128, _128, _64>; | ||||
|   using ClusterShape = Shape<_2, _2, _1>; | ||||
|   using TileShape = Shape<_128, _128, _128>; | ||||
|   using ClusterShape = Shape<_2, _1, _1>; | ||||
|   using Cutlass3xGemm = | ||||
|       cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape, | ||||
|                             KernelSchedule, EpilogueSchedule>; | ||||
| @ -43,12 +43,26 @@ struct sm100_fp8_config_M128 { | ||||
| template <typename InType, typename OutType, | ||||
|           template <typename, typename, typename> typename Epilogue> | ||||
| struct sm100_fp8_config_M64 { | ||||
|   // M in [1, 64] | ||||
|   // M in (16, 64] | ||||
|   static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); | ||||
|   using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; | ||||
|   using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; | ||||
|   using TileShape = Shape<_64, _64, _256>; | ||||
|   using ClusterShape = Shape<_1, _8, _1>; | ||||
|   using TileShape = Shape<_64, _64, _128>; | ||||
|   using ClusterShape = Shape<_1, _1, _1>; | ||||
|   using Cutlass3xGemm = | ||||
|       cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape, | ||||
|                             KernelSchedule, EpilogueSchedule>; | ||||
| }; | ||||
|  | ||||
| template <typename InType, typename OutType, | ||||
|           template <typename, typename, typename> typename Epilogue> | ||||
| struct sm100_fp8_config_M16 { | ||||
|   // M in [1, 16] | ||||
|   static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); | ||||
|   using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; | ||||
|   using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; | ||||
|   using TileShape = Shape<_64, _64, _128>; | ||||
|   using ClusterShape = Shape<_1, _4, _1>; | ||||
|   using Cutlass3xGemm = | ||||
|       cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape, | ||||
|                             KernelSchedule, EpilogueSchedule>; | ||||
| @ -68,25 +82,31 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, | ||||
|   using Cutlass3xGemmDefault = | ||||
|       typename sm100_fp8_config_default<InType, OutType, | ||||
|                                         Epilogue>::Cutlass3xGemm; | ||||
|   using Cutlass3xGemmM16 = | ||||
|       typename sm100_fp8_config_M16<InType, OutType, Epilogue>::Cutlass3xGemm; | ||||
|   using Cutlass3xGemmM64 = | ||||
|       typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm; | ||||
|   using Cutlass3xGemmM128 = | ||||
|       typename sm100_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm; | ||||
|   using Cutlass3xGemmM256 = | ||||
|       typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm; | ||||
|  | ||||
|   uint32_t const m = a.size(0); | ||||
|   uint32_t const mp2 = | ||||
|       std::max(static_cast<uint32_t>(64), next_pow_2(m));  // next power of 2 | ||||
|       std::max(static_cast<uint32_t>(16), next_pow_2(m));  // next power of 2 | ||||
|  | ||||
|   if (mp2 <= 64) { | ||||
|     // m in [1, 64] | ||||
|   if (mp2 <= 16) { | ||||
|     // m in [1, 16] | ||||
|     return cutlass_gemm_caller<Cutlass3xGemmM16>( | ||||
|         out, a, b, std::forward<EpilogueArgs>(args)...); | ||||
|   } else if (mp2 <= 64) { | ||||
|     // m in (16, 64] | ||||
|     return cutlass_gemm_caller<Cutlass3xGemmM64>( | ||||
|         out, a, b, std::forward<EpilogueArgs>(args)...); | ||||
|   } else if (mp2 <= 128) { | ||||
|     // m in (64, 128] | ||||
|     return cutlass_gemm_caller<Cutlass3xGemmM128>( | ||||
|   } else if (mp2 <= 256) { | ||||
|     // m in (64, 256] | ||||
|     return cutlass_gemm_caller<Cutlass3xGemmM256>( | ||||
|         out, a, b, std::forward<EpilogueArgs>(args)...); | ||||
|   } else { | ||||
|     // m in (128, inf) | ||||
|     // m in (256, inf) | ||||
|     return cutlass_gemm_caller<Cutlass3xGemmDefault>( | ||||
|         out, a, b, std::forward<EpilogueArgs>(args)...); | ||||
|   } | ||||
|  | ||||
							
								
								
									
										24
									
								
								csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,24 @@ | ||||
| #include "scaled_mm_kernels.hpp" | ||||
| #include "scaled_mm_sm120_fp8_dispatch.cuh" | ||||
| #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" | ||||
|  | ||||
| namespace vllm { | ||||
|  | ||||
| void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a, | ||||
|                                  torch::Tensor const& b, | ||||
|                                  torch::Tensor const& a_scales, | ||||
|                                  torch::Tensor const& b_scales, | ||||
|                                  std::optional<torch::Tensor> const& bias) { | ||||
|   TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); | ||||
|   if (bias) { | ||||
|     TORCH_CHECK(bias->dtype() == out.dtype(), | ||||
|                 "currently bias dtype must match output dtype ", out.dtype()); | ||||
|     return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>( | ||||
|         out, a, b, a_scales, b_scales, *bias); | ||||
|   } else { | ||||
|     return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>( | ||||
|         out, a, b, a_scales, b_scales); | ||||
|   } | ||||
| } | ||||
|  | ||||
| }  // namespace vllm | ||||
| @ -0,0 +1,67 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include "scaled_mm.cuh" | ||||
| #include "cutlass_gemm_caller.cuh" | ||||
|  | ||||
| /** | ||||
|  * This file defines Gemm kernel configurations for SM120 (fp8) based on the | ||||
|  * Gemm shape. | ||||
|  */ | ||||
|  | ||||
| namespace vllm { | ||||
|  | ||||
| using c3x::cutlass_gemm_caller; | ||||
|  | ||||
| template <typename InType, typename OutType, | ||||
|           template <typename, typename, typename> typename Epilogue> | ||||
| struct sm120_fp8_config_default { | ||||
|   static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); | ||||
|   using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; | ||||
|   using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; | ||||
|   using TileShape = Shape<_128, _128, _128>; | ||||
|   using ClusterShape = Shape<_1, _1, _1>;  // Only work with Shape<_1, _1, _1> | ||||
|   using Cutlass3xGemm = | ||||
|       cutlass_3x_gemm_sm120<InType, OutType, Epilogue, TileShape, ClusterShape, | ||||
|                             KernelSchedule, EpilogueSchedule>; | ||||
| }; | ||||
|  | ||||
| template <typename InType, typename OutType, | ||||
|           template <typename, typename, typename> typename Epilogue, | ||||
|           typename... EpilogueArgs> | ||||
| inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out, | ||||
|                                             torch::Tensor const& a, | ||||
|                                             torch::Tensor const& b, | ||||
|                                             EpilogueArgs&&... args) { | ||||
|   static_assert(std::is_same<InType, cutlass::float_e4m3_t>()); | ||||
|   TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); | ||||
|   TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); | ||||
|  | ||||
|   using Cutlass3xGemmDefault = | ||||
|       typename sm120_fp8_config_default<InType, OutType, | ||||
|                                         Epilogue>::Cutlass3xGemm; | ||||
|   return cutlass_gemm_caller<Cutlass3xGemmDefault>( | ||||
|       out, a, b, std::forward<EpilogueArgs>(args)...); | ||||
| } | ||||
|  | ||||
| template <template <typename, typename, typename> typename Epilogue, | ||||
|           typename... EpilogueArgs> | ||||
| void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out, | ||||
|                                           torch::Tensor const& a, | ||||
|                                           torch::Tensor const& b, | ||||
|                                           EpilogueArgs&&... epilogue_args) { | ||||
|   TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); | ||||
|   TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); | ||||
|  | ||||
|   if (out.dtype() == torch::kBFloat16) { | ||||
|     return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t, | ||||
|                                            cutlass::bfloat16_t, Epilogue>( | ||||
|         out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); | ||||
|   } else { | ||||
|     TORCH_CHECK(out.dtype() == torch::kFloat16); | ||||
|     return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t, | ||||
|                                            cutlass::half_t, Epilogue>( | ||||
|         out, a, b, std::forward<EpilogueArgs>(epilogue_args)...); | ||||
|   } | ||||
| } | ||||
|  | ||||
| }  // namespace vllm | ||||
							
								
								
									
										34
									
								
								csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,34 @@ | ||||
| #include <cudaTypedefs.h> | ||||
| #include "c3x/scaled_mm_kernels.hpp" | ||||
|  | ||||
| #include "cuda_utils.h" | ||||
|  | ||||
| /* | ||||
|    This file defines quantized GEMM operations using the CUTLASS 3.x API, for | ||||
|    NVIDIA GPUs with sm120 (Blackwell Geforce). | ||||
| */ | ||||
|  | ||||
| #if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120 | ||||
|  | ||||
| void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a, | ||||
|                              torch::Tensor const& b, | ||||
|                              torch::Tensor const& a_scales, | ||||
|                              torch::Tensor const& b_scales, | ||||
|                              std::optional<torch::Tensor> const& bias) { | ||||
|   TORCH_CHECK(a_scales.dtype() == torch::kFloat32); | ||||
|   TORCH_CHECK(b_scales.dtype() == torch::kFloat32); | ||||
|  | ||||
|   int M = a.size(0), N = b.size(1), K = a.size(1); | ||||
|   TORCH_CHECK( | ||||
|       (a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && | ||||
|           (b_scales.numel() == 1 || b_scales.numel() == b.size(1)), | ||||
|       "Currently, block scaled fp8 gemm is not implemented for Blackwell"); | ||||
|  | ||||
|   // Standard per-tensor/per-token/per-channel scaling | ||||
|   TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); | ||||
|   TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn, | ||||
|               "Currently, only fp8 gemm is implemented for Blackwell"); | ||||
|   vllm::cutlass_scaled_mm_sm120_fp8(c, a, b, a_scales, b_scales, bias); | ||||
| } | ||||
|  | ||||
| #endif | ||||
| @ -41,6 +41,14 @@ void cutlass_moe_mm_sm90( | ||||
|  | ||||
| #endif | ||||
|  | ||||
| #if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120 | ||||
| void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a, | ||||
|                              torch::Tensor const& b, | ||||
|                              torch::Tensor const& a_scales, | ||||
|                              torch::Tensor const& b_scales, | ||||
|                              std::optional<torch::Tensor> const& bias); | ||||
| #endif | ||||
|  | ||||
| #if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100 | ||||
| void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, | ||||
|                              torch::Tensor const& b, | ||||
| @ -168,8 +176,15 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, | ||||
|   at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); | ||||
|   int32_t version_num = get_sm_version_num(); | ||||
|  | ||||
| #if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120 | ||||
|   if (version_num >= 120) { | ||||
|     cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias); | ||||
|     return; | ||||
|   } | ||||
| #endif | ||||
|  | ||||
| #if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100 | ||||
|   if (version_num >= 100) { | ||||
|   if (version_num >= 100 && version_num < 120) { | ||||
|     cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias); | ||||
|     return; | ||||
|   } | ||||
| @ -241,7 +256,7 @@ void get_cutlass_moe_mm_data( | ||||
|   // mm to run it for. | ||||
|   int32_t version_num = get_sm_version_num(); | ||||
| #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ | ||||
|     (defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90) | ||||
|     (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) | ||||
|   get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, | ||||
|                                  problem_sizes2, input_permutation, | ||||
|                                  output_permutation, num_experts, n, k, | ||||
| @ -252,7 +267,7 @@ void get_cutlass_moe_mm_data( | ||||
|       false, | ||||
|       "No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for " | ||||
|       "CUDA device capability: ", | ||||
|       version_num, ". Required capability: 90"); | ||||
|       version_num, ". Required capability: 90 or 100"); | ||||
| } | ||||
|  | ||||
| void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, | ||||
| @ -265,7 +280,8 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, | ||||
|   // This function currently gets compiled only if we have a valid cutlass moe | ||||
|   // mm to run it for. | ||||
|   int32_t version_num = get_sm_version_num(); | ||||
| #if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 | ||||
| #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ | ||||
|     (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) | ||||
|   get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1, | ||||
|                                       problem_sizes2, expert_num_tokens, | ||||
|                                       num_local_experts, padded_m, n, k); | ||||
| @ -275,7 +291,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, | ||||
|       false, | ||||
|       "No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel " | ||||
|       "for CUDA device capability: ", | ||||
|       version_num, ". Required capability: 90"); | ||||
|       version_num, ". Required capability: 90 or 100"); | ||||
| } | ||||
|  | ||||
| void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, | ||||
|  | ||||
| @ -335,8 +335,10 @@ void run_fp4_blockwise_scaled_group_mm( | ||||
|   TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); | ||||
| } | ||||
|  | ||||
| #if defined ENABLE_NVFP4 && ENABLE_NVFP4 | ||||
| constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; | ||||
| constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; | ||||
| #endif | ||||
|  | ||||
| #define CHECK_TYPE(x, st, m) \ | ||||
|   TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m) | ||||
|  | ||||
| @ -231,12 +231,115 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, | ||||
| } | ||||
|  | ||||
| // Use UE4M3 by default. | ||||
| template <class Type, bool UE8M0_SF = false> | ||||
| template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false> | ||||
| __global__ void | ||||
| #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) | ||||
| __launch_bounds__(512, 4) cvt_fp16_to_fp4( | ||||
| #else | ||||
| cvt_fp16_to_fp4( | ||||
| #endif | ||||
|     int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, | ||||
|     uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, | ||||
|     uint32_t* output_scale_offset_by_experts, int n_experts, bool low_latency) { | ||||
| #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) | ||||
|   using PackedVec = PackedVec<Type>; | ||||
|   static constexpr int CVT_FP4_NUM_THREADS_PER_SF = | ||||
|       (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); | ||||
|   static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, | ||||
|                 "Vec size is not matched."); | ||||
|  | ||||
|   int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
|   int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; | ||||
|  | ||||
|   // Each global thread processes one element | ||||
|   for (int globalIdx = tid; globalIdx < numRows * colsPerRow; | ||||
|        globalIdx += gridDim.x * blockDim.x) { | ||||
|     // Calculate which row and column this global thread should process | ||||
|     int rowIdx = globalIdx / colsPerRow; | ||||
|     int colIdx = globalIdx % colsPerRow; | ||||
|  | ||||
|     int64_t inOffset = rowIdx * colsPerRow + colIdx; | ||||
|     PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset]; | ||||
|     // Get the output tensor offset. | ||||
|     // Same as inOffset because 8 elements are packed into one uint32_t. | ||||
|     int64_t outOffset = inOffset; | ||||
|     auto& out_pos = out[outOffset]; | ||||
|  | ||||
|     // Find index within the experts using different strategies based on expert | ||||
|     // count | ||||
|     int rowIdx_in_expert = 0; | ||||
|     int expert_idx = 0; | ||||
|  | ||||
|     if constexpr (SMALL_NUM_EXPERTS) { | ||||
|       for (int i = 0; i < n_experts; i++) { | ||||
|         uint32_t current_offset = __ldca(&input_offset_by_experts[i]); | ||||
|         uint32_t next_offset = __ldca(&input_offset_by_experts[i + 1]); | ||||
|         if (rowIdx >= current_offset && rowIdx < next_offset) { | ||||
|           rowIdx_in_expert = rowIdx - current_offset; | ||||
|           expert_idx = i; | ||||
|           break; | ||||
|         } | ||||
|       } | ||||
|     } else { | ||||
|       // Load input offsets into registers first, then do the computation. | ||||
|       // Local array size set to 17 because of register limit. | ||||
|       uint32_t local_offsets[17]; | ||||
|       for (int chunk_start = 0; chunk_start < n_experts; chunk_start += 16) { | ||||
|         *reinterpret_cast<int4*>(local_offsets) = | ||||
|             __ldca(reinterpret_cast<const int4*>( | ||||
|                 &input_offset_by_experts[chunk_start])); | ||||
|         *reinterpret_cast<int4*>(local_offsets + 4) = | ||||
|             __ldca(reinterpret_cast<const int4*>( | ||||
|                 &input_offset_by_experts[chunk_start + 4])); | ||||
|         *reinterpret_cast<int4*>(local_offsets + 8) = | ||||
|             __ldca(reinterpret_cast<const int4*>( | ||||
|                 &input_offset_by_experts[chunk_start + 8])); | ||||
|         *reinterpret_cast<int4*>(local_offsets + 12) = | ||||
|             __ldca(reinterpret_cast<const int4*>( | ||||
|                 &input_offset_by_experts[chunk_start + 12])); | ||||
|         local_offsets[16] = __ldca(&input_offset_by_experts[chunk_start + 16]); | ||||
|  | ||||
|   // Check against the 16 loaded offsets | ||||
|   #pragma unroll | ||||
|         for (int i = 0; i < 16; i++) { | ||||
|           if (rowIdx >= local_offsets[i] && rowIdx < local_offsets[i + 1]) { | ||||
|             rowIdx_in_expert = rowIdx - local_offsets[i]; | ||||
|             expert_idx = chunk_start + i; | ||||
|             break; | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     // Get the global scaling factor, which will be applied to the SF. | ||||
|     // Note SFScale is the same as next GEMM's alpha, which is | ||||
|     // (448.f / (Alpha_A / 6.f)). | ||||
|     float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; | ||||
|  | ||||
|     int factor = CVT_FP4_SF_VEC_SIZE * 4; | ||||
|     // The actual output_scales dim is computed from the padded numCols. | ||||
|     int32_t numCols_padded = (numCols + factor - 1) / factor * factor; | ||||
|     int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4; | ||||
|     uint32_t* SFout_in_expert = | ||||
|         SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout; | ||||
|  | ||||
|     auto sf_out = | ||||
|         cvt_quant_to_fp4_get_sf_out_offset<uint32_t, | ||||
|                                            CVT_FP4_NUM_THREADS_PER_SF>( | ||||
|             rowIdx_in_expert, colIdx, numCols, SFout_in_expert); | ||||
|  | ||||
|     out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out); | ||||
|   } | ||||
| #endif | ||||
| } | ||||
|  | ||||
| // Kernel for LARGE_M_TOPK = true (large m_topk optimized version) | ||||
| template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false> | ||||
| __global__ void | ||||
| #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) | ||||
| __launch_bounds__(1024, 4) cvt_fp16_to_fp4( | ||||
| #else | ||||
| cvt_fp16_to_fp4( | ||||
| #endif | ||||
|     int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, | ||||
|     uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, | ||||
| @ -247,50 +350,80 @@ cvt_fp16_to_fp4( | ||||
|       (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); | ||||
|   static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, | ||||
|                 "Vec size is not matched."); | ||||
|   extern __shared__ uint32_t shared_input_offsets[]; | ||||
|  | ||||
|   // Input tensor row/col loops. | ||||
|   for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { | ||||
|     for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; | ||||
|          colIdx += blockDim.x) { | ||||
|       int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; | ||||
|       PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset]; | ||||
|       // Get the output tensor offset. | ||||
|       // Same as inOffset because 8 elements are packed into one uint32_t. | ||||
|       int64_t outOffset = inOffset; | ||||
|       auto& out_pos = out[outOffset]; | ||||
|  | ||||
|       // Find index within the experts. | ||||
|       int rowIdx_in_expert = 0; | ||||
|       int expert_idx = 0; | ||||
|       for (int i = 0; i < n_experts; i++) { | ||||
|         if (rowIdx >= input_offset_by_experts[i] && | ||||
|             rowIdx < input_offset_by_experts[i + 1]) { | ||||
|           rowIdx_in_expert = rowIdx - input_offset_by_experts[i]; | ||||
|           expert_idx = i; | ||||
|           break; | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       // Get the global scaling factor, which will be applied to the SF. | ||||
|       // Note SFScale is the same as next GEMM's alpha, which is | ||||
|       // (448.f / (Alpha_A / 6.f)). | ||||
|       float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; | ||||
|  | ||||
|       int factor = CVT_FP4_SF_VEC_SIZE * 4; | ||||
|       // The actual output_scales dim is computed from the padded numCols. | ||||
|       int32_t numCols_padded = (numCols + factor - 1) / factor * factor; | ||||
|       int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4; | ||||
|       uint32_t* SFout_in_expert = | ||||
|           SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout; | ||||
|  | ||||
|       auto sf_out = | ||||
|           cvt_quant_to_fp4_get_sf_out_offset<uint32_t, | ||||
|                                              CVT_FP4_NUM_THREADS_PER_SF>( | ||||
|               rowIdx_in_expert, colIdx, numCols, SFout_in_expert); | ||||
|  | ||||
|       out_pos = | ||||
|           cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out); | ||||
|   // Load input offsets into shared memory. | ||||
|   // If n_experts is larger than 4, use vectorized int4 to save instructions. | ||||
|   // If n_experts is smaller than 4, read directly. | ||||
|   if constexpr (SMALL_NUM_EXPERTS) { | ||||
|     for (int i = threadIdx.x; i < n_experts + 1; i += blockDim.x) { | ||||
|       shared_input_offsets[i] = input_offset_by_experts[i]; | ||||
|     } | ||||
|   } else { | ||||
|     for (int i = threadIdx.x * 4; i < n_experts; i += blockDim.x * 4) { | ||||
|       *reinterpret_cast<int4*>(&shared_input_offsets[i]) = | ||||
|           *reinterpret_cast<const int4*>(&input_offset_by_experts[i]); | ||||
|     } | ||||
|     if (threadIdx.x == 0) { | ||||
|       shared_input_offsets[n_experts] = input_offset_by_experts[n_experts]; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
|   int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; | ||||
|  | ||||
|   // Each global thread processes one element | ||||
|   for (int globalIdx = tid; globalIdx < numRows * colsPerRow; | ||||
|        globalIdx += gridDim.x * blockDim.x) { | ||||
|     // Calculate which row and column this global thread should process | ||||
|     int rowIdx = globalIdx / colsPerRow; | ||||
|     int colIdx = globalIdx % colsPerRow; | ||||
|  | ||||
|     int64_t inOffset = rowIdx * colsPerRow + colIdx; | ||||
|     PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset]; | ||||
|     int64_t outOffset = inOffset; | ||||
|     auto& out_pos = out[outOffset]; | ||||
|  | ||||
|     // Find expert using binary search for better performance with large m_topk | ||||
|     int rowIdx_in_expert = 0; | ||||
|     int expert_idx = 0; | ||||
|  | ||||
|     // Binary search through experts using shared memory | ||||
|     int left = 0, right = n_experts - 1; | ||||
|     while (left <= right) { | ||||
|       int mid = (left + right) / 2; | ||||
|       // Get offsets: shared_input_offsets[i] corresponds to | ||||
|       // input_offset_by_experts[i] | ||||
|       uint32_t mid_offset = shared_input_offsets[mid]; | ||||
|       uint32_t next_offset = shared_input_offsets[mid + 1]; | ||||
|  | ||||
|       if (rowIdx >= mid_offset && rowIdx < next_offset) { | ||||
|         rowIdx_in_expert = rowIdx - mid_offset; | ||||
|         expert_idx = mid; | ||||
|         break; | ||||
|       } else if (rowIdx < mid_offset) { | ||||
|         right = mid - 1; | ||||
|       } else { | ||||
|         left = mid + 1; | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; | ||||
|  | ||||
|     int factor = CVT_FP4_SF_VEC_SIZE * 4; | ||||
|     int32_t numCols_padded = (numCols + factor - 1) / factor * factor; | ||||
|     int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4; | ||||
|     uint32_t* SFout_in_expert = | ||||
|         SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout; | ||||
|  | ||||
|     auto sf_out = | ||||
|         cvt_quant_to_fp4_get_sf_out_offset<uint32_t, | ||||
|                                            CVT_FP4_NUM_THREADS_PER_SF>( | ||||
|             rowIdx_in_expert, colIdx, numCols, SFout_in_expert); | ||||
|  | ||||
|     out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out); | ||||
|   } | ||||
| #endif | ||||
| } | ||||
| @ -309,18 +442,63 @@ void quant_impl(void* output, void* output_scale, void* input, | ||||
|  | ||||
|   // Grid, Block size. | ||||
|   // Each thread converts 8 values. | ||||
|   dim3 block(std::min(int(k / ELTS_PER_THREAD), 512)); | ||||
|   int const workSizePerRow = k / ELTS_PER_THREAD; | ||||
|   int const totalWorkSize = m_topk * workSizePerRow; | ||||
|   dim3 block(std::min(workSizePerRow, 512)); | ||||
|   // Get number of blocks per SM (assume we can fully utilize the SM). | ||||
|   int const numBlocksPerSM = 2048 / block.x; | ||||
|   dim3 grid(std::min(int(m_topk), multiProcessorCount * numBlocksPerSM)); | ||||
|   dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x), | ||||
|                      multiProcessorCount * numBlocksPerSM)); | ||||
|   while (grid.x <= multiProcessorCount && block.x > 64) { | ||||
|     grid.x *= 2; | ||||
|     block.x = (block.x + 1) / 2; | ||||
|   } | ||||
|  | ||||
|   cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>( | ||||
|       m_topk, k, reinterpret_cast<T*>(input), | ||||
|       reinterpret_cast<float*>(input_global_scale), | ||||
|       reinterpret_cast<uint32_t*>(output), | ||||
|       reinterpret_cast<uint32_t*>(output_scale), | ||||
|       reinterpret_cast<uint32_t*>(input_offset_by_experts), | ||||
|       reinterpret_cast<uint32_t*>(output_scale_offset_by_experts), n_experts); | ||||
|   int const blockRepeat = | ||||
|       (totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x); | ||||
|   if (blockRepeat > 1) { | ||||
|     size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t); | ||||
|     if (n_experts >= 4) { | ||||
|       cvt_fp16_to_fp4<T, false, false> | ||||
|           <<<grid, block, shared_mem_size, stream>>>( | ||||
|               m_topk, k, reinterpret_cast<T*>(input), | ||||
|               reinterpret_cast<float*>(input_global_scale), | ||||
|               reinterpret_cast<uint32_t*>(output), | ||||
|               reinterpret_cast<uint32_t*>(output_scale), | ||||
|               reinterpret_cast<uint32_t*>(input_offset_by_experts), | ||||
|               reinterpret_cast<uint32_t*>(output_scale_offset_by_experts), | ||||
|               n_experts); | ||||
|     } else { | ||||
|       cvt_fp16_to_fp4<T, false, true><<<grid, block, shared_mem_size, stream>>>( | ||||
|           m_topk, k, reinterpret_cast<T*>(input), | ||||
|           reinterpret_cast<float*>(input_global_scale), | ||||
|           reinterpret_cast<uint32_t*>(output), | ||||
|           reinterpret_cast<uint32_t*>(output_scale), | ||||
|           reinterpret_cast<uint32_t*>(input_offset_by_experts), | ||||
|           reinterpret_cast<uint32_t*>(output_scale_offset_by_experts), | ||||
|           n_experts); | ||||
|     } | ||||
|   } else { | ||||
|     if (n_experts >= 16) { | ||||
|       cvt_fp16_to_fp4<T, false, false><<<grid, block, 0, stream>>>( | ||||
|           m_topk, k, reinterpret_cast<T*>(input), | ||||
|           reinterpret_cast<float*>(input_global_scale), | ||||
|           reinterpret_cast<uint32_t*>(output), | ||||
|           reinterpret_cast<uint32_t*>(output_scale), | ||||
|           reinterpret_cast<uint32_t*>(input_offset_by_experts), | ||||
|           reinterpret_cast<uint32_t*>(output_scale_offset_by_experts), | ||||
|           n_experts, /* bool low_latency */ true); | ||||
|     } else { | ||||
|       cvt_fp16_to_fp4<T, false, true><<<grid, block, 0, stream>>>( | ||||
|           m_topk, k, reinterpret_cast<T*>(input), | ||||
|           reinterpret_cast<float*>(input_global_scale), | ||||
|           reinterpret_cast<uint32_t*>(output), | ||||
|           reinterpret_cast<uint32_t*>(output_scale), | ||||
|           reinterpret_cast<uint32_t*>(input_offset_by_experts), | ||||
|           reinterpret_cast<uint32_t*>(output_scale_offset_by_experts), | ||||
|           n_experts, /* bool low_latency */ true); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| /*Quantization entry for fp4 experts quantization*/ | ||||
| @ -383,7 +561,7 @@ void scaled_fp4_experts_quant_sm100a( | ||||
|   TORCH_CHECK(output_scale.size(1) * 4 == padded_k); | ||||
|  | ||||
|   auto in_dtype = input.dtype(); | ||||
|   at::cuda::CUDAGuard device_guard{(char)input.get_device()}; | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | ||||
|   const cudaStream_t stream = | ||||
|       at::cuda::getCurrentCUDAStream(input.get_device()); | ||||
|   if (in_dtype == at::ScalarType::Half) { | ||||
| @ -401,4 +579,4 @@ void scaled_fp4_experts_quant_sm100a( | ||||
|   } else { | ||||
|     TORCH_CHECK(false, "Expected input data type to be half or bfloat16"); | ||||
|   } | ||||
| } | ||||
| } | ||||
|  | ||||
| @ -347,7 +347,7 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output, | ||||
|   auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr()); | ||||
|   auto sf_out = static_cast<int32_t*>(output_sf.data_ptr()); | ||||
|   auto output_ptr = static_cast<int64_t*>(output.data_ptr()); | ||||
|   at::cuda::CUDAGuard device_guard{(char)input.get_device()}; | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | ||||
|   auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); | ||||
|  | ||||
|   // We don't support e8m0 scales at this moment. | ||||
|  | ||||
| @ -267,7 +267,7 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, | ||||
|               B_sf.sizes()[1], ")"); | ||||
|  | ||||
|   auto out_dtype = D.dtype(); | ||||
|   at::cuda::CUDAGuard device_guard{(char)A.get_device()}; | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); | ||||
|  | ||||
|   if (out_dtype == at::ScalarType::Half) { | ||||
|  | ||||
| @ -446,8 +446,6 @@ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) { | ||||
| template <> | ||||
| __inline__ __device__ uint32_t | ||||
| scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) { | ||||
|   [[maybe_unused]] __half2_raw h2r = | ||||
|       __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret); | ||||
|   union { | ||||
|     __half2_raw h2r; | ||||
|     uint32_t ui32; | ||||
|  | ||||
| @ -92,111 +92,112 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W,  // quant weight | ||||
|                                   torch::Tensor X,  // input | ||||
|                                   int64_t type, int64_t row) { | ||||
|   int col = X.sizes()[1]; | ||||
|   int vecs = X.sizes()[0]; | ||||
|   const int padded = (col + 512 - 1) / 512 * 512; | ||||
|   const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); | ||||
|   auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); | ||||
|   at::Tensor Y = torch::empty({1, row}, options); | ||||
|   at::Tensor Y = torch::empty({vecs, row}, options); | ||||
|   cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); | ||||
|   options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); | ||||
|   at::Tensor quant_X = torch::empty({1, padded / 32 * 9}, options); | ||||
|   at::Tensor quant_X = torch::empty({vecs, padded / 32 * 9}, options); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_mul_mat_vec_a8", [&] { | ||||
|     quantize_row_q8_1_cuda<scalar_t>((scalar_t*)X.data_ptr(), | ||||
|                                      (void*)quant_X.data_ptr(), col, 1, stream); | ||||
|     quantize_row_q8_1_cuda<scalar_t>( | ||||
|         (scalar_t*)X.data_ptr(), (void*)quant_X.data_ptr(), col, vecs, stream); | ||||
|     switch (type) { | ||||
|       case 2: | ||||
|         mul_mat_vec_q4_0_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 3: | ||||
|         mul_mat_vec_q4_1_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 6: | ||||
|         mul_mat_vec_q5_0_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 7: | ||||
|         mul_mat_vec_q5_1_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 8: | ||||
|         mul_mat_vec_q8_0_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 10: | ||||
|         mul_mat_vec_q2_K_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 11: | ||||
|         mul_mat_vec_q3_K_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 12: | ||||
|         mul_mat_vec_q4_K_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 13: | ||||
|         mul_mat_vec_q5_K_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 14: | ||||
|         mul_mat_vec_q6_K_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 16: | ||||
|         mul_mat_vec_iq2_xxs_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 17: | ||||
|         mul_mat_vec_iq2_xs_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 18: | ||||
|         mul_mat_vec_iq3_xxs_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 19: | ||||
|         mul_mat_vec_iq1_s_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 20: | ||||
|         mul_mat_vec_iq4_nl_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 21: | ||||
|         mul_mat_vec_iq3_s_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 22: | ||||
|         mul_mat_vec_iq2_s_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 23: | ||||
|         mul_mat_vec_iq4_xs_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|       case 29: | ||||
|         mul_mat_vec_iq1_m_q8_1_cuda<scalar_t>( | ||||
|             (void*)W.data_ptr(), (void*)quant_X.data_ptr(), | ||||
|             (scalar_t*)Y.data_ptr(), col, row, stream); | ||||
|             (scalar_t*)Y.data_ptr(), col, row, vecs, stream); | ||||
|         break; | ||||
|     } | ||||
|   }); | ||||
|  | ||||
| @ -1,16 +1,19 @@ | ||||
| // copied and adapted from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu | ||||
| template <typename scalar_t, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda> | ||||
| static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols, const int nrows) { | ||||
| static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, scalar_t * __restrict__ dst, const int ncols, const int nrows, const int nvecs) { | ||||
|     const auto row = blockIdx.x*blockDim.y + threadIdx.y; | ||||
|     const auto vec = blockIdx.y; | ||||
|  | ||||
|     if (row >= nrows) { | ||||
|     if (row >= nrows || vec >= nvecs) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     const int blocks_per_row = ncols / qk; | ||||
|     const int blocks_per_warp = vdr * WARP_SIZE / qi; | ||||
|     const int nrows_y = (ncols + 512 - 1) / 512 * 512; | ||||
|  | ||||
| // partial sum for each thread | ||||
|  | ||||
|     // partial sum for each thread | ||||
|     float tmp = 0.0f; | ||||
|  | ||||
|     const block_q_t  * x = (const block_q_t  *) vx; | ||||
| @ -19,7 +22,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * | ||||
|     for (auto i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) { | ||||
|         const int ibx = row*blocks_per_row + i; // x block index | ||||
|  | ||||
|         const int iby = i * (qk/QK8_1); // y block index that aligns with ibx | ||||
|         const int iby = vec*(nrows_y/QK8_1) + i * (qk/QK8_1); // y block index that aligns with ibx | ||||
|  | ||||
|         const int iqs  = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int | ||||
|  | ||||
| @ -33,177 +36,177 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * | ||||
|     } | ||||
|  | ||||
|     if (threadIdx.x == 0) { | ||||
|         dst[row] = tmp; | ||||
|         dst[vec*nrows + row] = tmp; | ||||
|     } | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_iq2_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_iq1_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_iq1_m_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK_K, QI1_M, block_iq1_m, 1, vec_dot_iq1_m_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_iq4_nl_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_iq4_xs_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| template<typename scalar_t> | ||||
| static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
| static void mul_mat_vec_iq3_s_q8_1_cuda(const void * vx, const void * vy, scalar_t * dst, const int ncols, const int nrows, const int nvecs, cudaStream_t stream) { | ||||
|     const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; | ||||
|     const dim3 block_nums(block_num_y, 1, 1); | ||||
|     const dim3 block_nums(block_num_y, nvecs, 1); | ||||
|     const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); | ||||
|     mul_mat_vec_q<scalar_t, QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1> | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); | ||||
|         <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows, nvecs); | ||||
| } | ||||
|  | ||||
| @ -206,8 +206,6 @@ __global__ void gemm_half_q_half_gptq_4bit_kernel( | ||||
|   auto offset_m = blockIdx.y * m_count; | ||||
|   auto offset_k = blockIdx.z * BLOCK_KN_SIZE; | ||||
|  | ||||
|   [[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); | ||||
|   [[maybe_unused]] int end_m = min(offset_m + m_count, size_m); | ||||
|   int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); | ||||
|  | ||||
|   int n = offset_n + t * 4; | ||||
| @ -344,8 +342,6 @@ __global__ void gemm_half_q_half_gptq_2bit_kernel( | ||||
|   auto offset_m = blockIdx.y * m_count; | ||||
|   auto offset_k = blockIdx.z * BLOCK_KN_SIZE; | ||||
|  | ||||
|   [[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); | ||||
|   [[maybe_unused]] int end_m = min(offset_m + m_count, size_m); | ||||
|   int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); | ||||
|  | ||||
|   int n = offset_n + t * 4; | ||||
| @ -465,8 +461,6 @@ __global__ void gemm_half_q_half_gptq_3bit_kernel( | ||||
|   auto offset_m = blockIdx.y * m_count; | ||||
|   auto offset_k = blockIdx.z * BLOCK_KN_SIZE; | ||||
|  | ||||
|   [[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); | ||||
|   [[maybe_unused]] int end_m = min(offset_m + m_count, size_m); | ||||
|   int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); | ||||
|  | ||||
|   int n = offset_n + t * 4; | ||||
| @ -593,8 +587,6 @@ __global__ void gemm_half_q_half_gptq_8bit_kernel( | ||||
|   auto offset_m = blockIdx.y * m_count; | ||||
|   auto offset_k = blockIdx.z * BLOCK_KN_SIZE; | ||||
|  | ||||
|   [[maybe_unused]] int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); | ||||
|   [[maybe_unused]] int end_m = min(offset_m + m_count, size_m); | ||||
|   int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); | ||||
|  | ||||
|   int n = offset_n + t * 4; | ||||
|  | ||||
| @ -1113,8 +1113,6 @@ __global__ void Marlin( | ||||
|     if constexpr (has_zp && !is_zp_float) { | ||||
|       if (is_new_zp) { | ||||
|         if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; | ||||
|         FragB frag_zp_0; | ||||
|         FragB frag_zp_1; | ||||
|         int zp_quant_0, zp_quant_1; | ||||
|  | ||||
|         if constexpr (w_type.size_bits() == 4) { | ||||
|  | ||||
| @ -1003,7 +1003,7 @@ struct MacheteCollectiveMma { | ||||
|     static constexpr int A_CPY_VEC = | ||||
|         decltype(max_common_vector(tCsA, tCrA_load)){}; | ||||
|  | ||||
|     static constexpr int COVERSION_WIDTH = | ||||
|     static constexpr int CONVERSION_WIDTH = | ||||
|         std::min(A_CPY_VEC, int(size<0>(tCrA_mma))); | ||||
|  | ||||
|     auto load_A_to_registers = [&](int read_stage) { | ||||
| @ -1026,8 +1026,8 @@ struct MacheteCollectiveMma { | ||||
|     // PIPELINED MAIN LOOP | ||||
|     // | ||||
|  | ||||
|     auto convert_A = [&, a_vec = Int<COVERSION_WIDTH>{}](int k_block, | ||||
|                                                          int read_stage) { | ||||
|     auto convert_A = [&, a_vec = Int<CONVERSION_WIDTH>{}](int k_block, | ||||
|                                                           int read_stage) { | ||||
|       load_extra_info_to_registers(partitioned_extra_info, | ||||
|                                    copy_partitions_extra_info, k_block, | ||||
|                                    read_stage); | ||||
|  | ||||
							
								
								
									
										75
									
								
								csrc/quantization/vectorization_utils.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								csrc/quantization/vectorization_utils.cuh
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,75 @@ | ||||
| #pragma once | ||||
| #include "vectorization.cuh" | ||||
|  | ||||
| namespace vllm { | ||||
|  | ||||
| template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp> | ||||
| struct DefaultVecOp { | ||||
|   ScaOp scalar_op; | ||||
|  | ||||
|   __device__ __forceinline__ void operator()( | ||||
|       vec_n_t<OutT, VEC_SIZE>& dst, const vec_n_t<InT, VEC_SIZE>& src) const { | ||||
| #pragma unroll | ||||
|     for (int i = 0; i < VEC_SIZE; ++i) { | ||||
|       scalar_op(dst.val[i], src.val[i]); | ||||
|     } | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <int VEC_SIZE, typename InT, typename OutT, typename VecOp, | ||||
|           typename ScaOp> | ||||
| __device__ inline void vectorize_with_alignment( | ||||
|     const InT* in, OutT* out, int len, int tid, int stride, | ||||
|     VecOp&& vec_op,       // vec_n_t<InT,16> -> vec_n_t<OutT,16> | ||||
|     ScaOp&& scalar_op) {  // InT -> OutT | ||||
|   static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0, | ||||
|                 "VEC_SIZE must be a positive power-of-two"); | ||||
|   constexpr int WIDTH = VEC_SIZE * sizeof(InT);  // eg: 64 B | ||||
|   uintptr_t addr = reinterpret_cast<uintptr_t>(in); | ||||
|  | ||||
|   int misalignment_offset = addr & (WIDTH - 1);       // addr % 64 | ||||
|   int alignment_bytes = WIDTH - misalignment_offset;  // 64 - (addr % 64) | ||||
|   int prefix_elems = alignment_bytes & (WIDTH - 1);   // handle 64 | ||||
|   prefix_elems /= sizeof(InT); | ||||
|   prefix_elems = min(prefix_elems, len);  // 0 ≤ prefix < 16 | ||||
|  | ||||
|   // 1. prefill the when it is unsafe to vectorize | ||||
|   for (int i = tid; i < prefix_elems; i += stride) { | ||||
|     scalar_op(out[i], in[i]); | ||||
|   } | ||||
|  | ||||
|   in += prefix_elems; | ||||
|   out += prefix_elems; | ||||
|   len -= prefix_elems; | ||||
|  | ||||
|   int num_vec = len / VEC_SIZE; | ||||
|   using vin_t = vec_n_t<InT, VEC_SIZE>; | ||||
|   using vout_t = vec_n_t<OutT, VEC_SIZE>; | ||||
|   auto* v_in = reinterpret_cast<const vin_t*>(in); | ||||
|   auto* v_out = reinterpret_cast<vout_t*>(out); | ||||
|  | ||||
|   // 2. vectorize the main part | ||||
|   for (int i = tid; i < num_vec; i += stride) { | ||||
|     vout_t tmp; | ||||
|     vec_op(tmp, v_in[i]); | ||||
|     v_out[i] = tmp; | ||||
|   } | ||||
|  | ||||
|   // 3. handle the tail | ||||
|   int tail_start = num_vec * VEC_SIZE; | ||||
|   for (int i = tid + tail_start; i < len; i += stride) { | ||||
|     scalar_op(out[i], in[i]); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <int VEC_SIZE, typename InT, typename OutT, typename ScaOp> | ||||
| __device__ __forceinline__ void vectorize_with_alignment(const InT* in, | ||||
|                                                          OutT* out, int len, | ||||
|                                                          int tid, int stride, | ||||
|                                                          ScaOp&& scalar_op) { | ||||
|   using Vec = DefaultVecOp<VEC_SIZE, InT, OutT, std::decay_t<ScaOp>>; | ||||
|   vectorize_with_alignment<VEC_SIZE>(in, out, len, tid, stride, Vec{scalar_op}, | ||||
|                                      std::forward<ScaOp>(scalar_op)); | ||||
| } | ||||
|  | ||||
| }  // namespace vllm | ||||
							
								
								
									
										338
									
								
								csrc/quickreduce/base.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										338
									
								
								csrc/quickreduce/base.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,338 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <cstdint> | ||||
| #include <hip/hip_runtime.h> | ||||
| #include <hip/hip_fp16.h> | ||||
| #include <hip/hip_bf16.h> | ||||
|  | ||||
| #define __quickreduce_device_inline__ __device__ __forceinline__ | ||||
| #define __quickreduce_launch_bounds_two_shot__ __launch_bounds__(256, 4) | ||||
| #define __quickreduce_launch_bounds_one_shot__ __launch_bounds__(512, 4) | ||||
|  | ||||
| namespace quickreduce { | ||||
|  | ||||
| typedef __hip_bfloat16 nv_bfloat16; | ||||
| typedef __hip_bfloat162 nv_bfloat162; | ||||
|  | ||||
| using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; | ||||
| using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; | ||||
|  | ||||
| // Setup acquire-release semantics for vector memory reads (mubuf instruction) | ||||
| // as per architecture. | ||||
| #if defined(__gfx942__) | ||||
| // CDNA3: Scope bits sc0, sc1 | ||||
|   #define MUBUF_ACQUIRE 16 | ||||
|   #define MUBUF_RELEASE 16 | ||||
| #elif (defined(__gfx908__) || defined(__gfx90a__)) | ||||
| // CDNA1 and CDNA2 - glc bit | ||||
|   #define MUBUF_ACQUIRE 1 | ||||
|   #define MUBUF_RELEASE 0 | ||||
| #endif | ||||
|  | ||||
| static constexpr int kNegOne = 0xBC00BC00;  // {-1, -1}, fp16x2_t | ||||
|  | ||||
| // Number of atoms (4xf16x2_t) processed by a single thread | ||||
| static constexpr int kAtoms = 8; | ||||
|  | ||||
| // We use a workgroup of 256 threads | ||||
| static constexpr int kBlockSize = 256; | ||||
| static constexpr int kAtomStride = kBlockSize; | ||||
|  | ||||
| // Size and atom stride of source/destination data that the block will | ||||
| // process. | ||||
| // Workgroup scope = Tile = (256 threads x 8 atoms x 16B) | ||||
| static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t); | ||||
|  | ||||
| // Max number of blocks. 304 CUs on MI300 | ||||
| static constexpr int kMaxNumBlocks = 304 * 4; | ||||
|  | ||||
| // Standard CDNA wavefront size. | ||||
| static constexpr int kWavefront = 64; | ||||
|  | ||||
| // 256 thread, 4 wavefronts. | ||||
| static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1}; | ||||
|  | ||||
| // Number of threads in a group for quantization | ||||
| // It corresponds to 32 F16 elements in quantization block | ||||
| static constexpr int kThreadGroupSize = 8; | ||||
|  | ||||
| // Methods | ||||
| __quickreduce_device_inline__ __host__ unsigned long divceil(unsigned long x, | ||||
|                                                              unsigned long y) { | ||||
|   return ((x + y - 1) / y); | ||||
| } | ||||
|  | ||||
| union BufferResource { | ||||
|   __quickreduce_device_inline__ constexpr BufferResource() | ||||
|       : config(0x00020000U) {} | ||||
|  | ||||
|   __quickreduce_device_inline__ constexpr BufferResource(void* buffer_address, | ||||
|                                                          uint32_t buffer_size) | ||||
|       : address(buffer_address), range(buffer_size), config(0x00020000U) {} | ||||
|  | ||||
|   int32x4_t descriptor; | ||||
|   struct { | ||||
|     void* address;  // 8B, out of which first 48b is address, and 16b is stride | ||||
|     // (unused) | ||||
|     uint32_t range;   // Byte range for the buffer resource | ||||
|     uint32_t config;  // Constant, DFMT=32b | ||||
|   }; | ||||
| }; | ||||
|  | ||||
| __quickreduce_device_inline__ static int32x4_t buffer_load_dwordx4( | ||||
|     int32x4_t srsrc, int32_t voffset, int32_t soffset, | ||||
|     int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); | ||||
|  | ||||
| __quickreduce_device_inline__ static void buffer_store_dwordx4( | ||||
|     int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset, | ||||
|     int32_t aux) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); | ||||
|  | ||||
| __quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) { | ||||
| #if defined(__gfx942__) | ||||
|   if (value) { | ||||
|     asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::); | ||||
|   } else { | ||||
|     asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::); | ||||
|   } | ||||
| #endif | ||||
| } | ||||
| union bf162_int_union { | ||||
|   int i; | ||||
|   nv_bfloat162 bf2; | ||||
| }; | ||||
|  | ||||
| template <typename T> | ||||
| __quickreduce_device_inline__ void packed_assign_add(int32x4_t* A, | ||||
|                                                      int32x4_t* B); | ||||
|  | ||||
| template <> | ||||
| __quickreduce_device_inline__ void packed_assign_add<half>(int32x4_t* A, | ||||
|                                                            int32x4_t* B) { | ||||
|   int32x4_t& tR_fragment = A[0]; | ||||
|   int32x4_t& tA_fragment = B[0]; | ||||
|  | ||||
|   asm volatile("v_pk_add_f16 %0, %1, %2" | ||||
|                : "=v"(tR_fragment[0]) | ||||
|                : "v"(tR_fragment[0]), "v"(tA_fragment[0])); | ||||
|   asm volatile("v_pk_add_f16 %0, %1, %2" | ||||
|                : "=v"(tR_fragment[1]) | ||||
|                : "v"(tR_fragment[1]), "v"(tA_fragment[1])); | ||||
|   asm volatile("v_pk_add_f16 %0, %1, %2" | ||||
|                : "=v"(tR_fragment[2]) | ||||
|                : "v"(tR_fragment[2]), "v"(tA_fragment[2])); | ||||
|   asm volatile("v_pk_add_f16 %0, %1, %2" | ||||
|                : "=v"(tR_fragment[3]) | ||||
|                : "v"(tR_fragment[3]), "v"(tA_fragment[3])); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __quickreduce_device_inline__ void packed_assign_add<nv_bfloat16>( | ||||
|     int32x4_t* A, int32x4_t* B) { | ||||
|   nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(A); | ||||
|   nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(B); | ||||
| #pragma unroll | ||||
|   for (int i = 0; i < 4; i++) { | ||||
|     tA[i] = __hadd2(tA[i], tB[i]); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| __quickreduce_device_inline__ int packed_max(int a, int b); | ||||
|  | ||||
| template <> | ||||
| __quickreduce_device_inline__ int packed_max<half>(int a, int b) { | ||||
|   int result; | ||||
|   asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); | ||||
|   return result; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __quickreduce_device_inline__ int packed_max<nv_bfloat16>(int a, int b) { | ||||
|   bf162_int_union A, B, R; | ||||
|   A.i = a; | ||||
|   B.i = b; | ||||
|   R.bf2 = __hmax2(A.bf2, B.bf2); | ||||
|   return R.i; | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| __quickreduce_device_inline__ int packed_min(int a, int b); | ||||
|  | ||||
| template <> | ||||
| __quickreduce_device_inline__ int packed_min<half>(int a, int b) { | ||||
|   int result; | ||||
|   asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); | ||||
|   return result; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __quickreduce_device_inline__ int packed_min<nv_bfloat16>(int a, int b) { | ||||
|   bf162_int_union A, B, R; | ||||
|   A.i = a; | ||||
|   B.i = b; | ||||
|   R.bf2 = __hmin2(A.bf2, B.bf2); | ||||
|   return R.i; | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| __quickreduce_device_inline__ int packed_abs_max(int a, int b); | ||||
|  | ||||
| template <> | ||||
| __quickreduce_device_inline__ int packed_abs_max<half>(int a, int b) { | ||||
|   half2 wmaxh2 = __builtin_bit_cast(half2, a); | ||||
|   half2 wminh2 = __builtin_bit_cast(half2, b); | ||||
|   half2 wblockmaxh2; | ||||
|  | ||||
|   wblockmaxh2.x = | ||||
|       __hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x; | ||||
|   wblockmaxh2.y = | ||||
|       __hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y; | ||||
|   return __builtin_bit_cast(int, wblockmaxh2); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __quickreduce_device_inline__ int packed_abs_max<nv_bfloat16>(int a, int b) { | ||||
|   bf162_int_union A, B, R; | ||||
|   A.i = a; | ||||
|   B.i = b; | ||||
|   R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x; | ||||
|   R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y; | ||||
|   return R.i; | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| __quickreduce_device_inline__ int packed_add(int a, int b); | ||||
|  | ||||
| template <> | ||||
| __quickreduce_device_inline__ int packed_add<half>(int a, int b) { | ||||
|   int result; | ||||
|   asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); | ||||
|   return result; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __quickreduce_device_inline__ int packed_add<nv_bfloat16>(int a, int b) { | ||||
|   bf162_int_union A, B, R; | ||||
|   A.i = a; | ||||
|   B.i = b; | ||||
|   R.bf2 = __hadd2(A.bf2, B.bf2); | ||||
|   return R.i; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __quickreduce_device_inline__ int packed_add<int16_t>(int a, int b) { | ||||
|   int result; | ||||
|   asm volatile("v_pk_add_i16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); | ||||
|   return result; | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| __quickreduce_device_inline__ int packed_sub(int a, int b); | ||||
|  | ||||
| template <> | ||||
| __quickreduce_device_inline__ int packed_sub<half>(int a, int b) { | ||||
|   int result; | ||||
|  | ||||
|   // MI300 lacks packed fp16 sub instruction. So we do -1 * min + max | ||||
|   asm volatile("v_pk_fma_f16 %0, %1, %2 %3" | ||||
|                : "=v"(result) | ||||
|                : "v"(kNegOne), "v"(b), "v"(a)); | ||||
|   return result; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __quickreduce_device_inline__ int packed_sub<nv_bfloat16>(int a, int b) { | ||||
|   bf162_int_union A, B, R; | ||||
|   A.i = a; | ||||
|   B.i = b; | ||||
|   R.bf2 = __hsub2(A.bf2, B.bf2); | ||||
|   return R.i; | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| __quickreduce_device_inline__ int packed_mul(int a, int b); | ||||
|  | ||||
| template <> | ||||
| __quickreduce_device_inline__ int packed_mul<half>(int a, int b) { | ||||
|   int result; | ||||
|   asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b)); | ||||
|   return result; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __quickreduce_device_inline__ int packed_mul<nv_bfloat16>(int a, int b) { | ||||
|   nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(&a); | ||||
|   nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(&b); | ||||
|   nv_bfloat162 tR = __hmul2(*tA, *tB); | ||||
|   return *(reinterpret_cast<int*>(&tR)); | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| __quickreduce_device_inline__ int packed_rcp(int a); | ||||
|  | ||||
| template <> | ||||
| __quickreduce_device_inline__ int packed_rcp<half>(int a) { | ||||
|   return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a))); | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __quickreduce_device_inline__ int packed_rcp<nv_bfloat16>(int a) { | ||||
|   bf162_int_union A, R; | ||||
|   A.i = a; | ||||
|   R.bf2 = h2rcp(A.bf2); | ||||
|   return R.i; | ||||
| } | ||||
|  | ||||
| // changes dtype | ||||
| __quickreduce_device_inline__ float T2float_cast(half a) { | ||||
|   return __half2float(a); | ||||
| } | ||||
|  | ||||
| __quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) { | ||||
|   return __bfloat162float(a); | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| __quickreduce_device_inline__ int group_abs_max(int32x4_t atom) { | ||||
|   const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize; | ||||
|  | ||||
|   int wmax, wmin, wblockmax; | ||||
|   int a, b; | ||||
|   a = packed_max<T>(atom[0], atom[1]); | ||||
|   b = packed_max<T>(atom[2], atom[3]); | ||||
|  | ||||
|   wmax = packed_max<T>(a, b); | ||||
|  | ||||
|   a = packed_min<T>(atom[0], atom[1]); | ||||
|   b = packed_min<T>(atom[2], atom[3]); | ||||
|  | ||||
|   wmin = packed_min<T>(a, b); | ||||
|  | ||||
|   // Reduce the max among a group of threads | ||||
|   // Note: This is basically 2 blocks of values setup as the | ||||
|   // upper/lower halves of the f16x2_t | ||||
|   for (int i = 1; i < kThreadGroupSize; i <<= 1) { | ||||
|     int x = __shfl_down(wmax, i); | ||||
|     wmax = packed_max<T>(wmax, x); | ||||
|  | ||||
|     int y = __shfl_down(wmin, i); | ||||
|     wmin = packed_min<T>(wmin, y); | ||||
|   } | ||||
|   wblockmax = packed_abs_max<T>(wmax, wmin); | ||||
|   // Share with the cohort | ||||
|   wblockmax = __shfl(wblockmax, group_leader); | ||||
|   return wblockmax; | ||||
| } | ||||
|  | ||||
| __quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr, | ||||
|                                                  uint32_t flag) { | ||||
|   __atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE); | ||||
| } | ||||
|  | ||||
| __quickreduce_device_inline__ void wait_sync_flag(uint32_t* flag_ptr, | ||||
|                                                   uint32_t flag) { | ||||
|   while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) { | ||||
|   } | ||||
| } | ||||
|  | ||||
| }  // namespace quickreduce | ||||
							
								
								
									
										196
									
								
								csrc/quickreduce/quick_reduce.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										196
									
								
								csrc/quickreduce/quick_reduce.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,196 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <vector> | ||||
| #include <hip/hip_runtime.h> | ||||
| #include "quick_reduce_impl.cuh" | ||||
|  | ||||
| #define HIP_CHECK(err)                                                     \ | ||||
|   do {                                                                     \ | ||||
|     hipError_t err_ = (err);                                               \ | ||||
|     if (err_ != hipSuccess) {                                              \ | ||||
|       std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, \ | ||||
|                   hipGetErrorString(err_));                                \ | ||||
|       throw std::runtime_error("HIP error");                               \ | ||||
|     }                                                                      \ | ||||
|   } while (0) | ||||
|  | ||||
| namespace quickreduce { | ||||
| using fptr_t = int64_t; | ||||
| static_assert(sizeof(void*) == sizeof(fptr_t)); | ||||
|  | ||||
| template <typename AllReduceKernel, typename T> | ||||
| __global__ __quickreduce_launch_bounds_two_shot__ static void | ||||
| allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks, | ||||
|                             int rank, uint8_t** dbuffer_list, | ||||
|                             uint32_t data_offset, uint32_t flag_color) { | ||||
|   int block = blockIdx.x; | ||||
|   int grid = gridDim.x; | ||||
|  | ||||
|   while (block < num_blocks) { | ||||
|     AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, | ||||
|                          flag_color); | ||||
|     block += grid; | ||||
|     flag_color++; | ||||
|   } | ||||
| } | ||||
|  | ||||
| #define TWOSHOT_DISPATCH(__codec)                                           \ | ||||
|   if (world_size == 2) {                                                    \ | ||||
|     using LineCodec = __codec<T, 2>;                                        \ | ||||
|     using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>;   \ | ||||
|     hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>),   \ | ||||
|                        dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ | ||||
|                        num_blocks, rank, dbuffer_list, data_offset,         \ | ||||
|                        flag_color);                                         \ | ||||
|   } else if (world_size == 4) {                                             \ | ||||
|     using LineCodec = __codec<T, 4>;                                        \ | ||||
|     using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>;   \ | ||||
|     hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>),   \ | ||||
|                        dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ | ||||
|                        num_blocks, rank, dbuffer_list, data_offset,         \ | ||||
|                        flag_color);                                         \ | ||||
|   } else if (world_size == 8) {                                             \ | ||||
|     using LineCodec = __codec<T, 8>;                                        \ | ||||
|     using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>;   \ | ||||
|     hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>),   \ | ||||
|                        dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \ | ||||
|                        num_blocks, rank, dbuffer_list, data_offset,         \ | ||||
|                        flag_color);                                         \ | ||||
|   } | ||||
|  | ||||
| enum QuickReduceQuantLevel { | ||||
|   F16 = 0, | ||||
|   INT8 = 1, | ||||
|   INT6 = 2, | ||||
|   INT4 = 3, | ||||
| }; | ||||
|  | ||||
| struct DeviceComms { | ||||
|   // Max problem size is 2GB (in bytes) or half of uint32_t max value. | ||||
|   int64_t kMaxProblemSize = | ||||
|       static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1; | ||||
|  | ||||
|   // Max TP-8 | ||||
|   static int constexpr kMaxWorldSize = 8; | ||||
|  | ||||
|   bool initialized = false; | ||||
|   uint32_t flag_color = 1; | ||||
|   int world_size; | ||||
|   int rank; | ||||
|  | ||||
|   uint8_t* dbuffer; | ||||
|   uint8_t** dbuffer_list; | ||||
|   hipIpcMemHandle_t buffer_ipc_handle; | ||||
|   std::vector<hipIpcMemHandle_t> all_buffer_ipc_handles; | ||||
|   std::vector<uint8_t*> buffer_list; | ||||
|   uint32_t data_offset; | ||||
|  | ||||
|   DeviceComms() : initialized(false), world_size(1), rank(0) {} | ||||
|   ~DeviceComms() { destroy(); } | ||||
|  | ||||
|   void init(int world_size, int rank, | ||||
|             std::optional<int64_t> max_problem_size = std::nullopt) { | ||||
|     destroy(); | ||||
|     this->world_size = world_size; | ||||
|     this->rank = rank; | ||||
|     if (max_problem_size.has_value() && max_problem_size.value() > 0) { | ||||
|       this->kMaxProblemSize = max_problem_size.value(); | ||||
|     } | ||||
|     // Allocate buffer size for worst case: F16 2-stage buffer. | ||||
|     uint32_t flags_buffer_size = | ||||
|         2 * world_size * kMaxNumBlocks * sizeof(uint32_t); | ||||
|     static int64_t data_buffer_size = 2 * this->kMaxProblemSize; | ||||
|     int64_t total_buffer_size = flags_buffer_size + data_buffer_size; | ||||
|     data_offset = flags_buffer_size; | ||||
|     HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, | ||||
|                                     hipDeviceMallocUncached)); | ||||
|  | ||||
|     // Clear the flags buffer. | ||||
|     HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size)); | ||||
|  | ||||
|     // Device-side list of IPC buffers. | ||||
|     buffer_list.resize(world_size); | ||||
|     HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*))); | ||||
|  | ||||
|     // Create IPC handles for rank's communication buffer. | ||||
|     all_buffer_ipc_handles.resize(world_size); | ||||
|     HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer)); | ||||
|  | ||||
|     initialized = true; | ||||
|   } | ||||
|   int get_world_size() { return world_size; } | ||||
|   int get_rank() { return rank; } | ||||
|   bool status() { return initialized; } | ||||
|   hipIpcMemHandle_t const get_handle() { return buffer_ipc_handle; } | ||||
|  | ||||
|   void destroy() { | ||||
|     if (initialized) { | ||||
|       for (int i = 0; i < world_size; i++) { | ||||
|         if (i != rank) { | ||||
|           HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i])); | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       HIP_CHECK(hipFree(dbuffer)); | ||||
|       HIP_CHECK(hipFree(dbuffer_list)); | ||||
|  | ||||
|       initialized = false; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   void open_ipc_handles(std::vector<hipIpcMemHandle_t> const& ipc_handles) { | ||||
|     assert(ipc_handles.size() == all_buffer_ipc_handles.size()); | ||||
|     for (int i = 0; i < world_size; i++) { | ||||
|       all_buffer_ipc_handles[i] = ipc_handles[i]; | ||||
|     } | ||||
|  | ||||
|     // Open device memory access to the IPC communication buffers. | ||||
|     // Note: For our own rank, we do not need to open a handle. | ||||
|     for (int i = 0; i < world_size; i++) { | ||||
|       if (i != rank) { | ||||
|         HIP_CHECK(hipIpcOpenMemHandle((void**)&buffer_list[i], | ||||
|                                       all_buffer_ipc_handles[i], | ||||
|                                       hipIpcMemLazyEnablePeerAccess)); | ||||
|       } else { | ||||
|         buffer_list[i] = dbuffer; | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(), | ||||
|                         world_size * sizeof(uint8_t*), hipMemcpyHostToDevice)); | ||||
|   } | ||||
|  | ||||
|   template <typename T, bool cast_bf2half> | ||||
|   void allreduce(T const* A, T* B, uint32_t N, int quant_level, | ||||
|                  hipStream_t stream) { | ||||
|     if (world_size != 2 && world_size != 4 && world_size != 8) { | ||||
|       throw std::runtime_error("All Reduce not supported for world_size = " + | ||||
|                                std::to_string(world_size)); | ||||
|     } | ||||
|  | ||||
|     // Configuration. | ||||
|     uint32_t msg_size = N * sizeof(T); | ||||
|     uint32_t num_blocks = divceil(msg_size, kTileSize); | ||||
|     uint32_t grid = min(kMaxNumBlocks, num_blocks); | ||||
|     auto quant_level_ = static_cast<QuickReduceQuantLevel>(quant_level); | ||||
|     switch (quant_level_) { | ||||
|       case QuickReduceQuantLevel::INT8: | ||||
|         TWOSHOT_DISPATCH(CodecQ8) | ||||
|         break; | ||||
|       case QuickReduceQuantLevel::INT6: | ||||
|         TWOSHOT_DISPATCH(CodecQ6) | ||||
|         break; | ||||
|       case QuickReduceQuantLevel::INT4: | ||||
|         TWOSHOT_DISPATCH(CodecQ4) | ||||
|         break; | ||||
|       default: | ||||
|         TWOSHOT_DISPATCH(CodecFP) | ||||
|         break; | ||||
|     } | ||||
|     HIP_CHECK(cudaGetLastError()); | ||||
|     // Rotate the flag color. | ||||
|     flag_color += divceil(N, grid); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| }  // namespace quickreduce | ||||
							
								
								
									
										698
									
								
								csrc/quickreduce/quick_reduce_impl.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										698
									
								
								csrc/quickreduce/quick_reduce_impl.cuh
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,698 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include <hip/hip_runtime.h> | ||||
| #include "base.h" | ||||
|  | ||||
| namespace quickreduce { | ||||
|  | ||||
| struct CodecBase { | ||||
|   const int thread; | ||||
|   const int rank; | ||||
|   const int group_leader; | ||||
|   __quickreduce_device_inline__ CodecBase(int thread, int rank) | ||||
|       : thread(thread), | ||||
|         rank(rank), | ||||
|         group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) { | ||||
|     set_fp16_ovfl(true); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| // Default full precision codec. | ||||
| template <typename T, int world_size> | ||||
| struct CodecFP : public CodecBase { | ||||
|   static constexpr int kWorldSize = world_size; | ||||
|   static constexpr int kRankAtoms = kAtoms / kWorldSize; | ||||
|  | ||||
|   // Codec tile size process by this workgroup. | ||||
|   // Each thread processes atoms of f16x8_t (16B). | ||||
|   static constexpr int kRankTransmittedTileSize = | ||||
|       kBlockSize * kRankAtoms * sizeof(int32x4_t); | ||||
|   static_assert(kRankTransmittedTileSize % 16 == 0, | ||||
|                 "kRankTransmittedTileSize must be 16B aligned."); | ||||
|  | ||||
|   // Total tile size for the collective communication. | ||||
|   static constexpr int kTransmittedTileSize = | ||||
|       kRankTransmittedTileSize * kWorldSize; | ||||
|  | ||||
|   __quickreduce_device_inline__ CodecFP(int thread, int rank) | ||||
|       : CodecBase(thread, rank) {} | ||||
|  | ||||
|   __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, | ||||
|                                           const int32x4_t* __restrict__ data) { | ||||
|     for (int i = 0; i < kRankAtoms; i++) { | ||||
|       __builtin_nontemporal_store(data[i], send_buffer + thread); | ||||
|       send_buffer += kAtomStride; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, | ||||
|                                           int32x4_t* __restrict__ data) { | ||||
|     for (int i = 0; i < kRankAtoms; i++) { | ||||
|       data[i] = __builtin_nontemporal_load(*recv_buffer + thread); | ||||
|       *recv_buffer += kAtomStride; | ||||
|     } | ||||
|   } | ||||
| }; | ||||
|  | ||||
| // Int4 symmetric quantization codec. | ||||
| // We quantize the FP16 data to block-scaled Int4 in blocks of 4 * | ||||
| // kThreadGroupSize. | ||||
| template <typename T, int world_size> | ||||
| struct CodecQ4 : public CodecBase { | ||||
|   static constexpr int kWorldSize = world_size; | ||||
|  | ||||
|   // Codec tile size process by this workgroup. | ||||
|   // Each threads processes a fragment of fp16x8_t (16B), | ||||
|   // into a int4x8_t (4B) and a fp16 scale shared among 32 values. | ||||
|   static constexpr int kRankAtoms = kAtoms / kWorldSize; | ||||
|   static constexpr int kRankTileStride = 1152; | ||||
|   static constexpr int kRankTileScaleOffset = 1024; | ||||
|   static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; | ||||
|   static_assert(kRankTransmittedTileSize % 16 == 0, | ||||
|                 "kRankTransmittedTileSize must be 16B aligned."); | ||||
|  | ||||
|   static constexpr int kRankBufferTileStride = | ||||
|       kRankTileStride / sizeof(int32x4_t); | ||||
|  | ||||
|   // Total tile size for the collective communication. | ||||
|   static constexpr int kTransmittedTileSize = | ||||
|       kRankTransmittedTileSize * kWorldSize; | ||||
|  | ||||
|   // Constants configuration | ||||
|  | ||||
|   // {-1/8.0h, -1/8.0h}, f16x2_t | ||||
|   static constexpr int kScaleFactor = | ||||
|       std::is_same<T, half>::value ? 0xB000B000 : 0xBE00BE00; | ||||
|  | ||||
|   // {1e-7, 1e-7}, f16x2_t | ||||
|   static constexpr int kScaleEpsilon = | ||||
|       std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7; | ||||
|  | ||||
|   // {-8, -8}, f16x2_t | ||||
|   static constexpr int kRangeMin = | ||||
|       std::is_same<T, half>::value ? 0xC800C800 : 0xC100C100; | ||||
|  | ||||
|   // {+7, +7}, f16x2_t | ||||
|   static constexpr int kRangeMax = | ||||
|       std::is_same<T, half>::value ? 0x47004700 : 0x40E040E0; | ||||
|  | ||||
|   // {+8, +8}, int16x2_t | ||||
|   static constexpr int kRangeBias = 0x00080008; | ||||
|  | ||||
|   __quickreduce_device_inline__ CodecQ4(int thread, int rank) | ||||
|       : CodecBase(thread, rank) {} | ||||
|  | ||||
|   __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, | ||||
|                                           const int32x4_t* __restrict__ data) { | ||||
|     for (int k = 0; k < kRankAtoms; k++) { | ||||
|       int32x4_t const atom = data[k]; | ||||
|  | ||||
|       // Compute the absolute maximum of the atom in the thread group | ||||
|       // In 2 blocks of values, upper/lower halves of the f16x2_t | ||||
|       int wblockmax = group_abs_max<T>(atom); | ||||
|  | ||||
|       // Derive scales | ||||
|       int decoding_scale; | ||||
|       int encoding_scale; | ||||
|       decoding_scale = packed_mul<T>(wblockmax, kScaleFactor); | ||||
|       encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon); | ||||
|       encoding_scale = packed_rcp<T>(encoding_scale); | ||||
|  | ||||
|       // Apply scales to get quantized values | ||||
|       int32x4_t w; | ||||
|       for (int i = 0; i < 4; i++) { | ||||
|         w[i] = packed_mul<T>(atom[i], encoding_scale); | ||||
|         w[i] = packed_max<T>(w[i], kRangeMin); | ||||
|         w[i] = packed_min<T>(w[i], kRangeMax); | ||||
|       } | ||||
|  | ||||
|       // Convert from f16x2_t to uint16x2_t | ||||
|       int32x4_t q; | ||||
|       { | ||||
|         int16_t* qi = reinterpret_cast<int16_t*>(&q); | ||||
|         T* wh = reinterpret_cast<T*>(&w); | ||||
|         for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); | ||||
|  | ||||
|         for (int i = 0; i < 4; i++) { | ||||
|           q[i] = packed_add<int16_t>(q[i], kRangeBias); | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       // Pack 8 x q4 into int32_t | ||||
|       int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12); | ||||
|  | ||||
|       // Write quantized atom to send_buffer | ||||
|       // note: only the group leader stores the scale | ||||
|       uint8_t* atom_ptr = | ||||
|           reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride); | ||||
|       int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread; | ||||
|       int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + | ||||
|                     (thread / 8); | ||||
|  | ||||
|       __builtin_nontemporal_store(qw, qw_ptr); | ||||
|       if (threadIdx.x == group_leader) { | ||||
|         __builtin_nontemporal_store(decoding_scale, qs_ptr); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, | ||||
|                                           int32x4_t* __restrict__ data) { | ||||
|     for (int k = 0; k < kRankAtoms; k++) { | ||||
|       // Directly read quantized atom from recv_buffer | ||||
|       uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer); | ||||
|       int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread; | ||||
|       int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + | ||||
|                     (thread / 8); | ||||
|  | ||||
|       int32_t qw = __builtin_nontemporal_load(qw_ptr); | ||||
|       int qs = __builtin_nontemporal_load(qs_ptr); | ||||
|  | ||||
|       *recv_buffer += kRankBufferTileStride; | ||||
|  | ||||
|       // Unpack q4 into f16x8_t | ||||
|       int32x4_t w; | ||||
|       { | ||||
|         static constexpr uint kMask000F = 0x000F000F; | ||||
|         static constexpr uint kHalf2_1024 = | ||||
|             0x64006400;  // {1024.0, 1024.0}, fp16x2_t | ||||
|         static uint constexpr kHalf2_1032 = | ||||
|             0xE408E408;  // {-1032.0, -1032.0}, fp16x2_t | ||||
|  | ||||
|         for (int i = 0; i < 4; i++) { | ||||
|           if constexpr (std::is_same<T, half>::value) { | ||||
|             int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024; | ||||
|             w[i] = packed_add<half>(q4, kHalf2_1032); | ||||
|           } else { | ||||
|             int32_t int16_2 = (qw >> (i * 4)) & kMask000F; | ||||
|             int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF); | ||||
|             int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF); | ||||
|             nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low)); | ||||
|             nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high)); | ||||
|             nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); | ||||
|             int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2); | ||||
|             w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       // Apply decoding scales | ||||
|       for (int i = 0; i < 4; i++) { | ||||
|         w[i] = packed_mul<T>(w[i], qs); | ||||
|       } | ||||
|  | ||||
|       data[k] = w; | ||||
|     } | ||||
|   } | ||||
| }; | ||||
|  | ||||
| // Int6 symmetric quantization codec. | ||||
| // We quantize the FP16 data to block-scaled Int6 in blocks of 4 * | ||||
| // kThreadGroupSize. | ||||
| template <typename T, int world_size> | ||||
| struct CodecQ6 : public CodecBase { | ||||
|   static constexpr int kWorldSize = world_size; | ||||
|  | ||||
|   // Codec tile size process by this workgroup. | ||||
|   // Each threads processes a fragment of fp16x8_t (16B), | ||||
|   // into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values. | ||||
|   static constexpr int kRankAtoms = kAtoms / kWorldSize; | ||||
|   static constexpr int kRankTileStride = 1664; | ||||
|   static constexpr int kRankTileQ2Offset = 1024; | ||||
|   static constexpr int kRankTileScaleOffset = 1536; | ||||
|   static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; | ||||
|   static_assert(kRankTransmittedTileSize % 16 == 0, | ||||
|                 "kRankTransmittedTileSize must be 16B aligned."); | ||||
|  | ||||
|   static constexpr int kRankBufferTileStride = | ||||
|       kRankTileStride / sizeof(int32x4_t); | ||||
|  | ||||
|   // Total tile size for the collective communication. | ||||
|   static constexpr int kTransmittedTileSize = | ||||
|       kRankTransmittedTileSize * kWorldSize; | ||||
|  | ||||
|   // Constants configuration | ||||
|  | ||||
|   // {-1/32.0h, -1/32.0h}, fp16x2_t | ||||
|   static constexpr int kScaleFactor = | ||||
|       std::is_same<T, half>::value ? 0xA800A800 : 0xBD00BD00; | ||||
|  | ||||
|   // {1e-7, 1e-7}, fp16x2_t | ||||
|   static constexpr int kScaleEpsilon = | ||||
|       std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7; | ||||
|  | ||||
|   // {-32, -32}, fp16x2_t | ||||
|   static constexpr int kRangeMin = | ||||
|       std::is_same<T, half>::value ? 0xD000D000 : 0xC200C200; | ||||
|  | ||||
|   // {+31, +31}, fp16x2_t | ||||
|   static constexpr int kRangeMax = | ||||
|       std::is_same<T, half>::value ? 0x4FC04FC0 : 0x41F841F8; | ||||
|  | ||||
|   // {+32, +32}, int16x2_t | ||||
|   static constexpr int kRangeBias = 0x00200020; | ||||
|  | ||||
|   __quickreduce_device_inline__ CodecQ6(int thread, int rank) | ||||
|       : CodecBase(thread, rank) {} | ||||
|  | ||||
|   __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, | ||||
|                                           const int32x4_t* __restrict__ data) { | ||||
|     for (int k = 0; k < kRankAtoms; k++) { | ||||
|       int32x4_t const atom = data[k]; | ||||
|  | ||||
|       // Compute the absolute maximum of the atom in the thread group | ||||
|       // In 2 blocks of values, upper/lower halves of the f16x2_t | ||||
|       int wblockmax = group_abs_max<T>(atom); | ||||
|  | ||||
|       // Derive scales | ||||
|       int decoding_scale; | ||||
|       int encoding_scale; | ||||
|       decoding_scale = packed_mul<T>(wblockmax, kScaleFactor); | ||||
|       encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon); | ||||
|       encoding_scale = packed_rcp<T>(encoding_scale); | ||||
|  | ||||
|       // Apply scales to get quantized values | ||||
|       int32x4_t w; | ||||
|       for (int i = 0; i < 4; i++) { | ||||
|         w[i] = packed_mul<T>(atom[i], encoding_scale); | ||||
|         w[i] = packed_max<T>(w[i], kRangeMin); | ||||
|         w[i] = packed_min<T>(w[i], kRangeMax); | ||||
|       } | ||||
|  | ||||
|       // Convert from f16x2_t to uint16x2_t | ||||
|       int32x4_t q; | ||||
|       { | ||||
|         int16_t* qi = reinterpret_cast<int16_t*>(&q); | ||||
|         T* wh = reinterpret_cast<T*>(&w); | ||||
|         for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); | ||||
|  | ||||
|         for (int i = 0; i < 4; i++) { | ||||
|           q[i] = packed_add<int16_t>(q[i], kRangeBias); | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       // Pack 8 x q6 into int32_t + int16_t | ||||
|       uint32_t q4w; | ||||
|       uint16_t q2w = 0; | ||||
|       q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) | | ||||
|             ((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12); | ||||
|       { | ||||
|         int16_t* tw = reinterpret_cast<int16_t*>(&q); | ||||
| #pragma unroll | ||||
|         for (int i = 0; i < 8; i++) { | ||||
|           q2w |= (tw[i] >> 4) << (i * 2); | ||||
|         } | ||||
|       } | ||||
|       // Write quantized atom to send_buffer | ||||
|       // note: only the group leader stores the scale | ||||
|       uint8_t* atom_ptr = | ||||
|           reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride); | ||||
|       uint32_t* q4w_ptr = reinterpret_cast<uint32_t*>(atom_ptr) + thread; | ||||
|       uint16_t* q2w_ptr = | ||||
|           reinterpret_cast<uint16_t*>(atom_ptr + kRankTileQ2Offset) + thread; | ||||
|       int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + | ||||
|                     (thread / 8); | ||||
|  | ||||
|       __builtin_nontemporal_store(q4w, q4w_ptr); | ||||
|       __builtin_nontemporal_store(q2w, q2w_ptr); | ||||
|       if (threadIdx.x == group_leader) { | ||||
|         __builtin_nontemporal_store(decoding_scale, qs_ptr); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, | ||||
|                                           int32x4_t* __restrict__ data) { | ||||
|     for (int k = 0; k < kRankAtoms; k++) { | ||||
|       // Directly read quantized atom from recv_buffer | ||||
|       uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer); | ||||
|       uint32_t* q4w_ptr = reinterpret_cast<uint32_t*>(atom_ptr) + thread; | ||||
|       uint16_t* q2w_ptr = | ||||
|           reinterpret_cast<uint16_t*>(atom_ptr + kRankTileQ2Offset) + thread; | ||||
|       int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + | ||||
|                     (thread / 8); | ||||
|  | ||||
|       uint32_t q4w = __builtin_nontemporal_load(q4w_ptr); | ||||
|       uint16_t q2w = __builtin_nontemporal_load(q2w_ptr); | ||||
|       int qs = __builtin_nontemporal_load(qs_ptr); | ||||
|  | ||||
|       *recv_buffer += kRankBufferTileStride; | ||||
|  | ||||
|       // Unpack q6 into fp16x8_t | ||||
|       int32x4_t w; | ||||
|       { | ||||
|         static uint constexpr kMask000F = 0x000F000F; | ||||
|         static uint constexpr kHalf2_1024 = | ||||
|             0x64006400;  // {1024.0, 1024.0}, fp16x2_t | ||||
|         static uint constexpr kHalf2_1056 = | ||||
|             0xE420E420;  // {-1056.0, -1056.0}, fp16x2_t | ||||
|  | ||||
| #pragma unroll | ||||
|         for (int i = 0; i < 4; i++) { | ||||
|           int32_t q4 = q4w & kMask000F; | ||||
|           int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14); | ||||
|           q4w >>= 4; | ||||
|           q2w >>= 4; | ||||
|           if constexpr (std::is_same<T, half>::value) { | ||||
|             int32_t q6 = q4 | (q2 << 4) | kHalf2_1024; | ||||
|             asm volatile("v_pk_add_f16 %0, %1, %2" | ||||
|                          : "=v"(w[i]) | ||||
|                          : "v"(q6), "v"(kHalf2_1056)); | ||||
|           } else { | ||||
|             int32_t int16_2 = q4 | (q2 << 4); | ||||
|             int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF); | ||||
|             int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF); | ||||
|  | ||||
|             nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low)); | ||||
|             nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high)); | ||||
|             nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); | ||||
|             int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2); | ||||
|             w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       // Apply decoding scales | ||||
|       for (int i = 0; i < 4; i++) { | ||||
|         w[i] = packed_mul<T>(w[i], qs); | ||||
|       } | ||||
|  | ||||
|       // That's pretty much it... | ||||
|       data[k] = w; | ||||
|     } | ||||
|   } | ||||
| }; | ||||
|  | ||||
| // Int8 symmetric quantization codec. | ||||
| // We quantize the FP16 data to block-scaled Int8 in blocks of 4 * | ||||
| // kThreadGroupSize. | ||||
| template <typename T, int world_size> | ||||
| struct CodecQ8 : public CodecBase { | ||||
|   static constexpr int kWorldSize = world_size; | ||||
|  | ||||
|   // Codec tile size process by this workgroup. | ||||
|   // Each threads processes a fragment of f16x8_t (16B), | ||||
|   // into a int8x8_t (8B) and a f16 scale shared among 32 values. | ||||
|   static constexpr int kRankAtoms = kAtoms / kWorldSize; | ||||
|   static constexpr int kRankTileStride = 2176; | ||||
|   static constexpr int kRankTileScaleOffset = 2048; | ||||
|   static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms; | ||||
|   static_assert(kRankTransmittedTileSize % 16 == 0, | ||||
|                 "kRankTileSize must be 16B aligned."); | ||||
|  | ||||
|   static constexpr int kRankBufferTileStride = | ||||
|       kRankTileStride / sizeof(int32x4_t); | ||||
|  | ||||
|   // Total tile size for the collective communication. | ||||
|   static constexpr int kTransmittedTileSize = | ||||
|       kRankTransmittedTileSize * kWorldSize; | ||||
|  | ||||
|   // Constants configuration | ||||
|  | ||||
|   // {-1/128.0h, -1/128.0h}, f16x2_t | ||||
|   static constexpr int kScaleFactor = | ||||
|       std::is_same<T, half>::value ? 0xA000A000 : 0xBC00BC00; | ||||
|  | ||||
|   // {1e-7, 1e-7}, f16x2_t | ||||
|   static constexpr int kScaleEpsilon = | ||||
|       std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7; | ||||
|  | ||||
|   // {-128, -128}, f16x2_t | ||||
|   static constexpr int kRangeMin = | ||||
|       std::is_same<T, half>::value ? 0xD800D800 : 0xC300C300; | ||||
|   // {+127, +127}, f16x2_t | ||||
|   static constexpr int kRangeMax = | ||||
|       std::is_same<T, half>::value ? 0x57F057F0 : 0x42FE42FE; | ||||
|  | ||||
|   // {+128, +128}, int16x2_t | ||||
|   static constexpr int kRangeBias = 0x00800080; | ||||
|  | ||||
|   __quickreduce_device_inline__ CodecQ8(int thread, int rank) | ||||
|       : CodecBase(thread, rank) {} | ||||
|  | ||||
|   __quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer, | ||||
|                                           int32x4_t const* __restrict__ data) { | ||||
|     for (int k = 0; k < kRankAtoms; k++) { | ||||
|       int32x4_t const atom = data[k]; | ||||
|       // Compute the absolute maximum of the atom in the thread group | ||||
|       // In 2 blocks of values, upper/lower halves of the f16x2_t | ||||
|       int wblockmax = group_abs_max<T>(atom); | ||||
|  | ||||
|       // Derive scales | ||||
|       int decoding_scale; | ||||
|       int encoding_scale; | ||||
|       decoding_scale = packed_mul<T>(wblockmax, kScaleFactor); | ||||
|       encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon); | ||||
|       encoding_scale = packed_rcp<T>(encoding_scale); | ||||
|  | ||||
|       // Apply scales to get quantized values | ||||
|       int32x4_t w; | ||||
|       for (int i = 0; i < 4; i++) { | ||||
|         w[i] = packed_mul<T>(atom[i], encoding_scale); | ||||
|         w[i] = packed_max<T>(w[i], kRangeMin); | ||||
|         w[i] = packed_min<T>(w[i], kRangeMax); | ||||
|       } | ||||
|  | ||||
|       // Convert from f16x2_t to uint16x2_t | ||||
|       int32x4_t q; | ||||
|       { | ||||
|         int16_t* qi = reinterpret_cast<int16_t*>(&q); | ||||
|         T* wh = reinterpret_cast<T*>(&w); | ||||
|         for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i])); | ||||
|  | ||||
|         for (int i = 0; i < 4; i++) { | ||||
|           q[i] = packed_add<int16_t>(q[i], kRangeBias); | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       // Pack 8 x q8 into int32x2_t | ||||
|       int32x2_t qw; | ||||
|       qw[0] = q[0] | (q[1] << 8); | ||||
|       qw[1] = q[2] | (q[3] << 8); | ||||
|  | ||||
|       // Write quantized atom to send_buffer | ||||
|       // note: only the group leader stores the scale | ||||
|       uint8_t* atom_ptr = | ||||
|           reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride); | ||||
|       int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread; | ||||
|       int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + | ||||
|                     (thread / 8); | ||||
|  | ||||
|       __builtin_nontemporal_store(qw, qw_ptr); | ||||
|       if (threadIdx.x == group_leader) { | ||||
|         __builtin_nontemporal_store(decoding_scale, qs_ptr); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   __quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer, | ||||
|                                           int32x4_t* __restrict__ data) { | ||||
|     for (int k = 0; k < kRankAtoms; k++) { | ||||
|       // Directly read quantized atom from recv_buffer | ||||
|       uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer); | ||||
|       int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread; | ||||
|       int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) + | ||||
|                     (thread / 8); | ||||
|  | ||||
|       int32x2_t qw = __builtin_nontemporal_load(qw_ptr); | ||||
|       int qs = __builtin_nontemporal_load(qs_ptr); | ||||
|  | ||||
|       *recv_buffer += kRankBufferTileStride; | ||||
|  | ||||
|       // Unpack q8 into fp16x8_t | ||||
|       int32x4_t w; | ||||
|       { | ||||
|         static uint constexpr kMask00FF = 0x00FF00FF; | ||||
|  | ||||
|         // {1024.0, 1024.0}, fp16x2_t | ||||
|         static uint constexpr kHalf2_1024 = 0x64006400; | ||||
|  | ||||
|         // {-1152.0, -1152.0}, fp16x2_t | ||||
|         static uint constexpr kHalf2_1152 = 0xE480E480; | ||||
|  | ||||
| #pragma unroll | ||||
|         for (int i = 0; i < 4; i++) { | ||||
|           if constexpr (std::is_same<T, half>::value) { | ||||
|             int32_t q8 = | ||||
|                 ((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024; | ||||
|             w[i] = packed_add<half>(q8, kHalf2_1152); | ||||
|           } else { | ||||
|             int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF; | ||||
|             int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF); | ||||
|             int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF); | ||||
|             nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low)); | ||||
|             nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high)); | ||||
|             nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high); | ||||
|             int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2); | ||||
|             w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       // Apply decoding scales | ||||
|       for (int i = 0; i < 4; i++) { | ||||
|         w[i] = packed_mul<T>(w[i], qs); | ||||
|       } | ||||
|  | ||||
|       data[k] = w; | ||||
|     } | ||||
|   } | ||||
| }; | ||||
|  | ||||
| // Twoshot All Reduce | ||||
| template <typename T, class Codec, bool cast_bf2half> | ||||
| struct AllReduceTwoshot { | ||||
|   static_assert(sizeof(T) == 2); | ||||
|  | ||||
|   static constexpr int kWorldSize = Codec::kWorldSize; | ||||
|  | ||||
|   __device__ static void run( | ||||
|       T const* __restrict__ input, T* __restrict__ output, | ||||
|       uint32_t const N,                    // number of elements | ||||
|       int const block,                     // block index | ||||
|       int const rank,                      // rank index | ||||
|       uint8_t** __restrict__ buffer_list,  // communication buffers | ||||
|       uint32_t const data_offset,          // offset to start of the data buffer | ||||
|       uint32_t flag_color) { | ||||
|     // Topology | ||||
|     int thread = threadIdx.x + threadIdx.y * kWavefront; | ||||
|     uint8_t* rank_buffer = buffer_list[rank]; | ||||
|     Codec codec(thread, rank); | ||||
|     int block_id = blockIdx.x; | ||||
|     int grid_size = gridDim.x; | ||||
|     // -------------------------------------------------------- | ||||
|     // Read input into registers | ||||
|     int32x4_t tA[kAtoms]; | ||||
|  | ||||
|     BufferResource src_buffer(const_cast<T*>(input), N * sizeof(T)); | ||||
|     uint32_t src_offset = block * kTileSize + thread * sizeof(int32x4_t); | ||||
|  | ||||
|     for (int i = 0; i < kAtoms; i++) { | ||||
|       tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0); | ||||
|       src_offset += kAtomStride * sizeof(int32x4_t); | ||||
|       if constexpr (cast_bf2half) { | ||||
|         const nv_bfloat162* bf_buf = | ||||
|             reinterpret_cast<const nv_bfloat162*>(&tA[i]); | ||||
|         half2 half_buf[4]; | ||||
| #pragma unroll | ||||
|         for (int j = 0; j < 4; ++j) { | ||||
|           float2 f = __bfloat1622float2(bf_buf[j]); | ||||
|           half_buf[j] = __float22half2_rn(f); | ||||
|         } | ||||
|         tA[i] = *reinterpret_cast<const int32x4_t*>(half_buf); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     // -------------------------------------------------------- | ||||
|     // Phase-1A: Write segment data into the communication buffer of the target | ||||
|     // rank responsible for this segment. | ||||
|     uint32_t comm_data0_offset = | ||||
|         data_offset + block_id * Codec::kTransmittedTileSize; | ||||
|     uint32_t comm_data1_offset = | ||||
|         grid_size * Codec::kTransmittedTileSize + comm_data0_offset; | ||||
|  | ||||
|     uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t)); | ||||
|     uint32_t comm_flags1_offset = | ||||
|         grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset; | ||||
|  | ||||
|     for (int r = 0; r < kWorldSize; r++) { | ||||
|       int32x4_t* send_buffer = | ||||
|           reinterpret_cast<int32x4_t*>(buffer_list[r] + comm_data0_offset + | ||||
|                                        rank * Codec::kRankTransmittedTileSize); | ||||
|       codec.send(send_buffer, &tA[r * Codec::kRankAtoms]); | ||||
|     } | ||||
|  | ||||
|     __syncthreads(); | ||||
|     if (thread < kWorldSize) { | ||||
|       int r = thread; | ||||
|       uint32_t* flag_ptr = reinterpret_cast<uint32_t*>( | ||||
|           buffer_list[r] + comm_flags0_offset + rank * sizeof(uint32_t)); | ||||
|       set_sync_flag(flag_ptr, flag_color); | ||||
|     } | ||||
|     // -------------------------------------------------------- | ||||
|     // Phase-1B: Reduce the segment data from the communication buffers. | ||||
|     int32x4_t tR[Codec::kRankAtoms] = {}; | ||||
|     { | ||||
|       // Read the data from the communication buffer. | ||||
|       int32x4_t* recv_buffer = | ||||
|           reinterpret_cast<int32x4_t*>(rank_buffer + comm_data0_offset); | ||||
|       uint32_t* flag_ptr = | ||||
|           reinterpret_cast<uint32_t*>(rank_buffer + comm_flags0_offset); | ||||
|  | ||||
|       for (int r = 0; r < kWorldSize; r++) { | ||||
|         // Wait for the flags to be set. | ||||
|         if (thread == 0) { | ||||
|           wait_sync_flag(&flag_ptr[r], flag_color); | ||||
|         } | ||||
|         __syncthreads(); | ||||
|  | ||||
|         // note: we reuse tA as temp buffer here | ||||
|         codec.recv(&recv_buffer, tA); | ||||
|  | ||||
|         for (int i = 0; i < Codec::kRankAtoms; i++) { | ||||
|           packed_assign_add<T>(&tR[i], &tA[i]); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     // Phase-2: Write the reduced segment to every other rank | ||||
|     for (int r = 0; r < kWorldSize; r++) { | ||||
|       int32x4_t* send_buffer = | ||||
|           reinterpret_cast<int32x4_t*>(buffer_list[r] + comm_data1_offset + | ||||
|                                        rank * Codec::kRankTransmittedTileSize); | ||||
|       codec.send(send_buffer, tR); | ||||
|     } | ||||
|  | ||||
|     __syncthreads(); | ||||
|     if (thread < kWorldSize) { | ||||
|       int r = thread; | ||||
|       uint32_t* flag_ptr = reinterpret_cast<uint32_t*>( | ||||
|           buffer_list[r] + comm_flags1_offset + rank * sizeof(uint32_t)); | ||||
|       set_sync_flag(flag_ptr, flag_color); | ||||
|     } | ||||
|  | ||||
|     // Phase-2: Read the gather segments from the rank's communication buffer. | ||||
|     { | ||||
|       // Read the data from the communication buffer. | ||||
|       int32x4_t* recv_buffer = | ||||
|           reinterpret_cast<int32x4_t*>(rank_buffer + comm_data1_offset); | ||||
|       uint32_t* flag_ptr = | ||||
|           reinterpret_cast<uint32_t*>(rank_buffer + comm_flags1_offset); | ||||
|  | ||||
|       for (int r = 0; r < kWorldSize; r++) { | ||||
|         // Wait for the flags to be set. | ||||
|         if (thread == 0) { | ||||
|           wait_sync_flag(&flag_ptr[r], flag_color); | ||||
|         } | ||||
|         __syncthreads(); | ||||
|  | ||||
|         // Gather all reduced and final rank segments into tA. | ||||
|         codec.recv(&recv_buffer, &tA[r * Codec::kRankAtoms]); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     // -------------------------------------------------------- | ||||
|     // Write the result to output. | ||||
|     BufferResource dst_buffer(output, N * sizeof(T)); | ||||
|     uint32_t dst_offset = block * kTileSize + thread * sizeof(int32x4_t); | ||||
|  | ||||
|     for (int i = 0; i < kAtoms; i++) { | ||||
|       if constexpr (cast_bf2half) { | ||||
|         const half2* half_buf = reinterpret_cast<const half2*>(&tA[i]); | ||||
|         nv_bfloat162 bf16_buf[4]; | ||||
| #pragma unroll | ||||
|         for (int j = 0; j < 4; ++j) { | ||||
|           float2 f = __half22float2(half_buf[j]); | ||||
|           bf16_buf[j] = __float22bfloat162_rn(f); | ||||
|         } | ||||
|         buffer_store_dwordx4(*reinterpret_cast<const int32x4_t*>(bf16_buf), | ||||
|                              dst_buffer.descriptor, dst_offset, 0, 0); | ||||
|       } else { | ||||
|         buffer_store_dwordx4(tA[i], dst_buffer.descriptor, dst_offset, 0, 0); | ||||
|       } | ||||
|       dst_offset += kAtomStride * sizeof(int32x4_t); | ||||
|     } | ||||
|   } | ||||
| }; | ||||
|  | ||||
| }  // namespace quickreduce | ||||
| @ -136,11 +136,6 @@ __device__ __forceinline__ T from_float(const float& inp) { | ||||
|  | ||||
| template <typename T> | ||||
| __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { | ||||
|   [[maybe_unused]] union tmpcvt { | ||||
|     uint16_t u; | ||||
|     _Float16 f; | ||||
|     __hip_bfloat16 b; | ||||
|   } t16; | ||||
|   _B16x4 ret; | ||||
|   if constexpr (std::is_same<T, _Float16>::value) { | ||||
|     union h2cvt { | ||||
| @ -169,11 +164,6 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { | ||||
| template <typename T> | ||||
| __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, | ||||
|                                         const _B16x4& inp2) { | ||||
|   [[maybe_unused]] union tmpcvt { | ||||
|     uint16_t u; | ||||
|     _Float16 f; | ||||
|     __hip_bfloat16 b; | ||||
|   } t1, t2, res; | ||||
|   _B16x4 ret; | ||||
|   if constexpr (std::is_same<T, _Float16>::value) { | ||||
|     union h2cvt { | ||||
| @ -325,8 +315,6 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( | ||||
|  | ||||
|   constexpr int GQA_RATIO4 = DIVIDE_ROUND_UP(GQA_RATIO, 4); | ||||
|  | ||||
|   [[maybe_unused]] __shared__ float shared_qk_max[NWARPS][16 + 1]; | ||||
|   [[maybe_unused]] __shared__ float shared_exp_sum[NWARPS][16 + 1]; | ||||
|   // shared_logits is used for multiple purposes | ||||
|   __shared__ _B16x4 shared_logits[NWARPS][4][16][4]; | ||||
|  | ||||
| @ -444,8 +432,6 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( | ||||
|     const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; | ||||
|     const int klocal_token_idx = | ||||
|         TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; | ||||
|     [[maybe_unused]] const int kglobal_token_idx = | ||||
|         partition_start_token_idx + klocal_token_idx; | ||||
|     const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; | ||||
|     const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; | ||||
|  | ||||
| @ -1309,9 +1295,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( | ||||
|  | ||||
|   const int context_len = context_lens[seq_idx]; | ||||
|   const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); | ||||
|   [[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; | ||||
|   const auto warpid = threadIdx.x / WARP_SIZE; | ||||
|   [[maybe_unused]] const auto laneid = threadIdx.x % WARP_SIZE; | ||||
|  | ||||
|   __shared__ float shared_global_exp_sum; | ||||
|   // max num partitions supported is warp_size * NPAR_LOOPS | ||||
| @ -1614,7 +1598,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( | ||||
|   const int warpid = threadIdx.x / WARP_SIZE; | ||||
|   const int laneid = threadIdx.x % WARP_SIZE; | ||||
|   const int lane2id = laneid % 2; | ||||
|   const int lane4id = laneid % 4; | ||||
|   const int lane16id = laneid % 16; | ||||
|   const int rowid = laneid / 16; | ||||
|  | ||||
| @ -1761,7 +1744,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( | ||||
|     const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; | ||||
|     const int klocal_token_idx = | ||||
|         TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; | ||||
|     const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; | ||||
|     const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; | ||||
|     const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; | ||||
|  | ||||
| @ -2080,9 +2062,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( | ||||
|  | ||||
|   const int context_len = context_lens[seq_idx]; | ||||
|   const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); | ||||
|   [[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; | ||||
|   const int warpid = threadIdx.x / WARP_SIZE; | ||||
|   [[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE; | ||||
|  | ||||
|   __shared__ float shared_global_exp_sum; | ||||
|   // max num partitions supported is warp_size * NPAR_LOOPS | ||||
| @ -2386,7 +2366,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( | ||||
|   const int warpid = threadIdx.x / WARP_SIZE; | ||||
|   const int laneid = threadIdx.x % WARP_SIZE; | ||||
|   const int lane2id = laneid % 2; | ||||
|   const int lane4id = laneid % 4; | ||||
|   const int lane16id = laneid % 16; | ||||
|   const int rowid = laneid / 16; | ||||
|  | ||||
| @ -2532,7 +2511,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel( | ||||
|     const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; | ||||
|     const int klocal_token_idx = | ||||
|         TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; | ||||
|     const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; | ||||
|     const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; | ||||
|     const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; | ||||
|  | ||||
| @ -2816,9 +2794,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( | ||||
|  | ||||
|   const int context_len = context_lens[seq_idx]; | ||||
|   const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); | ||||
|   [[maybe_unused]] constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; | ||||
|   const int warpid = threadIdx.x / WARP_SIZE; | ||||
|   [[maybe_unused]] const int laneid = threadIdx.x % WARP_SIZE; | ||||
|  | ||||
|   __shared__ float shared_global_exp_sum; | ||||
|   // max num partitions supported is warp_size * NPAR_LOOPS | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	