mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-20 23:03:52 +08:00 
			
		
		
		
	Compare commits
	
		
			288 Commits
		
	
	
		
			v0.10.1.1
			...
			khluu/test
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 91e382c935 | |||
| 6446677839 | |||
| 69244e67e6 | |||
| 8dbf6ed7be | |||
| 9de25c294b | |||
| fce10dbed5 | |||
| d272415e57 | |||
| 142ac08030 | |||
| 3210264421 | |||
| 644d57d531 | |||
| c905684cfe | |||
| 786835807b | |||
| fecbb7c782 | |||
| 6dab89b8ec | |||
| de02b07db4 | |||
| eb1995167e | |||
| 2c2b140ae8 | |||
| c7c80af084 | |||
| 6891205b16 | |||
| b1625dbe9c | |||
| 585e0bde36 | |||
| 714872f1a9 | |||
| 5f1af97f86 | |||
| c3b0fd1ee6 | |||
| 6421b66bf4 | |||
| 2f13319f47 | |||
| d696f86e7b | |||
| 9816b81f5f | |||
| c37c0af990 | |||
| 9715f7bb0f | |||
| 98aa16ff41 | |||
| 227e231b55 | |||
| 730d0ac8b9 | |||
| 9b0187003e | |||
| 44ac25eae2 | |||
| 7ea22e42d5 | |||
| 9d4183dd2e | |||
| 513298f1b4 | |||
| 379f828fba | |||
| 1fdc732419 | |||
| f58675bfb3 | |||
| 7c04779afa | |||
| f66673a39d | |||
| b78bed1bc5 | |||
| 164b2273c8 | |||
| 2b4fc9bd9b | |||
| ebd5a77bb5 | |||
| 384dd1b0a8 | |||
| fdeb3dac13 | |||
| d52358c1e0 | |||
| 6ace2f72b0 | |||
| b00e69f8ca | |||
| 50fede6634 | |||
| b5d34af328 | |||
| 9b5f64238f | |||
| ff77764f86 | |||
| bfc1edc9f5 | |||
| 3ecbb14b81 | |||
| 7d67a9d9f9 | |||
| 959783fb99 | |||
| ce0e9dbd43 | |||
| b395b3b0a3 | |||
| 6fad29b11b | |||
| 6fd45e7b8a | |||
| 56dcf4e7e9 | |||
| ae067888d6 | |||
| 906e461ed6 | |||
| 2a97ffc33d | |||
| efc88cf64a | |||
| 7b6a837275 | |||
| c34c82b7fe | |||
| 8a044754bd | |||
| 9188ae7cb5 | |||
| 8a3cd90af5 | |||
| 2a167b2eeb | |||
| 0ff902f3b4 | |||
| a9082a4d14 | |||
| e0329ed4b4 | |||
| 6879cd80ae | |||
| e269be2ba2 | |||
| 5c4b6e66fe | |||
| d0a4a3f645 | |||
| ebafb0936d | |||
| 0cb7b065c3 | |||
| 2da02dd0d8 | |||
| d765cf01fe | |||
| 712d0f88d8 | |||
| 49ab23b3cc | |||
| c9abb10489 | |||
| 787cdb3829 | |||
| a5203d04df | |||
| 99f8094400 | |||
| 170e8ea9ea | |||
| a71e4765cc | |||
| 39971db3aa | |||
| 504d914314 | |||
| 47455c424f | |||
| c7fc6b1354 | |||
| ad78868450 | |||
| e2db1164a1 | |||
| 416f05929a | |||
| 5e021b4981 | |||
| 1b9b16649c | |||
| e76e233540 | |||
| a75277285b | |||
| 9dc30b7068 | |||
| 053278a5dc | |||
| c55c028998 | |||
| 65197a5fb3 | |||
| b8f17f5d98 | |||
| d9a55204ba | |||
| b4e9fd811f | |||
| 308fa287a8 | |||
| fa78de9dc3 | |||
| f6818a92cb | |||
| 23c939fd30 | |||
| add1adfec7 | |||
| c80c53a30f | |||
| 24d0c9e6ed | |||
| cc7ae5e7ca | |||
| 0313cf854d | |||
| 0483fabc74 | |||
| da65bec309 | |||
| 4645024d3a | |||
| cd7a3df26f | |||
| 32d2b4064f | |||
| 22cf679aad | |||
| b6d7d34fc6 | |||
| 341923b982 | |||
| 424fb7a5d2 | |||
| 88491c1b6b | |||
| 613a23b57f | |||
| 51a215300b | |||
| ebe14621e3 | |||
| 325aa3dee9 | |||
| a073be6d87 | |||
| 695e7adcd2 | |||
| 281710ef9a | |||
| 808d2e9aa0 | |||
| 285178b3b8 | |||
| 88016c372a | |||
| 998720859c | |||
| 0ba1b54ac6 | |||
| 53415653ff | |||
| 17373dcd93 | |||
| 5964069367 | |||
| de9c085e17 | |||
| 111692bb8c | |||
| 394591e343 | |||
| 3ac849665d | |||
| 0b9cc56fac | |||
| 8896eb72eb | |||
| 19fe1a0510 | |||
| 480bdf5a7b | |||
| 5368f76855 | |||
| 8ef6b8a38c | |||
| 3bbe11cc13 | |||
| c5041f899f | |||
| 8b5fe6eb51 | |||
| 800349c2a5 | |||
| 044931f97b | |||
| 1d353b6352 | |||
| 3496274663 | |||
| 8a19303173 | |||
| 603fbbbce0 | |||
| 10f535c086 | |||
| 48bfb0c9b7 | |||
| f8ce022948 | |||
| 0278f1ac3a | |||
| a482e4e769 | |||
| e0b056e443 | |||
| 79f05e4436 | |||
| f8daddcc4c | |||
| c8e33c72c6 | |||
| d70a16625d | |||
| 5cc54f7c5b | |||
| 0c6e40bbaa | |||
| 2e2000f352 | |||
| 31282401b6 | |||
| 0c31e28e95 | |||
| f571ff8eb6 | |||
| f64ee61d9e | |||
| 8993073dc1 | |||
| 655a09f653 | |||
| f94bf9b924 | |||
| 3663870c72 | |||
| 2461d9e562 | |||
| 7be5d113d8 | |||
| b029de9902 | |||
| bbea1cefdd | |||
| f5aa307d77 | |||
| 4b795020ed | |||
| c86af22f31 | |||
| 10cc12ba66 | |||
| a4fbb32fab | |||
| 1b125004be | |||
| 4fbda0b20c | |||
| 4e51fa8cba | |||
| bf7c99dfc4 | |||
| b95697d731 | |||
| 582bbe6bd7 | |||
| 0cdbf5e61c | |||
| ebe56a0064 | |||
| f77a0802b7 | |||
| c4477f55e5 | |||
| dfd2382039 | |||
| 3b11b26b50 | |||
| d6d13bd49e | |||
| 5efd6905bc | |||
| b17109beea | |||
| 4449235843 | |||
| 38217877aa | |||
| c6d80a7a96 | |||
| 7cd17e22d7 | |||
| 50df09fe13 | |||
| 68fcd3fa73 | |||
| 83e69a09d6 | |||
| 3aa8c10038 | |||
| 103f1ec8d3 | |||
| d983769c41 | |||
| 8fd920924c | |||
| de7b67a023 | |||
| f729023272 | |||
| 1a3079a15e | |||
| 941f56858a | |||
| a634733f67 | |||
| 64ab3c7253 | |||
| e58c5a9768 | |||
| d46d417b58 | |||
| 0167efe20d | |||
| c32e6ad1f6 | |||
| 1630cc8d0f | |||
| 14e2b0730b | |||
| 0f4f0191d8 | |||
| a38b8af4c3 | |||
| 21dce80ea9 | |||
| e61bac87ee | |||
| 80141bbf2f | |||
| b94faf9d50 | |||
| 5b5f350d67 | |||
| f7cf5b512e | |||
| 03d4235fd2 | |||
| d6a1a20973 | |||
| a70d0bd0a3 | |||
| 24f4d1a224 | |||
| 4f510bc2a1 | |||
| 1298c67795 | |||
| 4d9c61993a | |||
| b87cb97a53 | |||
| f856c33ce9 | |||
| 03752dba8f | |||
| 40f26734b9 | |||
| 2c3f557f08 | |||
| 21bcc8263f | |||
| 5bfe0dea7a | |||
| 31fd3265c8 | |||
| 31436e8b4f | |||
| 4efd43e9b4 | |||
| 3c8a787247 | |||
| 01a08739e0 | |||
| fda9537c5e | |||
| 90bbe0a5ad | |||
| e75f342261 | |||
| 78dba404ad | |||
| e9d6a3db69 | |||
| a4454e9401 | |||
| 14006840ea | |||
| 6603288736 | |||
| 95e3095136 | |||
| c9b38be8aa | |||
| 0dd3f4f5ab | |||
| 498259ccce | |||
| 6d25e3fd6e | |||
| ac6eb49de3 | |||
| bf756321c7 | |||
| 0e3bb543f0 | |||
| 569aefd134 | |||
| d3f71f1224 | |||
| 5a30bd10d8 | |||
| 27e8d1ea3e | |||
| 5c79b0d648 | |||
| 5f5664b3e4 | |||
| 89657a557c | |||
| 08d5f7113a | |||
| b2fd0b81e0 | |||
| 9f1c642254 | |||
| 7be3a59d8e | |||
| 8ea0c2753a | 
| @ -8,7 +8,8 @@ template = """<!DOCTYPE html> | ||||
| <html> | ||||
|     <body> | ||||
|     <h1>Links for vLLM</h1/> | ||||
|         <a href="../{wheel_html_escaped}">{wheel}</a><br/> | ||||
|         <a href="../{x86_wheel_html_escaped}">{x86_wheel}</a><br/> | ||||
|         <a href="../{arm_wheel_html_escaped}">{arm_wheel}</a><br/> | ||||
|     </body> | ||||
| </html> | ||||
| """ | ||||
| @ -21,7 +22,25 @@ filename = os.path.basename(args.wheel) | ||||
|  | ||||
| with open("index.html", "w") as f: | ||||
|     print(f"Generated index.html for {args.wheel}") | ||||
|     # sync the abi tag with .buildkite/scripts/upload-wheels.sh | ||||
|     if "x86_64" in filename: | ||||
|         x86_wheel = filename | ||||
|         arm_wheel = filename.replace("x86_64", "aarch64").replace( | ||||
|             "manylinux1", "manylinux2014" | ||||
|         ) | ||||
|     elif "aarch64" in filename: | ||||
|         x86_wheel = filename.replace("aarch64", "x86_64").replace( | ||||
|             "manylinux2014", "manylinux1" | ||||
|         ) | ||||
|         arm_wheel = filename | ||||
|     else: | ||||
|         raise ValueError(f"Unsupported wheel: {filename}") | ||||
|     # cloudfront requires escaping the '+' character | ||||
|     f.write( | ||||
|         template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B")) | ||||
|         template.format( | ||||
|             x86_wheel=x86_wheel, | ||||
|             x86_wheel_html_escaped=x86_wheel.replace("+", "%2B"), | ||||
|             arm_wheel=arm_wheel, | ||||
|             arm_wheel_html_escaped=arm_wheel.replace("+", "%2B"), | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
| @ -1,12 +0,0 @@ | ||||
| # For vllm script, with -t option (tensor parallel size). | ||||
| # bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1 | ||||
| model_name: "HandH1998/QQQ-Llama-3-8b-g128" | ||||
| tasks: | ||||
| - name: "gsm8k" | ||||
|   metrics: | ||||
|   - name: "exact_match,strict-match" | ||||
|     value: 0.419 | ||||
|   - name: "exact_match,flexible-extract" | ||||
|     value: 0.416 | ||||
| limit: 1000 | ||||
| num_fewshot: 5 | ||||
| @ -3,4 +3,3 @@ Meta-Llama-3-70B-Instruct.yaml | ||||
| Mixtral-8x7B-Instruct-v0.1.yaml | ||||
| Qwen2-57B-A14-Instruct.yaml | ||||
| DeepSeek-V2-Lite-Chat.yaml | ||||
| Meta-Llama-3-8B-QQQ.yaml | ||||
|  | ||||
| @ -2,7 +2,7 @@ | ||||
| # We can use this script to compute baseline accuracy on GSM for transformers. | ||||
| # | ||||
| # Make sure you have lm-eval-harness installed: | ||||
| #   pip install lm-eval==0.4.4 | ||||
| #   pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] | ||||
|  | ||||
| usage() { | ||||
|     echo`` | ||||
|  | ||||
| @ -3,7 +3,7 @@ | ||||
| # We use this for fp8, which HF does not support. | ||||
| # | ||||
| # Make sure you have lm-eval-harness installed: | ||||
| #   pip install lm-eval==0.4.4 | ||||
| #   pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] | ||||
|  | ||||
| usage() { | ||||
|     echo`` | ||||
|  | ||||
| @ -141,7 +141,7 @@ When run, benchmark script generates results under `benchmark/results` folder, a | ||||
| `compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT.   | ||||
| If only one benchmark_results.json is passed, `compare-json-results.py` compares different TP and PP configurations in the benchmark_results.json instead. | ||||
|  | ||||
| Here is an example using the script to compare result_a and result_b with Model, Dataset name, input/output lenght, max concurrency and qps. | ||||
| Here is an example using the script to compare result_a and result_b with Model, Dataset name, input/output length, max concurrency and qps. | ||||
| `python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json` | ||||
|  | ||||
| |   | Model | Dataset Name | Input Len | Output Len | # of max concurrency | qps  | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio        | | ||||
|  | ||||
| @ -17,7 +17,7 @@ Latest reproduction guilde: [github issue link](https://github.com/vllm-project/ | ||||
|     - SGLang: `lmsysorg/sglang:v0.3.2-cu121` | ||||
|     - LMDeploy: `openmmlab/lmdeploy:v0.6.1-cu12` | ||||
|     - TensorRT-LLM: `nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3` | ||||
|         - *NOTE: we uses r24.07 as the current implementation only works for this version. We are going to bump this up.* | ||||
|         - *NOTE: we use r24.07 as the current implementation only works for this version. We are going to bump this up.* | ||||
|     - Check [nightly-pipeline.yaml](nightly-pipeline.yaml) for the concrete docker images, specs and commands we use for the benchmark. | ||||
| - Hardware | ||||
|     - 8x Nvidia A100 GPUs | ||||
|  | ||||
| @ -3,44 +3,129 @@ | ||||
| import argparse | ||||
| import json | ||||
| import os | ||||
| from importlib import util | ||||
|  | ||||
| import pandas as pd | ||||
|  | ||||
| plotly_found = util.find_spec("plotly.express") is not None | ||||
|  | ||||
|  | ||||
| def compare_data_columns( | ||||
|     files, name_column, data_column, info_cols, drop_column, debug=False | ||||
| ): | ||||
|     print("\ncompare_data_column: " + data_column) | ||||
|     """ | ||||
|     Align concatenation by keys derived from info_cols instead of row order. | ||||
|     - Pick one canonical key list: subset of info_cols present in ALL files. | ||||
|     - For each file: set index to those keys, aggregate duplicates | ||||
|     - (mean for metric, first for names). | ||||
|     - Concat along axis=1 (indexes align), then reset_index so callers can | ||||
|     - group by columns. | ||||
|     - If --debug, add a <file_label>_name column per file. | ||||
|     """ | ||||
|     print("\ncompare_data_column:", data_column) | ||||
|  | ||||
|     frames = [] | ||||
|     raw_data_cols = [] | ||||
|     compare_frames = [] | ||||
|  | ||||
|     # 1) choose a canonical key list from info_cols that exists in ALL files | ||||
|     cols_per_file = [] | ||||
|     for f in files: | ||||
|         try: | ||||
|             df_tmp = pd.read_json(f, orient="records") | ||||
|         except Exception as err: | ||||
|             raise ValueError(f"Failed to read {f}") from err | ||||
|         cols_per_file.append(set(df_tmp.columns)) | ||||
|  | ||||
|     key_cols = [c for c in info_cols if all(c in cset for cset in cols_per_file)] | ||||
|     if not key_cols: | ||||
|         # soft fallback: use any info_cols present in the first file | ||||
|         key_cols = [c for c in info_cols if c in list(cols_per_file[0])] | ||||
|     if not key_cols: | ||||
|         raise ValueError( | ||||
|             "No common key columns found from info_cols across the input files." | ||||
|         ) | ||||
|  | ||||
|     # 2) build a single "meta" block (keys as columns) once, aligned by the key index | ||||
|     meta_added = False | ||||
|  | ||||
|     for file in files: | ||||
|         data_df = pd.read_json(file) | ||||
|         serving_df = data_df.dropna(subset=[drop_column], ignore_index=True) | ||||
|         # Show all info columns in the first couple columns | ||||
|         if not frames: | ||||
|             for col in info_cols: | ||||
|                 if col not in serving_df.columns: | ||||
|                     print(f"Skipping missing column: {col}") | ||||
|                     continue | ||||
|                 frames.append(serving_df[col]) | ||||
|         # only show test name under debug mode | ||||
|         if debug is True: | ||||
|             serving_df = serving_df.rename(columns={name_column: file + "_name"}) | ||||
|             frames.append(serving_df[file + "_name"]) | ||||
|         df = pd.read_json(file, orient="records") | ||||
|  | ||||
|         file = "/".join(file.split("/")[:-1]) | ||||
|         serving_df = serving_df.rename(columns={data_column: file}) | ||||
|         frames.append(serving_df[file]) | ||||
|         raw_data_cols.append(file) | ||||
|         compare_frames.append(serving_df[file]) | ||||
|         # Keep rows that actually have the compared metric (same as original behavior) | ||||
|         if drop_column in df.columns: | ||||
|             df = df.dropna(subset=[drop_column], ignore_index=True) | ||||
|  | ||||
|         # Stabilize numeric key columns (harmless if missing) | ||||
|         for c in ( | ||||
|             "Input Len", | ||||
|             "Output Len", | ||||
|             "TP Size", | ||||
|             "PP Size", | ||||
|             "# of max concurrency.", | ||||
|             "qps", | ||||
|         ): | ||||
|             if c in df.columns: | ||||
|                 df[c] = pd.to_numeric(df[c], errors="coerce") | ||||
|  | ||||
|         # Ensure all key columns exist | ||||
|         for c in key_cols: | ||||
|             if c not in df.columns: | ||||
|                 df[c] = pd.NA | ||||
|  | ||||
|         # Set index = key_cols and aggregate duplicates → unique MultiIndex | ||||
|         df_idx = df.set_index(key_cols, drop=False) | ||||
|  | ||||
|         # meta (key columns), unique per key | ||||
|         meta = df_idx[key_cols] | ||||
|         if not meta.index.is_unique: | ||||
|             meta = meta.groupby(level=key_cols, dropna=False).first() | ||||
|  | ||||
|         # metric series for this file, aggregated to one row per key | ||||
|         file_label = "/".join(file.split("/")[:-1]) or os.path.basename(file) | ||||
|         s = df_idx[data_column] | ||||
|         if not s.index.is_unique: | ||||
|             s = s.groupby(level=key_cols, dropna=False).mean() | ||||
|         s.name = file_label  # column label like original | ||||
|  | ||||
|         # add meta once (from first file) so keys are the leftmost columns | ||||
|         if not meta_added: | ||||
|             frames.append(meta) | ||||
|             meta_added = True | ||||
|  | ||||
|         # (NEW) debug: aligned test-name column per file | ||||
|         if debug and name_column in df_idx.columns: | ||||
|             name_s = df_idx[name_column] | ||||
|             if not name_s.index.is_unique: | ||||
|                 name_s = name_s.groupby(level=key_cols, dropna=False).first() | ||||
|             name_s.name = f"{file_label}_name" | ||||
|             frames.append(name_s) | ||||
|  | ||||
|         frames.append(s) | ||||
|         raw_data_cols.append(file_label) | ||||
|         compare_frames.append(s) | ||||
|  | ||||
|         # Generalize ratio: for any file N>=2, add ratio (fileN / file1) | ||||
|         if len(compare_frames) >= 2: | ||||
|             # Compare numbers among two files | ||||
|             ratio_df = compare_frames[1] / compare_frames[0] | ||||
|             frames.append(ratio_df) | ||||
|             compare_frames.pop(1) | ||||
|             base = compare_frames[0] | ||||
|             current = compare_frames[-1] | ||||
|             ratio = current / base | ||||
|             ratio = ratio.mask(base == 0)  # avoid inf when baseline is 0 | ||||
|             ratio.name = f"Ratio 1 vs {len(compare_frames)}" | ||||
|             frames.append(ratio) | ||||
|  | ||||
|     # 4) concat on columns with aligned MultiIndex; | ||||
|     # then reset_index to return keys as columns | ||||
|     concat_df = pd.concat(frames, axis=1) | ||||
|     concat_df = concat_df.reset_index(drop=True).reset_index() | ||||
|     if "index" in concat_df.columns: | ||||
|         concat_df = concat_df.drop(columns=["index"]) | ||||
|  | ||||
|     # Ensure key/info columns appear first (in your info_cols order) | ||||
|     front = [c for c in info_cols if c in concat_df.columns] | ||||
|     rest = [c for c in concat_df.columns if c not in front] | ||||
|     concat_df = concat_df[front + rest] | ||||
|  | ||||
|     print(raw_data_cols) | ||||
|     return concat_df, raw_data_cols | ||||
|  | ||||
| @ -67,6 +152,15 @@ def split_json_by_tp_pp( | ||||
|  | ||||
|     df = pd.DataFrame(data) | ||||
|  | ||||
|     # Keep only "serving" tests | ||||
|     name_col = next( | ||||
|         (c for c in ["Test name", "test_name", "Test Name"] if c in df.columns), None | ||||
|     ) | ||||
|     if name_col: | ||||
|         df = df[ | ||||
|             df[name_col].astype(str).str.contains(r"serving", case=False, na=False) | ||||
|         ].copy() | ||||
|  | ||||
|     # Handle alias column names | ||||
|     rename_map = { | ||||
|         "tp_size": "TP Size", | ||||
| @ -181,7 +275,6 @@ if __name__ == "__main__": | ||||
|                     f"Expected subset: {filtered_info_cols}, " | ||||
|                     f"but DataFrame has: {list(output_df.columns)}" | ||||
|                 ) | ||||
|  | ||||
|             output_df_sorted = output_df.sort_values(by=existing_group_cols) | ||||
|             output_groups = output_df_sorted.groupby(existing_group_cols, dropna=False) | ||||
|             for name, group in output_groups: | ||||
| @ -189,8 +282,7 @@ if __name__ == "__main__": | ||||
|                 text_file.write(html_msgs_for_data_cols[i]) | ||||
|                 text_file.write(html) | ||||
|  | ||||
|                 if plot is True: | ||||
|                     import pandas as pd | ||||
|                 if plot and plotly_found: | ||||
|                     import plotly.express as px | ||||
|  | ||||
|                     df = group[raw_data_cols] | ||||
|  | ||||
| @ -382,7 +382,7 @@ run_genai_perf_tests() { | ||||
|       client_command="genai-perf profile \ | ||||
|         -m $model \ | ||||
|         --service-kind openai \ | ||||
|         --backend vllm \ | ||||
|         --backend "$backend" \ | ||||
|         --endpoint-type chat \ | ||||
|         --streaming \ | ||||
|         --url localhost:$port \ | ||||
|  | ||||
| @ -27,7 +27,12 @@ steps: | ||||
|     env: | ||||
|       DOCKER_BUILDKIT: "1" | ||||
|  | ||||
|   - block: "Build CUDA 12.6 wheel" | ||||
|     key: block-build-cu126-wheel | ||||
|     depends_on: ~ | ||||
|  | ||||
|   - label: "Build wheel - CUDA 12.6" | ||||
|     depends_on: block-build-cu126-wheel | ||||
|     id: build-wheel-cuda-12-6 | ||||
|     agents: | ||||
|       queue: cpu_queue_postmerge | ||||
| @ -68,7 +73,7 @@ steps: | ||||
|       queue: cpu_queue_postmerge | ||||
|     commands: | ||||
|       - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" | ||||
|       - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." | ||||
|       - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg FLASHINFER_AOT_COMPILE=true --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." | ||||
|       - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" | ||||
|  | ||||
|   - label: "Annotate release workflow" | ||||
|  | ||||
| @ -46,6 +46,11 @@ function cpu_tests() { | ||||
|     set -e | ||||
|     python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m" | ||||
|  | ||||
|   # Run kernel tests | ||||
|   docker exec cpu-test-"$NUMA_NODE" bash -c " | ||||
|     set -e | ||||
|     pytest -v -s tests/kernels/test_onednn.py" | ||||
|  | ||||
|   # Run basic model test | ||||
|   docker exec cpu-test-"$NUMA_NODE" bash -c " | ||||
|     set -e | ||||
| @ -99,4 +104,4 @@ function cpu_tests() { | ||||
|  | ||||
| # All of CPU tests are expected to be finished less than 40 mins. | ||||
| export -f cpu_tests | ||||
| timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" | ||||
| timeout 2h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" | ||||
|  | ||||
| @ -61,7 +61,7 @@ echo "Results will be stored in: $RESULTS_DIR" | ||||
| echo "--- Installing Python dependencies ---" | ||||
| python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \ | ||||
|     && python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \ | ||||
|     && python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \ | ||||
|     && python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \ | ||||
|     && python3 -m pip install --progress-bar off hf-transfer | ||||
| echo "--- Python dependencies installed ---" | ||||
| export VLLM_USE_V1=1 | ||||
|  | ||||
| @ -61,7 +61,7 @@ echo "Results will be stored in: $RESULTS_DIR" | ||||
| echo "--- Installing Python dependencies ---" | ||||
| python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \ | ||||
|     && python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \ | ||||
|     && python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \ | ||||
|     && python3 -m pip install --progress-bar off "lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d" \ | ||||
|     && python3 -m pip install --progress-bar off hf-transfer | ||||
| echo "--- Python dependencies installed ---" | ||||
| export VLLM_USE_V1=1 | ||||
|  | ||||
| @ -23,10 +23,15 @@ docker run \ | ||||
|     --device /dev/dri \ | ||||
|     -v /dev/dri/by-path:/dev/dri/by-path \ | ||||
|     --entrypoint="" \ | ||||
|     -e "HF_TOKEN=${HF_TOKEN}" \ | ||||
|     -e "ZE_AFFINITY_MASK=${ZE_AFFINITY_MASK}" \ | ||||
|     --name "${container_name}" \ | ||||
|     "${image_name}" \ | ||||
|     sh -c ' | ||||
|     bash -c ' | ||||
|     set -e | ||||
|     echo $ZE_AFFINITY_MASK | ||||
|     VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager | ||||
|     VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE | ||||
|     VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray | ||||
|     VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp | ||||
|     cd tests | ||||
| @ -35,8 +40,8 @@ docker run \ | ||||
|     pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py | ||||
|     pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py | ||||
|     pytest -v -s v1/structured_output | ||||
|     pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py | ||||
|     pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py | ||||
|     pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py | ||||
|     pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py | ||||
|     pytest -v -s v1/test_serial_utils.py | ||||
|     pytest -v -s v1/test_utils.py | ||||
|     pytest -v -s v1/test_metrics_reader.py | ||||
|  | ||||
| @ -17,7 +17,7 @@ if [ "$disk_usage" -gt "$threshold" ]; then | ||||
|   # Remove dangling images (those that are not tagged and not used by any container) | ||||
|   docker image prune -f | ||||
|   # Remove unused volumes / force the system prune for old images as well. | ||||
|   docker volume prune -f && docker system prune --force --filter "until=72h" --all | ||||
|   docker volume prune -f && docker system prune --force --filter "until=24h" --all | ||||
|   echo "Docker images and volumes cleanup completed." | ||||
| else | ||||
|   echo "Disk usage is below $threshold%. No cleanup needed." | ||||
|  | ||||
| @ -14,8 +14,19 @@ fi | ||||
| # Get the single wheel file | ||||
| wheel="${wheel_files[0]}" | ||||
|  | ||||
| # Rename 'linux' to 'manylinux1' in the wheel filename | ||||
| new_wheel="${wheel/linux/manylinux1}" | ||||
| # Detect architecture and rename 'linux' to appropriate manylinux version | ||||
| arch=$(uname -m) | ||||
| if [[ $arch == "x86_64" ]]; then | ||||
|     manylinux_version="manylinux1" | ||||
| elif [[ $arch == "aarch64" ]]; then | ||||
|     manylinux_version="manylinux2014" | ||||
| else | ||||
|     echo "Warning: Unknown architecture $arch, using manylinux1 as default" | ||||
|     manylinux_version="manylinux1" | ||||
| fi | ||||
|  | ||||
| # Rename 'linux' to the appropriate manylinux version in the wheel filename | ||||
| new_wheel="${wheel/linux/$manylinux_version}" | ||||
| mv -- "$wheel" "$new_wheel" | ||||
| wheel="$new_wheel" | ||||
|  | ||||
|  | ||||
| @ -88,15 +88,6 @@ steps: | ||||
|   - pytest -v -s basic_correctness/test_cpu_offload.py | ||||
|   - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py | ||||
|  | ||||
| - label: Chunked Prefill Test | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   source_file_dependencies: | ||||
|   - vllm/ | ||||
|   - tests/basic_correctness/test_chunked_prefill | ||||
|   commands: | ||||
|   - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py | ||||
|   - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py | ||||
|  | ||||
| - label: Core Test # 10min | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   fast_check: true | ||||
| @ -135,7 +126,8 @@ steps: | ||||
|   - tests/entrypoints/test_chat_utils | ||||
|   commands: | ||||
|   - export VLLM_WORKER_MULTIPROC_METHOD=spawn | ||||
|   - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ | ||||
|   - PYTHONPATH=/vllm-workspace pytest -v -s entrypoints/openai/test_collective_rpc.py # PYTHONPATH is needed to import custom Worker extension | ||||
|   - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/ --ignore=entrypoints/openai/test_collective_rpc.py | ||||
|   - pytest -v -s entrypoints/test_chat_utils.py | ||||
|  | ||||
| - label: Distributed Tests (4 GPUs) # 10min | ||||
| @ -252,6 +244,7 @@ steps: | ||||
|     - pytest -v -s v1/core | ||||
|     - pytest -v -s v1/engine | ||||
|     - pytest -v -s v1/entrypoints | ||||
|     - pytest -v -s v1/executor | ||||
|     - pytest -v -s v1/sample | ||||
|     - pytest -v -s v1/logits_processors | ||||
|     - pytest -v -s v1/worker | ||||
| @ -295,15 +288,6 @@ steps: | ||||
|     - python3 offline_inference/basic/score.py | ||||
|     - VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 | ||||
|  | ||||
| - label: Prefix Caching Test # 9min | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   source_file_dependencies: | ||||
|   - vllm/ | ||||
|   - tests/prefix_caching | ||||
|   commands: | ||||
|     - pytest -v -s prefix_caching | ||||
|  | ||||
|  | ||||
| - label: Platform Tests (CUDA) | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   source_file_dependencies: | ||||
| @ -345,6 +329,7 @@ steps: | ||||
|     - pytest -v -s compile/test_sequence_parallelism.py | ||||
|     - pytest -v -s compile/test_async_tp.py | ||||
|     - pytest -v -s compile/test_fusion_all_reduce.py | ||||
|     - pytest -v -s compile/test_decorator.py | ||||
|  | ||||
| - label: PyTorch Fullgraph Smoke Test # 9min | ||||
|   mirror_hardwares: [amdexperimental] | ||||
| @ -358,6 +343,7 @@ steps: | ||||
|   - pytest -v -s compile/piecewise/test_simple.py | ||||
|   - pytest -v -s compile/piecewise/test_toy_llama.py | ||||
|   - pytest -v -s compile/piecewise/test_full_cudagraph.py | ||||
|   - pytest -v -s compile/piecewise/test_multiple_graphs.py | ||||
|  | ||||
| - label: PyTorch Fullgraph Test # 18min | ||||
|   mirror_hardwares: [amdexperimental] | ||||
| @ -404,6 +390,7 @@ steps: | ||||
|   - csrc/moe/ | ||||
|   - tests/kernels/moe | ||||
|   - vllm/model_executor/layers/fused_moe/ | ||||
|   - vllm/distributed/device_communicators/ | ||||
|   commands: | ||||
|     - pytest -v -s kernels/moe --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT | ||||
|   parallelism: 2 | ||||
| @ -468,13 +455,11 @@ steps: | ||||
|  | ||||
| - label: LM Eval Small Models # 53min | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" | ||||
|   source_file_dependencies: | ||||
|   - csrc/ | ||||
|   - vllm/model_executor/layers/quantization | ||||
|   commands: | ||||
|   - export VLLM_WORKER_MULTIPROC_METHOD=spawn | ||||
|   - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 | ||||
|   - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 | ||||
|  | ||||
| - label: OpenAI API correctness | ||||
|   mirror_hardwares: [amdexperimental] | ||||
| @ -562,6 +547,15 @@ steps: | ||||
|   commands: | ||||
|     - pytest -v -s models/language/pooling -m 'not core_model' | ||||
|  | ||||
| - label: Multi-Modal Processor Test | ||||
|   source_file_dependencies: | ||||
|   - vllm/ | ||||
|   - tests/models/multimodal | ||||
|   commands: | ||||
|     - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git | ||||
|     - pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py | ||||
|     - pytest -v -s models/multimodal/processing/test_tensor_schema.py | ||||
|  | ||||
| - label: Multi-Modal Models Test (Standard) | ||||
|   mirror_hardwares: [amdexperimental] | ||||
|   torch_nightly: true | ||||
| @ -571,9 +565,7 @@ steps: | ||||
|   commands: | ||||
|     - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git | ||||
|     - pip freeze | grep -E 'torch' | ||||
|     - pytest -v -s models/multimodal/processing | ||||
|     - pytest -v -s --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/test_tensor_schema.py models/multimodal -m core_model | ||||
|     - pytest -v -s models/multimodal/test_tensor_schema.py -m core_model  # Needs mp_method="spawn" | ||||
|     - pytest -v -s models/multimodal -m core_model --ignore models/multimodal/generation/test_whisper.py --ignore models/multimodal/processing | ||||
|     - cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model  # Otherwise, mp_method="spawn" doesn't work | ||||
|  | ||||
| - label: Multi-Modal Models Test (Extended) 1 | ||||
| @ -584,7 +576,7 @@ steps: | ||||
|   - tests/models/multimodal | ||||
|   commands: | ||||
|     - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git | ||||
|     - pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model' | ||||
|     - pytest -v -s models/multimodal -m 'not core_model' --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing | ||||
|  | ||||
| - label: Multi-Modal Models Test (Extended) 2 | ||||
|   mirror_hardwares: [amdexperimental] | ||||
| @ -647,8 +639,10 @@ steps: | ||||
|   - vllm/model_executor/layers/fused_moe/cutlass_moe.py | ||||
|   - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py | ||||
|   - vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py | ||||
|   - vllm/model_executor/layers/quantization/utils/flashinfer_utils.py | ||||
|   - vllm/v1/attention/backends/flashinfer.py | ||||
|   - vllm/compilation/fusion.py | ||||
|   - vllm/compilation/fusion_attn.py | ||||
|   commands: | ||||
|     - nvidia-smi | ||||
|     - python3 examples/offline_inference/basic/chat.py | ||||
| @ -661,10 +655,14 @@ steps: | ||||
|     - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' | ||||
|     - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py | ||||
|     - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py | ||||
|     - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py | ||||
|     - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py | ||||
|     - pytest -v -s tests/kernels/moe/test_nvfp4_moe.py | ||||
|     - pytest -v -s tests/kernels/moe/test_mxfp4_moe.py | ||||
|     # Fusion | ||||
|     - pytest -v -s tests/compile/test_fusion_all_reduce.py | ||||
|     - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern | ||||
|     - pytest -v -s tests/kernels/moe/test_flashinfer.py | ||||
|  | ||||
| #####  1 GPU test  ##### | ||||
| #####  multi gpus test  ##### | ||||
| @ -847,3 +845,10 @@ steps: | ||||
|   commands: | ||||
|   - export VLLM_WORKER_MULTIPROC_METHOD=spawn | ||||
|   - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 | ||||
|  | ||||
| - label: Qwen MoE EP Test # optional | ||||
|   gpu: h200 | ||||
|   optional: true | ||||
|   num_gpus: 2 | ||||
|   commands: | ||||
|     - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 /vllm-workspace/examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1  --dp-size=2 --max-model-len 2048 | ||||
|  | ||||
							
								
								
									
										16
									
								
								.github/CODEOWNERS
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										16
									
								
								.github/CODEOWNERS
									
									
									
									
										vendored
									
									
								
							| @ -10,6 +10,7 @@ | ||||
| /vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill | ||||
| /vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill | ||||
| /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 | ||||
| /vllm/model_executor/layers/mamba @tdoublep | ||||
| /vllm/multimodal @DarkLight1337 @ywang96 | ||||
| /vllm/vllm_flash_attn @LucasWilkinson | ||||
| /vllm/lora @jeejeelee | ||||
| @ -25,11 +26,11 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson | ||||
| # vLLM V1 | ||||
| /vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat | ||||
| /vllm/v1/structured_output @mgoin @russellb @aarnphm | ||||
| /vllm/v1/attention/backends/triton_attn.py @tdoublep | ||||
|  | ||||
| # Test ownership | ||||
| /.buildkite/lm-eval-harness @mgoin @simon-mo | ||||
| /tests/async_engine @njhill @robertgshaw2-redhat @simon-mo | ||||
| /tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac | ||||
| /tests/distributed/test_multi_node_assignment.py @youkaichao | ||||
| /tests/distributed/test_pipeline_parallel.py @youkaichao | ||||
| /tests/distributed/test_same_node.py @youkaichao | ||||
| @ -44,6 +45,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson | ||||
| /tests/v1/structured_output @mgoin @russellb @aarnphm | ||||
| /tests/weight_loading @mgoin @youkaichao @yewentao256 | ||||
| /tests/lora @jeejeelee | ||||
| /tests/models/language/generation/test_hybrid.py @tdoublep | ||||
|  | ||||
| # Docs | ||||
| /docs @hmellor | ||||
| @ -72,3 +74,15 @@ mkdocs.yaml @hmellor | ||||
| /vllm/model_executor/models/pixtral*.py @patrickvonplaten | ||||
| /vllm/transformers_utils/configs/mistral.py @patrickvonplaten | ||||
| /vllm/transformers_utils/tokenizers/mistral.py @patrickvonplaten | ||||
|  | ||||
| # Kernels | ||||
| /vllm/attention/ops/chunked_prefill_paged_decode.py @tdoublep | ||||
| /vllm/attention/ops/triton_unified_attention.py @tdoublep | ||||
|  | ||||
| # ROCm related: specify owner with write access to notify AMD folks for careful code review | ||||
| /docker/Dockerfile.rocm* @gshtras | ||||
| /vllm/v1/attention/backends/rocm*.py @gshtras | ||||
| /vllm/v1/attention/backends/mla/rocm*.py @gshtras | ||||
| /vllm/attention/ops/rocm*.py @gshtras | ||||
| /vllm/model_executor/layers/fused_moe/rocm*.py @gshtras | ||||
|  | ||||
|  | ||||
							
								
								
									
										3
									
								
								.github/PULL_REQUEST_TEMPLATE.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/PULL_REQUEST_TEMPLATE.md
									
									
									
									
										vendored
									
									
								
							| @ -7,8 +7,6 @@ PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTT | ||||
|  | ||||
| ## Test Result | ||||
|  | ||||
| ## (Optional) Documentation Update | ||||
|  | ||||
| --- | ||||
| <details> | ||||
| <summary> Essential Elements of an Effective PR Description Checklist </summary> | ||||
| @ -17,6 +15,7 @@ PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTT | ||||
| - [ ] The test plan, such as providing test command. | ||||
| - [ ] The test results, such as pasting the results comparison before and after, or e2e results | ||||
| - [ ] (Optional) The necessary documentation update, such as updating `supported_models.md` and `examples` for a new model. | ||||
| - [ ] (Optional) Release notes update. If your change is user facing, please update the release notes draft in the [Google Doc](https://docs.google.com/document/d/1YyVqrgX4gHTtrstbq8oWUImOyPCKSGnJ7xtTpmXzlRs/edit?tab=t.0). | ||||
| </details> | ||||
|  | ||||
| **BEFORE SUBMITTING, PLEASE READ <https://docs.vllm.ai/en/latest/contributing>** (anything written below this line will be removed by GitHub Actions) | ||||
|  | ||||
							
								
								
									
										305
									
								
								.github/workflows/issue_autolabel.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										305
									
								
								.github/workflows/issue_autolabel.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,305 @@ | ||||
| name: Label issues based on keywords | ||||
| on: | ||||
|   issues: | ||||
|     types: [opened, edited, reopened] | ||||
| permissions: | ||||
|   issues: write          # needed so the workflow can add labels | ||||
|   contents: read | ||||
| concurrency: | ||||
|   group: issue-labeler-${{ github.event.issue.number }} | ||||
|   cancel-in-progress: true | ||||
| jobs: | ||||
|   add-labels: | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - name: Label issues based on keywords | ||||
|         uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea  # v7.0.1 | ||||
|         with: | ||||
|           script: | | ||||
|             // Configuration: Add new labels and keywords here | ||||
|             const labelConfig = { | ||||
|               rocm: { | ||||
|                 // Keyword search - matches whole words only (with word boundaries) | ||||
|                 keywords: [ | ||||
|                   { | ||||
|                     term: "composable kernel", | ||||
|                     searchIn: "both" | ||||
|                   }, | ||||
|                   { | ||||
|                     term: "rccl", | ||||
|                     searchIn: "body"  // only search in body | ||||
|                   }, | ||||
|                   { | ||||
|                     term: "migraphx", | ||||
|                     searchIn: "title"  // only search in title | ||||
|                   }, | ||||
|                   { | ||||
|                     term: "hipgraph", | ||||
|                     searchIn: "both" | ||||
|                   }, | ||||
|                   { | ||||
|                     term: "ROCm System Management Interface", | ||||
|                     searchIn: "body" | ||||
|                   }, | ||||
|                 ], | ||||
|                  | ||||
|                 // Substring search - matches anywhere in text (partial matches) | ||||
|                 substrings: [ | ||||
|                   { | ||||
|                     term: "VLLM_ROCM_", | ||||
|                     searchIn: "both" | ||||
|                   }, | ||||
|                   { | ||||
|                     term: "rocm", | ||||
|                     searchIn: "title" | ||||
|                   }, | ||||
|                   { | ||||
|                     term: "amd", | ||||
|                     searchIn: "title" | ||||
|                   }, | ||||
|                   { | ||||
|                     term: "hip-", | ||||
|                     searchIn: "both" | ||||
|                   }, | ||||
|                   { | ||||
|                     term: "gfx", | ||||
|                     searchIn: "both" | ||||
|                   }, | ||||
|                   { | ||||
|                     term: "cdna", | ||||
|                     searchIn: "both" | ||||
|                   }, | ||||
|                   { | ||||
|                     term: "rdna", | ||||
|                     searchIn: "both" | ||||
|                   }, | ||||
|                   { | ||||
|                     term: "torch_hip", | ||||
|                     searchIn: "body"  // only in body | ||||
|                   }, | ||||
|                   { | ||||
|                     term: "_hip", | ||||
|                     searchIn: "both" | ||||
|                   }, | ||||
|                   { | ||||
|                     term: "hip_", | ||||
|                     searchIn: "both" | ||||
|                   }, | ||||
|                    | ||||
|                   // ROCm tools and libraries | ||||
|                   { | ||||
|                     term: "hipify", | ||||
|                     searchIn: "both" | ||||
|                   }, | ||||
|                 ], | ||||
|                  | ||||
|                 // Regex patterns - for complex pattern matching | ||||
|                 regexPatterns: [ | ||||
|                   { | ||||
|                     pattern: "\\bmi\\d{3}[a-z]*\\b", | ||||
|                     description: "AMD GPU names (mi + 3 digits + optional letters)", | ||||
|                     flags: "gi", | ||||
|                     searchIn: "both"  // "title", "body", or "both" | ||||
|                   } | ||||
|                 ], | ||||
|               }, | ||||
|             }; | ||||
|              | ||||
|             // Helper function to create regex based on search type | ||||
|             function createSearchRegex(term, type) { | ||||
|               // Escape special regex characters in the term | ||||
|               const escapedTerm = term.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); | ||||
|                | ||||
|               switch (type) { | ||||
|                 case 'keyword': | ||||
|                   // Word boundary search - matches whole words only | ||||
|                   return new RegExp(`\\b${escapedTerm}\\b`, "gi"); | ||||
|                 case 'substring': | ||||
|                   // Substring search - matches anywhere in the text | ||||
|                   return new RegExp(escapedTerm, "gi"); | ||||
|                 default: | ||||
|                   throw new Error(`Unknown search type: ${type}`); | ||||
|               } | ||||
|             } | ||||
|              | ||||
|             // Helper function to find matching terms in text with line information | ||||
|             function findMatchingTermsWithLines(text, searchTerms = [], searchType = 'keyword', searchLocation = '') { | ||||
|               const matches = []; | ||||
|               const lines = text.split('\n'); | ||||
|                | ||||
|               for (const termConfig of searchTerms) { | ||||
|                 let regex; | ||||
|                 let term, searchIn, pattern, description, flags; | ||||
|                  | ||||
|                 // Handle different input formats (string or object) | ||||
|                 if (typeof termConfig === 'string') { | ||||
|                   term = termConfig; | ||||
|                   searchIn = 'both'; // default | ||||
|                 } else { | ||||
|                   term = termConfig.term; | ||||
|                   searchIn = termConfig.searchIn || 'both'; | ||||
|                   pattern = termConfig.pattern; | ||||
|                   description = termConfig.description; | ||||
|                   flags = termConfig.flags; | ||||
|                 } | ||||
|                  | ||||
|                 // Skip if this term shouldn't be searched in the current location | ||||
|                 if (searchIn !== 'both' && searchIn !== searchLocation) { | ||||
|                   continue; | ||||
|                 } | ||||
|                  | ||||
|                 // Create appropriate regex | ||||
|                 if (searchType === 'regex') { | ||||
|                   regex = new RegExp(pattern, flags || "gi"); | ||||
|                 } else { | ||||
|                   regex = createSearchRegex(term, searchType); | ||||
|                 } | ||||
|                  | ||||
|                 const termMatches = []; | ||||
|                  | ||||
|                 // Check each line for matches | ||||
|                 lines.forEach((line, lineIndex) => { | ||||
|                   const lineMatches = line.match(regex); | ||||
|                   if (lineMatches) { | ||||
|                     lineMatches.forEach(match => { | ||||
|                       termMatches.push({ | ||||
|                         match: match, | ||||
|                         lineNumber: lineIndex + 1, | ||||
|                         lineContent: line.trim(), | ||||
|                         searchType: searchType, | ||||
|                         searchLocation: searchLocation, | ||||
|                         originalTerm: term || pattern, | ||||
|                         description: description, | ||||
|                         // Show context around the match in the line | ||||
|                         context: line.length > 100 ?  | ||||
|                           line.substring(Math.max(0, line.toLowerCase().indexOf(match.toLowerCase()) - 30),  | ||||
|                                        line.toLowerCase().indexOf(match.toLowerCase()) + match.length + 30) + '...'  | ||||
|                           : line.trim() | ||||
|                       }); | ||||
|                     }); | ||||
|                   } | ||||
|                 }); | ||||
|                  | ||||
|                 if (termMatches.length > 0) { | ||||
|                   matches.push({ | ||||
|                     term: term || (description || pattern), | ||||
|                     searchType: searchType, | ||||
|                     searchLocation: searchLocation, | ||||
|                     searchIn: searchIn, | ||||
|                     pattern: pattern, | ||||
|                     matches: termMatches, | ||||
|                     count: termMatches.length | ||||
|                   }); | ||||
|                 } | ||||
|               } | ||||
|                | ||||
|               return matches; | ||||
|             } | ||||
|              | ||||
|             // Helper function to check if label should be added | ||||
|             async function processLabel(labelName, config) { | ||||
|               const body = context.payload.issue.body || ""; | ||||
|               const title = context.payload.issue.title || ""; | ||||
|                | ||||
|               core.notice(`Processing label: ${labelName}`); | ||||
|               core.notice(`Issue Title: "${title}"`); | ||||
|               core.notice(`Issue Body length: ${body.length} characters`); | ||||
|                | ||||
|               let shouldAddLabel = false; | ||||
|               let allMatches = []; | ||||
|               let reason = ''; | ||||
|                | ||||
|               const keywords = config.keywords || []; | ||||
|               const substrings = config.substrings || []; | ||||
|               const regexPatterns = config.regexPatterns || []; | ||||
|                | ||||
|               core.notice(`Searching with ${keywords.length} keywords, ${substrings.length} substrings, and ${regexPatterns.length} regex patterns`); | ||||
|                | ||||
|               // Search in title | ||||
|               if (title.trim()) { | ||||
|                 core.notice(`Searching in title: "${title}"`); | ||||
|                  | ||||
|                 const titleKeywordMatches = findMatchingTermsWithLines(title, keywords, 'keyword', 'title'); | ||||
|                 const titleSubstringMatches = findMatchingTermsWithLines(title, substrings, 'substring', 'title'); | ||||
|                 const titleRegexMatches = findMatchingTermsWithLines(title, regexPatterns, 'regex', 'title'); | ||||
|                  | ||||
|                 allMatches.push(...titleKeywordMatches, ...titleSubstringMatches, ...titleRegexMatches); | ||||
|               } | ||||
|                | ||||
|               // Search in body | ||||
|               if (body.trim()) { | ||||
|                 core.notice(`Searching in body (${body.length} characters)`); | ||||
|                  | ||||
|                 const bodyKeywordMatches = findMatchingTermsWithLines(body, keywords, 'keyword', 'body'); | ||||
|                 const bodySubstringMatches = findMatchingTermsWithLines(body, substrings, 'substring', 'body'); | ||||
|                 const bodyRegexMatches = findMatchingTermsWithLines(body, regexPatterns, 'regex', 'body'); | ||||
|                  | ||||
|                 allMatches.push(...bodyKeywordMatches, ...bodySubstringMatches, ...bodyRegexMatches); | ||||
|               } | ||||
|                | ||||
|               if (allMatches.length > 0) { | ||||
|                 core.notice(`Found ${allMatches.length} matching term(s):`); | ||||
|                  | ||||
|                 for (const termMatch of allMatches) { | ||||
|                   const locationText = termMatch.searchLocation === 'title' ? 'title' : 'body'; | ||||
|                   const searchInText = termMatch.searchIn === 'both' ? 'both' : termMatch.searchIn; | ||||
|                    | ||||
|                   if (termMatch.searchType === 'regex') { | ||||
|                     core.notice(`  📍 Regex: "${termMatch.term}" (pattern: ${termMatch.pattern}) found ${termMatch.count} time(s) in ${locationText} (configured to search in: ${searchInText}):`); | ||||
|                   } else { | ||||
|                     core.notice(`  📍 Term: "${termMatch.term}" (${termMatch.searchType} search) found ${termMatch.count} time(s) in ${locationText} (configured to search in: ${searchInText}):`); | ||||
|                   } | ||||
|                    | ||||
|                   // Show details for each match | ||||
|                   termMatch.matches.forEach((match, index) => { | ||||
|                     core.notice(`    ${index + 1}. Line ${match.lineNumber} in ${match.searchLocation}: "${match.match}" [${match.searchType}]`); | ||||
|                     if (match.description) { | ||||
|                       core.notice(`       Description: ${match.description}`); | ||||
|                     } | ||||
|                     core.notice(`       Context: ${match.context}`); | ||||
|                     if (match.lineContent !== match.context) { | ||||
|                       core.notice(`       Full line: ${match.lineContent}`); | ||||
|                     } | ||||
|                   }); | ||||
|                 } | ||||
|                  | ||||
|                 shouldAddLabel = true; | ||||
|                 const totalMatches = allMatches.reduce((sum, t) => sum + t.count, 0); | ||||
|                 const titleMatches = allMatches.filter(t => t.searchLocation === 'title').reduce((sum, t) => sum + t.count, 0); | ||||
|                 const bodyMatches = allMatches.filter(t => t.searchLocation === 'body').reduce((sum, t) => sum + t.count, 0); | ||||
|                 const keywordMatches = allMatches.filter(t => t.searchType === 'keyword').reduce((sum, t) => sum + t.count, 0); | ||||
|                 const substringMatches = allMatches.filter(t => t.searchType === 'substring').reduce((sum, t) => sum + t.count, 0); | ||||
|                 const regexMatches = allMatches.filter(t => t.searchType === 'regex').reduce((sum, t) => sum + t.count, 0); | ||||
|                  | ||||
|                 reason = `Found ${totalMatches} total matches (${titleMatches} in title, ${bodyMatches} in body) - ${keywordMatches} keyword matches, ${substringMatches} substring matches, ${regexMatches} regex matches`; | ||||
|               } | ||||
|                | ||||
|               core.notice(`Final decision: ${shouldAddLabel ? 'ADD LABEL' : 'DO NOT ADD LABEL'}`); | ||||
|               core.notice(`Reason: ${reason || 'No matching terms found'}`); | ||||
|                | ||||
|               if (shouldAddLabel) { | ||||
|                 const existingLabels = context.payload.issue.labels.map(l => l.name); | ||||
|                 if (!existingLabels.includes(labelName)) { | ||||
|                   await github.rest.issues.addLabels({ | ||||
|                     owner: context.repo.owner, | ||||
|                     repo: context.repo.repo, | ||||
|                     issue_number: context.issue.number, | ||||
|                     labels: [labelName], | ||||
|                   }); | ||||
|                   core.notice(`Label "${labelName}" added. ${reason}`); | ||||
|                   return true; | ||||
|                 } | ||||
|                 core.notice(`Label "${labelName}" already present.`); | ||||
|                 return false; | ||||
|               } | ||||
|                | ||||
|               core.notice(`No matching terms found for label "${labelName}".`); | ||||
|               return false; | ||||
|             } | ||||
|              | ||||
|             // Process all configured labels | ||||
|             const processLabels = Object.entries(labelConfig) | ||||
|               .map(([labelName, config]) => processLabel(labelName, config)); | ||||
|             const labelsAdded = await Promise.all(processLabels); | ||||
|             const numLabelsAdded = labelsAdded.reduce((x, y) => x + y, 0); | ||||
|             core.notice(`Processing complete. ${numLabelsAdded} label(s) added.`); | ||||
							
								
								
									
										89
									
								
								.github/workflows/lint-and-deploy.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										89
									
								
								.github/workflows/lint-and-deploy.yaml
									
									
									
									
										vendored
									
									
								
							| @ -1,89 +0,0 @@ | ||||
| name: Lint and Deploy Charts | ||||
|  | ||||
| on: pull_request | ||||
|  | ||||
| concurrency: | ||||
|   group: ${{ github.workflow }}-${{ github.ref }} | ||||
|   cancel-in-progress: true | ||||
|  | ||||
| permissions: | ||||
|   contents: read | ||||
|  | ||||
| jobs: | ||||
|   lint-and-deploy: | ||||
|     runs-on: ubuntu-latest | ||||
|     steps: | ||||
|       - name: Checkout | ||||
|         uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | ||||
|         with: | ||||
|           fetch-depth: 0 | ||||
|  | ||||
|       - name: Set up Helm | ||||
|         uses: azure/setup-helm@b9e51907a09c216f16ebe8536097933489208112 # v4.3.0 | ||||
|         with: | ||||
|           version: v3.14.4 | ||||
|  | ||||
|        #Python is required because ct lint runs Yamale and yamllint which require Python. | ||||
|       - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 | ||||
|         with: | ||||
|           python-version: '3.13' | ||||
|  | ||||
|       - name: Set up chart-testing | ||||
|         uses: helm/chart-testing-action@0d28d3144d3a25ea2cc349d6e59901c4ff469b3b # v2.7.0 | ||||
|         with: | ||||
|           version: v3.10.1 | ||||
|  | ||||
|       - name: Run chart-testing (lint) | ||||
|         run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/online_serving/chart-helm --charts examples/online_serving/chart-helm | ||||
|  | ||||
|       - name: Setup minio | ||||
|         run: | | ||||
|           docker network create vllm-net | ||||
|           docker run -d -p 9000:9000 --name minio --net vllm-net \ | ||||
|                      -e "MINIO_ACCESS_KEY=minioadmin" \ | ||||
|                      -e "MINIO_SECRET_KEY=minioadmin" \ | ||||
|                      -v /tmp/data:/data \ | ||||
|                      -v /tmp/config:/root/.minio \ | ||||
|                      minio/minio server /data | ||||
|           export AWS_ACCESS_KEY_ID=minioadmin | ||||
|           export AWS_SECRET_ACCESS_KEY=minioadmin | ||||
|           export AWS_EC2_METADATA_DISABLED=true | ||||
|           mkdir opt-125m | ||||
|           cd opt-125m && curl -O -Ls "https://huggingface.co/facebook/opt-125m/resolve/main/{pytorch_model.bin,config.json,generation_config.json,merges.txt,special_tokens_map.json,tokenizer_config.json,vocab.json}" && cd .. | ||||
|           aws --endpoint-url http://127.0.0.1:9000/ s3 mb s3://testbucket | ||||
|           aws --endpoint-url http://127.0.0.1:9000/ s3 cp opt-125m/ s3://testbucket/opt-125m --recursive | ||||
|  | ||||
|       - name: Create kind cluster | ||||
|         uses: helm/kind-action@a1b0e391336a6ee6713a0583f8c6240d70863de3 # v1.12.0 | ||||
|  | ||||
|       - name: Build the Docker image vllm cpu | ||||
|         run: docker buildx build -f docker/Dockerfile.cpu -t vllm-cpu-env . | ||||
|  | ||||
|       - name: Configuration of docker images, network and namespace for the kind cluster | ||||
|         run: | | ||||
|           docker pull amazon/aws-cli:2.6.4 | ||||
|           kind load docker-image  amazon/aws-cli:2.6.4 --name chart-testing | ||||
|           kind load docker-image vllm-cpu-env:latest --name chart-testing | ||||
|           docker network connect vllm-net "$(docker ps -aqf "name=chart-testing-control-plane")" | ||||
|           kubectl create ns ns-vllm | ||||
|  | ||||
|       - name: Run chart-testing (install) | ||||
|         run: | | ||||
|           export AWS_ACCESS_KEY_ID=minioadmin | ||||
|           export AWS_SECRET_ACCESS_KEY=minioadmin | ||||
|           sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" & | ||||
|           helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" | ||||
|  | ||||
|       - name: curl test | ||||
|         run: | | ||||
|           kubectl -n ns-vllm port-forward service/test-vllm-service 8001:80 & | ||||
|           sleep 10 | ||||
|           CODE="$(curl -v -f --location http://localhost:8001/v1/completions \ | ||||
|                   --header "Content-Type: application/json" \ | ||||
|                   --data '{ | ||||
|                           "model": "opt-125m", | ||||
|                           "prompt": "San Francisco is a", | ||||
|                           "max_tokens": 7, | ||||
|                           "temperature": 0 | ||||
|                   }'):$CODE" | ||||
|           echo "$CODE" | ||||
							
								
								
									
										111
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										111
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
								
							| @ -1,111 +0,0 @@ | ||||
| # This workflow will upload a Python Package to Release asset | ||||
| # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions | ||||
|  | ||||
| name: Create Release | ||||
|  | ||||
| on: | ||||
|   push: | ||||
|     tags: | ||||
|       - v* | ||||
|  | ||||
| # Needed to create release and upload assets | ||||
| permissions: | ||||
|   contents: write | ||||
|  | ||||
| jobs: | ||||
|   release: | ||||
|     # Retrieve tag and create release | ||||
|     name: Create Release | ||||
|     runs-on: ubuntu-latest | ||||
|     outputs: | ||||
|       upload_url: ${{ steps.create_release.outputs.upload_url }} | ||||
|     steps: | ||||
|       - name: Checkout | ||||
|         uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | ||||
|  | ||||
|       - name: Extract branch info | ||||
|         shell: bash | ||||
|         run: | | ||||
|           echo "release_tag=${GITHUB_REF#refs/*/}" >> "$GITHUB_ENV" | ||||
|  | ||||
|       - name: Create Release | ||||
|         id: create_release | ||||
|         uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 | ||||
|         env: | ||||
|           RELEASE_TAG: ${{ env.release_tag }} | ||||
|         with: | ||||
|           github-token: "${{ secrets.GITHUB_TOKEN }}" | ||||
|           script: | | ||||
|             const script = require('.github/workflows/scripts/create_release.js') | ||||
|             await script(github, context, core) | ||||
|  | ||||
|   # NOTE(simon): No longer build wheel using GitHub Actions. See buildkite's release workflow.  | ||||
|   # wheel: | ||||
|   #   name: Build Wheel | ||||
|   #   runs-on: ${{ matrix.os }} | ||||
|   #   needs: release | ||||
|  | ||||
|   #   strategy: | ||||
|   #     fail-fast: false | ||||
|   #     matrix: | ||||
|   #         os: ['ubuntu-20.04'] | ||||
|   #         python-version: ['3.9', '3.10', '3.11', '3.12'] | ||||
|   #         pytorch-version: ['2.4.0']  # Must be the most recent version that meets requirements/cuda.txt. | ||||
|   #         cuda-version: ['11.8', '12.1'] | ||||
|  | ||||
|   #   steps: | ||||
|   #     - name: Checkout | ||||
|   #       uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | ||||
|  | ||||
|   #     - name: Setup ccache | ||||
|   #       uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14 | ||||
|   #       with: | ||||
|   #         create-symlink: true | ||||
|   #         key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }} | ||||
|  | ||||
|   #     - name: Set up Linux Env | ||||
|   #       if: ${{ runner.os == 'Linux' }} | ||||
|   #       run: | | ||||
|   #         bash -x .github/workflows/scripts/env.sh | ||||
|  | ||||
|   #     - name: Set up Python | ||||
|   #       uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 | ||||
|   #       with: | ||||
|   #           python-version: ${{ matrix.python-version }} | ||||
|  | ||||
|   #     - name: Install CUDA ${{ matrix.cuda-version }} | ||||
|   #       run: | | ||||
|   #         bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} | ||||
|  | ||||
|   #     - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }} | ||||
|   #       run: | | ||||
|   #         bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }} | ||||
|  | ||||
|   #     - name: Build wheel | ||||
|   #       shell: bash | ||||
|   #       env: | ||||
|   #         CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size | ||||
|   #       run: | | ||||
|   #         bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} | ||||
|   #         wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename) | ||||
|   #         asset_name=${wheel_name//"linux"/"manylinux1"} | ||||
|   #         echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV" | ||||
|   #         echo "asset_name=${asset_name}" >> "$GITHUB_ENV" | ||||
|  | ||||
|   #     - name: Upload Release Asset | ||||
|   #       uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2 | ||||
|   #       env: | ||||
|   #         GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||||
|   #       with: | ||||
|   #         upload_url: ${{ needs.release.outputs.upload_url }} | ||||
|   #         asset_path: ./dist/${{ env.wheel_name }} | ||||
|   #         asset_name: ${{ env.asset_name }} | ||||
|   #         asset_content_type: application/* | ||||
|  | ||||
|       # (Danielkinz): This last step will publish the .whl to pypi. Warning: untested | ||||
|       # - name: Publish package | ||||
|       #   uses: pypa/gh-action-pypi-publish@release/v1.8 | ||||
|       #   with: | ||||
|       #     repository-url: https://test.pypi.org/legacy/ | ||||
|       #     password: ${{ secrets.PYPI_API_TOKEN }} | ||||
|       #     skip-existing: true | ||||
							
								
								
									
										49
									
								
								.github/workflows/reminder_comment.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										49
									
								
								.github/workflows/reminder_comment.yml
									
									
									
									
										vendored
									
									
								
							| @ -12,16 +12,43 @@ jobs: | ||||
|         uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 | ||||
|         with: | ||||
|           script: | | ||||
|             github.rest.issues.createComment({ | ||||
|               owner: context.repo.owner, | ||||
|               repo: context.repo.repo, | ||||
|               issue_number: context.issue.number, | ||||
|               body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' + | ||||
|                 '💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' + | ||||
|                 'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org.\n\n' + | ||||
|                 'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' + | ||||
|                 'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' + | ||||
|                 '🚀' | ||||
|             }) | ||||
|             try { | ||||
|               // Get the PR author | ||||
|               const prAuthor = context.payload.pull_request.user.login; | ||||
|                | ||||
|               // Check if this is the author's first PR in this repository | ||||
|               // Use GitHub's search API to find all PRs by this author | ||||
|               const { data: searchResults } = await github.rest.search.issuesAndPullRequests({ | ||||
|                 q: `repo:${context.repo.owner}/${context.repo.repo} type:pr author:${prAuthor}`, | ||||
|                 per_page: 100   | ||||
|               }); | ||||
|                | ||||
|               const authorPRCount = searchResults.total_count; | ||||
|                | ||||
|               console.log(`Found ${authorPRCount} PRs by ${prAuthor}`); | ||||
|                | ||||
|               // Only post comment if this is the first PR (only one PR by this author) | ||||
|               if (authorPRCount === 1) { | ||||
|                 console.log(`Posting welcome comment for first-time contributor: ${prAuthor}`); | ||||
|                 await github.rest.issues.createComment({ | ||||
|                 owner: context.repo.owner, | ||||
|                 repo: context.repo.repo, | ||||
|                 issue_number: context.issue.number, | ||||
|                 body: '👋 Hi! Thank you for contributing to the vLLM project.\n\n' + | ||||
|                   '💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.\n\n' + | ||||
|                   'Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. \n\n' + | ||||
|                   'You ask your reviewers to trigger select CI tests on top of `fastcheck` CI. \n\n' + | ||||
|                   'Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n' + | ||||
|                   'To run CI, PR reviewers can either: Add `ready` label to the PR or enable auto-merge.\n\n' + | ||||
|                   'If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.\n\n' + | ||||
|                   '🚀' | ||||
|                 }); | ||||
|               } else { | ||||
|                 console.log(`Skipping comment for ${prAuthor} - not their first PR (${authorPRCount} PRs found)`); | ||||
|               } | ||||
|             } catch (error) { | ||||
|               console.error('Error checking PR history or posting comment:', error); | ||||
|               // Don't fail the workflow, just log the error | ||||
|             } | ||||
|         env: | ||||
|           GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||||
|  | ||||
| @ -30,7 +30,7 @@ install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) | ||||
| # Supported python versions.  These versions will be searched in order, the | ||||
| # first match will be selected.  These should be kept in sync with setup.py. | ||||
| # | ||||
| set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") | ||||
| set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12" "3.13") | ||||
|  | ||||
| # Supported AMD GPU architectures. | ||||
| set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201") | ||||
| @ -357,9 +357,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | ||||
|     list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) | ||||
|  | ||||
|     set(MARLIN_SRCS | ||||
|        "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" | ||||
|        "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" | ||||
|        "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" | ||||
|        "csrc/quantization/gptq_marlin/gptq_marlin.cu" | ||||
|        "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" | ||||
|        "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") | ||||
| @ -752,6 +750,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") | ||||
|                      "found in CUDA target architectures") | ||||
|     endif() | ||||
|   endif() | ||||
|  | ||||
|   # Only build W4A8 kernels if we are building for something compatible with sm90a | ||||
|   cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}") | ||||
|   if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS) | ||||
|     set(SRCS | ||||
|        "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu") | ||||
|  | ||||
|     set_gencode_flags_for_srcs( | ||||
|       SRCS "${SRCS}" | ||||
|       CUDA_ARCHS "${W4A8_ARCHS}") | ||||
|  | ||||
|     list(APPEND VLLM_EXT_SRC "${SRCS}") | ||||
|  | ||||
|     message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}") | ||||
|   else() | ||||
|     if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 | ||||
|         AND W4A8_ARCHS) | ||||
|       message(STATUS "Not building W4A8 kernels as CUDA Compiler version is " | ||||
|                      "not >= 12.0, we recommend upgrading to CUDA 12.0 or " | ||||
|                      "later if you intend on running w4a16 quantized models on " | ||||
|                      "Hopper.") | ||||
|     else() | ||||
|       message(STATUS "Not building W4A8 kernels as no compatible archs " | ||||
|                      "found in CUDA target architectures") | ||||
|     endif() | ||||
|   endif() | ||||
|  | ||||
| # if CUDA endif | ||||
| endif() | ||||
|  | ||||
| @ -792,7 +817,9 @@ set(VLLM_MOE_EXT_SRC | ||||
|   "csrc/moe/topk_softmax_kernels.cu") | ||||
|  | ||||
| if(VLLM_GPU_LANG STREQUAL "CUDA") | ||||
|   list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu") | ||||
|   list(APPEND VLLM_MOE_EXT_SRC | ||||
|     "csrc/moe/moe_wna16.cu" | ||||
|     "csrc/moe/grouped_topk_kernels.cu") | ||||
| endif() | ||||
|  | ||||
| if(VLLM_GPU_LANG STREQUAL "CUDA") | ||||
|  | ||||
| @ -18,14 +18,15 @@ Easy, fast, and cheap LLM serving for everyone | ||||
|  | ||||
| *Latest News* 🔥 | ||||
|  | ||||
| - [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH). | ||||
| - [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152). | ||||
| - [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing). | ||||
| - [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/). | ||||
| - [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). | ||||
|  | ||||
| <details> | ||||
| <summary>Previous News</summary> | ||||
|  | ||||
| - [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing). | ||||
| - [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). | ||||
| - [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing). | ||||
| - [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing). | ||||
|  | ||||
| @ -42,4 +42,9 @@ For certain security issues of CRITICAL, HIGH, or MODERATE severity level, we ma | ||||
|  | ||||
| * If you wish to be added to the prenotification group, please send an email copying all the members of the [vulnerability management team](https://docs.vllm.ai/en/latest/contributing/vulnerability_management.html). Each vendor contact will be analyzed on a case-by-case basis. | ||||
|  | ||||
| * Organizations and vendors who either ship or use vLLM, are eligible to join the prenotification group if they meet at least one of the following qualifications | ||||
|     * Substantial internal deployment leveraging the upstream vLLM project. | ||||
|     * Established internal security teams and comprehensive compliance measures. | ||||
|     * Active and consistent contributions to the upstream vLLM project. | ||||
|  | ||||
| * We may withdraw organizations from receiving future prenotifications if they release fixes or any other information about issues before they are public. Group membership may also change based on policy refinements for who may be included. | ||||
|  | ||||
| @ -32,6 +32,14 @@ become available. | ||||
|         <div>Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:</div> | ||||
|         <code>wget http://images.cocodataset.org/zips/train2017.zip</code> | ||||
|       </td> | ||||
|     </tr> | ||||
|         <tr> | ||||
|       <td><strong>ShareGPT4Video (Video)</strong></td> | ||||
|       <td style="text-align: center;">✅</td> | ||||
|       <td style="text-align: center;">✅</td> | ||||
|       <td> | ||||
|         <code>git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video</code> | ||||
|       </td> | ||||
|     </tr> | ||||
|     <tr> | ||||
|       <td><strong>BurstGPT</strong></td> | ||||
| @ -51,6 +59,12 @@ become available. | ||||
|       <td style="text-align: center;">✅</td> | ||||
|       <td><code>synthetic</code></td> | ||||
|     </tr> | ||||
|     <tr> | ||||
|       <td><strong>RandomMultiModal (Image/Video)</strong></td> | ||||
|       <td style="text-align: center;">🟡</td> | ||||
|       <td style="text-align: center;">🚧</td> | ||||
|       <td><code>synthetic</code> </td> | ||||
|     </tr> | ||||
|     <tr> | ||||
|       <td><strong>Prefix Repetition</strong></td> | ||||
|       <td style="text-align: center;">✅</td> | ||||
| @ -194,6 +208,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct | ||||
| ```bash | ||||
| vllm bench serve \ | ||||
|   --backend openai-chat \ | ||||
|   --endpoint-type openai-chat \ | ||||
|   --model Qwen/Qwen2-VL-7B-Instruct \ | ||||
|   --endpoint /v1/chat/completions \ | ||||
|   --dataset-name hf \ | ||||
| @ -230,6 +245,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct | ||||
| ```bash | ||||
| vllm bench serve \ | ||||
|   --backend openai-chat \ | ||||
|   --endpoint-type openai-chat \ | ||||
|   --model Qwen/Qwen2-VL-7B-Instruct \ | ||||
|   --endpoint /v1/chat/completions \ | ||||
|   --dataset-name hf \ | ||||
| @ -244,6 +260,7 @@ vllm bench serve \ | ||||
| ```bash | ||||
| vllm bench serve \ | ||||
|   --backend openai-chat \ | ||||
|   --endpoint-type openai-chat \ | ||||
|   --model Qwen/Qwen2-VL-7B-Instruct \ | ||||
|   --endpoint /v1/chat/completions \ | ||||
|   --dataset-name hf \ | ||||
| @ -609,7 +626,7 @@ vllm bench serve \ | ||||
|   --prefix-repetition-prefix-len 512 \ | ||||
|   --prefix-repetition-suffix-len 128 \ | ||||
|   --prefix-repetition-num-prefixes 5 \ | ||||
|   --prefix-repetition-output-len 128  | ||||
|   --prefix-repetition-output-len 128 | ||||
| ``` | ||||
|  | ||||
| </details> | ||||
| @ -684,4 +701,102 @@ python benchmarks/benchmark_serving.py \ | ||||
|   --endpoint /v1/chat/completion | ||||
| ``` | ||||
|  | ||||
| ### Videos (ShareGPT4Video) | ||||
|  | ||||
| Start vLLM: | ||||
|  | ||||
| ```bash | ||||
| python -m vllm.entrypoints.openai.api_server \ | ||||
|   --model Qwen/Qwen2.5-VL-7B-Instruct \ | ||||
|   --dtype bfloat16 \ | ||||
|   --limit-mm-per-prompt '{"video": 1}' \ | ||||
|   --allowed-local-media-path /path/to/sharegpt4video/videos | ||||
| ``` | ||||
|  | ||||
| Send requests with videos: | ||||
|  | ||||
| ```bash | ||||
| python benchmarks/benchmark_serving.py \ | ||||
|   --backend openai-chat \ | ||||
|   --model Qwen/Qwen2.5-VL-7B-Instruct \ | ||||
|   --dataset-name sharegpt \ | ||||
|   --dataset-path /path/to/ShareGPT4Video/llava_v1_5_mix665k_with_video_chatgpt72k_share4video28k.json \ | ||||
|   --num-prompts 100 \ | ||||
|   --save-result \ | ||||
|   --result-dir ~/vllm_benchmark_results \ | ||||
|   --save-detailed \ | ||||
|   --endpoint /v1/chat/completion | ||||
| ``` | ||||
|  | ||||
| ### Synthetic Random Images (random-mm) | ||||
|  | ||||
| Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets. | ||||
|  | ||||
| Notes: | ||||
|  | ||||
| - Works only with online benchmark via the OpenAI  backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`. | ||||
| - Video sampling is not yet implemented. | ||||
|  | ||||
| Start the server (example): | ||||
|  | ||||
| ```bash | ||||
| vllm serve Qwen/Qwen2.5-VL-3B-Instruct \ | ||||
|   --dtype bfloat16 \ | ||||
|   --max-model-len 16384 \ | ||||
|   --limit-mm-per-prompt '{"image": 3, "video": 0}' \ | ||||
|   --mm-processor-kwargs max_pixels=1003520 | ||||
| ``` | ||||
|  | ||||
| Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`. | ||||
|  | ||||
| Ex.1: Fixed number of items and a single image resolution, enforcing generation of approx 40 tokens: | ||||
|  | ||||
| ```bash | ||||
| vllm bench serve \ | ||||
|   --backend openai-chat \ | ||||
|   --model Qwen/Qwen2.5-VL-3B-Instruct \ | ||||
|   --endpoint /v1/chat/completions \ | ||||
|   --dataset-name random-mm \ | ||||
|   --num-prompts 100 \ | ||||
|   --max-concurrency 10 \ | ||||
|   --random-prefix-len 25 \ | ||||
|   --random-input-len 300 \ | ||||
|   --random-output-len 40 \ | ||||
|   --random-range-ratio 0.2 \ | ||||
|   --random-mm-base-items-per-request 2 \ | ||||
|   --random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \ | ||||
|   --random-mm-bucket-config '{(224, 224, 1): 1.0}' \ | ||||
|   --request-rate inf \ | ||||
|   --ignore-eos \ | ||||
|   --seed 42 | ||||
| ``` | ||||
|  | ||||
| The number of items per request can be controlled by passing multiple image buckets: | ||||
|  | ||||
| ```bash | ||||
|   --random-mm-base-items-per-request 2 \ | ||||
|   --random-mm-num-mm-items-range-ratio 0.5 \ | ||||
|   --random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \ | ||||
|   --random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \ | ||||
| ``` | ||||
|  | ||||
| Flags specific to `random-mm`: | ||||
|  | ||||
| - `--random-mm-base-items-per-request`: base number of multimodal items per request. | ||||
| - `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items. | ||||
| - `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'. | ||||
| - `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported). | ||||
|  | ||||
| Behavioral notes: | ||||
|  | ||||
| - If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping. | ||||
|  | ||||
| How sampling works: | ||||
|  | ||||
| - Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits. | ||||
| - For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added. | ||||
| - If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing. | ||||
| This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`. | ||||
| - The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`. | ||||
|  | ||||
| </details> | ||||
|  | ||||
| @ -34,6 +34,7 @@ class RequestFuncInput: | ||||
|     multi_modal_content: Optional[dict | list[dict]] = None | ||||
|     ignore_eos: bool = False | ||||
|     language: Optional[str] = None | ||||
|     request_id: Optional[str] = None | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| @ -71,6 +72,9 @@ async def async_request_tgi( | ||||
|             "inputs": request_func_input.prompt, | ||||
|             "parameters": params, | ||||
|         } | ||||
|         headers = None | ||||
|         if request_func_input.request_id: | ||||
|             headers = {"x-request-id": request_func_input.request_id} | ||||
|         output = RequestFuncOutput() | ||||
|         output.prompt_len = request_func_input.prompt_len | ||||
|         if request_func_input.ignore_eos: | ||||
| @ -82,7 +86,9 @@ async def async_request_tgi( | ||||
|         st = time.perf_counter() | ||||
|         most_recent_timestamp = st | ||||
|         try: | ||||
|             async with session.post(url=api_url, json=payload) as response: | ||||
|             async with session.post( | ||||
|                 url=api_url, json=payload, headers=headers | ||||
|             ) as response: | ||||
|                 if response.status == 200: | ||||
|                     async for chunk_bytes in response.content: | ||||
|                         chunk_bytes = chunk_bytes.strip() | ||||
| @ -145,6 +151,9 @@ async def async_request_trt_llm( | ||||
|         } | ||||
|         if request_func_input.ignore_eos: | ||||
|             payload["min_length"] = request_func_input.output_len | ||||
|         headers = None | ||||
|         if request_func_input.request_id: | ||||
|             headers = {"x-request-id": request_func_input.request_id} | ||||
|         output = RequestFuncOutput() | ||||
|         output.prompt_len = request_func_input.prompt_len | ||||
|  | ||||
| @ -152,7 +161,9 @@ async def async_request_trt_llm( | ||||
|         st = time.perf_counter() | ||||
|         most_recent_timestamp = st | ||||
|         try: | ||||
|             async with session.post(url=api_url, json=payload) as response: | ||||
|             async with session.post( | ||||
|                 url=api_url, json=payload, headers=headers | ||||
|             ) as response: | ||||
|                 if response.status == 200: | ||||
|                     async for chunk_bytes in response.content: | ||||
|                         chunk_bytes = chunk_bytes.strip() | ||||
| @ -211,6 +222,8 @@ async def async_request_deepspeed_mii( | ||||
|             "top_p": 1.0, | ||||
|         } | ||||
|         headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} | ||||
|         if request_func_input.request_id: | ||||
|             headers["x-request-id"] = request_func_input.request_id | ||||
|  | ||||
|         output = RequestFuncOutput() | ||||
|         output.prompt_len = request_func_input.prompt_len | ||||
| @ -283,6 +296,8 @@ async def async_request_openai_completions( | ||||
|         if request_func_input.extra_body: | ||||
|             payload.update(request_func_input.extra_body) | ||||
|         headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} | ||||
|         if request_func_input.request_id: | ||||
|             headers["x-request-id"] = request_func_input.request_id | ||||
|  | ||||
|         output = RequestFuncOutput() | ||||
|         output.prompt_len = request_func_input.prompt_len | ||||
| @ -395,6 +410,8 @@ async def async_request_openai_chat_completions( | ||||
|             "Content-Type": "application/json", | ||||
|             "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", | ||||
|         } | ||||
|         if request_func_input.request_id: | ||||
|             headers["x-request-id"] = request_func_input.request_id | ||||
|  | ||||
|         output = RequestFuncOutput() | ||||
|         output.prompt_len = request_func_input.prompt_len | ||||
| @ -491,6 +508,8 @@ async def async_request_openai_audio( | ||||
|         headers = { | ||||
|             "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", | ||||
|         } | ||||
|         if request_func_input.request_id: | ||||
|             headers["x-request-id"] = request_func_input.request_id | ||||
|  | ||||
|         # Send audio file | ||||
|         def to_bytes(y, sr): | ||||
|  | ||||
| @ -19,6 +19,7 @@ import logging | ||||
| import random | ||||
| from abc import ABC, abstractmethod | ||||
| from collections.abc import Mapping | ||||
| from copy import deepcopy | ||||
| from dataclasses import dataclass | ||||
| from functools import cache | ||||
| from io import BytesIO | ||||
| @ -54,6 +55,7 @@ class SampleRequest: | ||||
|     expected_output_len: int | ||||
|     multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None | ||||
|     lora_request: Optional[LoRARequest] = None | ||||
|     request_id: Optional[str] = None | ||||
|  | ||||
|  | ||||
| # ----------------------------------------------------------------------------- | ||||
| @ -155,7 +157,10 @@ class BenchmarkDataset(ABC): | ||||
|  | ||||
|     @abstractmethod | ||||
|     def sample( | ||||
|         self, tokenizer: PreTrainedTokenizerBase, num_requests: int | ||||
|         self, | ||||
|         tokenizer: PreTrainedTokenizerBase, | ||||
|         num_requests: int, | ||||
|         request_id_prefix: str = "", | ||||
|     ) -> list[SampleRequest]: | ||||
|         """ | ||||
|         Abstract method to generate sample requests from the dataset. | ||||
| @ -167,6 +172,7 @@ class BenchmarkDataset(ABC): | ||||
|             tokenizer (PreTrainedTokenizerBase): The tokenizer to be used | ||||
|              for processing the dataset's text. | ||||
|             num_requests (int): The number of sample requests to generate. | ||||
|             request_id_prefix (str) The prefix of request_id. | ||||
|  | ||||
|         Returns: | ||||
|             list[SampleRequest]: A list of sample requests generated from the | ||||
| @ -175,7 +181,10 @@ class BenchmarkDataset(ABC): | ||||
|         raise NotImplementedError("sample must be implemented in subclasses.") | ||||
|  | ||||
|     def maybe_oversample_requests( | ||||
|         self, requests: list[SampleRequest], num_requests: int | ||||
|         self, | ||||
|         requests: list[SampleRequest], | ||||
|         num_requests: int, | ||||
|         request_id_prefix: str = "", | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Oversamples the list of requests if its size is less than the desired | ||||
| @ -183,11 +192,18 @@ class BenchmarkDataset(ABC): | ||||
|  | ||||
|         Args: | ||||
|             requests (List[SampleRequest]): The current list of sampled | ||||
|             requests.  num_requests (int): The target number of requests. | ||||
|             requests. | ||||
|             num_requests (int): The target number of requests. | ||||
|             request_id_prefix (str) The prefix of the request ids. | ||||
|         """ | ||||
|         if len(requests) < num_requests: | ||||
|             random.seed(self.random_seed) | ||||
|             additional = random.choices(requests, k=num_requests - len(requests)) | ||||
|             additional = deepcopy( | ||||
|                 random.choices(requests, k=num_requests - len(requests)) | ||||
|             ) | ||||
|             for i in range(len(additional)): | ||||
|                 req = additional[i] | ||||
|                 req.request_id = request_id_prefix + str(len(requests) + i) | ||||
|             requests.extend(additional) | ||||
|             logger.info("Oversampled requests to reach %d total samples.", num_requests) | ||||
|  | ||||
| @ -277,6 +293,41 @@ def process_image(image: Any) -> Mapping[str, Any]: | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def process_video(video: Any) -> Mapping[str, Any]: | ||||
|     """ | ||||
|     Process a single video input and return a multimedia content dictionary. | ||||
|  | ||||
|     Supports the following input types: | ||||
|  | ||||
|     1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key | ||||
|        containing raw video data. | ||||
|  | ||||
|     2. String input: - Treats the string as a URL or local file path.  - | ||||
|        Prepends "file://" if the string doesn't start with "http://" or | ||||
|        "file://".  - Returns a dictionary with the image URL. | ||||
|  | ||||
|     Raises: | ||||
|         ValueError: If the input is not a supported type. | ||||
|     """ | ||||
|     if isinstance(video, dict) and "bytes" in video: | ||||
|         video_bytes = video["bytes"] | ||||
|         video_base64 = base64.b64encode(video_bytes).decode("utf-8") | ||||
|         return { | ||||
|             "type": "video_url", | ||||
|             "video_url": {"url": f"data:video/mp4;base64,{video_base64}"}, | ||||
|         } | ||||
|  | ||||
|     if isinstance(video, str): | ||||
|         video_url = ( | ||||
|             video if video.startswith(("http://", "file://")) else f"file://{video}" | ||||
|         ) | ||||
|         return {"type": "video_url", "video_url": {"url": video_url}} | ||||
|  | ||||
|     raise ValueError( | ||||
|         f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`."  # noqa: E501 | ||||
|     ) | ||||
|  | ||||
|  | ||||
| # ----------------------------------------------------------------------------- | ||||
| # Random Dataset Implementation (Synthetic Data) | ||||
| # ----------------------------------------------------------------------------- | ||||
| @ -303,6 +354,7 @@ class RandomDataset(BenchmarkDataset): | ||||
|         range_ratio: float = DEFAULT_RANGE_RATIO, | ||||
|         input_len: int = DEFAULT_INPUT_LEN, | ||||
|         output_len: int = DEFAULT_OUTPUT_LEN, | ||||
|         request_id_prefix: str = "", | ||||
|         **kwargs, | ||||
|     ) -> list[SampleRequest]: | ||||
|         # Enforce range_ratio < 1 | ||||
| @ -363,8 +415,10 @@ class RandomDataset(BenchmarkDataset): | ||||
|                     prompt=prompt, | ||||
|                     prompt_len=total_input_len, | ||||
|                     expected_output_len=int(output_lens[i]), | ||||
|                     request_id=request_id_prefix + str(i), | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|         return requests | ||||
|  | ||||
|  | ||||
| @ -406,9 +460,11 @@ class ShareGPTDataset(BenchmarkDataset): | ||||
|         max_loras: Optional[int] = None, | ||||
|         output_len: Optional[int] = None, | ||||
|         enable_multimodal_chat: bool = False, | ||||
|         request_id_prefix: str = "", | ||||
|         **kwargs, | ||||
|     ) -> list: | ||||
|         samples: list = [] | ||||
|         ind = 0 | ||||
|         for entry in self.data: | ||||
|             if len(samples) >= num_requests: | ||||
|                 break | ||||
| @ -430,9 +486,10 @@ class ShareGPTDataset(BenchmarkDataset): | ||||
|                 skip_min_output_len_check=output_len is not None, | ||||
|             ): | ||||
|                 continue | ||||
|             # TODO: Also support ShareGPT4Video. | ||||
|             if image_path := entry.get("image"): | ||||
|                 mm_content = process_image(image_path) | ||||
|             elif video_path := entry.get("video"): | ||||
|                 mm_content = process_video(video_path) | ||||
|             else: | ||||
|                 mm_content = None | ||||
|             if enable_multimodal_chat: | ||||
| @ -444,9 +501,11 @@ class ShareGPTDataset(BenchmarkDataset): | ||||
|                     expected_output_len=new_output_len, | ||||
|                     lora_request=lora_request, | ||||
|                     multi_modal_data=mm_content, | ||||
|                     request_id=request_id_prefix + str(ind), | ||||
|                 ) | ||||
|             ) | ||||
|         self.maybe_oversample_requests(samples, num_requests) | ||||
|             ind += 1 | ||||
|         self.maybe_oversample_requests(samples, num_requests, request_id_prefix) | ||||
|         return samples | ||||
|  | ||||
|  | ||||
| @ -512,10 +571,11 @@ class CustomDataset(BenchmarkDataset): | ||||
|         output_len: Optional[int] = None, | ||||
|         enable_multimodal_chat: bool = False, | ||||
|         skip_chat_template: bool = False, | ||||
|         request_id_prefix: str = "", | ||||
|         **kwargs, | ||||
|     ) -> list: | ||||
|         sampled_requests = [] | ||||
|         for item in self.data: | ||||
|         for i, item in enumerate(self.data): | ||||
|             if len(sampled_requests) >= num_requests: | ||||
|                 break | ||||
|             prompt = item["prompt"] | ||||
| @ -534,9 +594,12 @@ class CustomDataset(BenchmarkDataset): | ||||
|                     prompt=prompt, | ||||
|                     prompt_len=prompt_len, | ||||
|                     expected_output_len=output_len, | ||||
|                     request_id=request_id_prefix + str(i), | ||||
|                 ) | ||||
|             ) | ||||
|         self.maybe_oversample_requests(sampled_requests, num_requests) | ||||
|         self.maybe_oversample_requests( | ||||
|             sampled_requests, num_requests, request_id_prefix | ||||
|         ) | ||||
|  | ||||
|         return sampled_requests | ||||
|  | ||||
| @ -578,6 +641,7 @@ class SonnetDataset(BenchmarkDataset): | ||||
|         input_len: int = DEFAULT_INPUT_LEN, | ||||
|         output_len: int = DEFAULT_OUTPUT_LEN, | ||||
|         return_prompt_formatted: bool = False, | ||||
|         request_id_prefix: str = "", | ||||
|         **kwargs, | ||||
|     ) -> list: | ||||
|         # Calculate average token length for a poem line. | ||||
| @ -603,6 +667,7 @@ class SonnetDataset(BenchmarkDataset): | ||||
|         prefix_lines = self.data[:num_prefix_lines] | ||||
|  | ||||
|         samples = [] | ||||
|         ind = 0 | ||||
|         while len(samples) < num_requests: | ||||
|             extra_lines = random.choices( | ||||
|                 self.data, k=num_input_lines - num_prefix_lines | ||||
| @ -613,14 +678,17 @@ class SonnetDataset(BenchmarkDataset): | ||||
|                 msg, add_generation_prompt=True, tokenize=False | ||||
|             ) | ||||
|             prompt_len = len(tokenizer(prompt_formatted).input_ids) | ||||
|  | ||||
|             if prompt_len <= input_len: | ||||
|                 samples.append( | ||||
|                     SampleRequest( | ||||
|                         prompt=prompt_formatted if return_prompt_formatted else prompt, | ||||
|                         prompt_len=prompt_len, | ||||
|                         expected_output_len=output_len, | ||||
|                         request_id=request_id_prefix + str(ind), | ||||
|                     ) | ||||
|                 ) | ||||
|                 ind += 1 | ||||
|         return samples | ||||
|  | ||||
|  | ||||
| @ -672,6 +740,7 @@ class BurstGPTDataset(BenchmarkDataset): | ||||
|         num_requests: int, | ||||
|         max_loras: Optional[int] = None, | ||||
|         lora_path: Optional[str] = None, | ||||
|         request_id_prefix: str = "", | ||||
|         **kwargs, | ||||
|     ) -> list[SampleRequest]: | ||||
|         samples = [] | ||||
| @ -693,6 +762,7 @@ class BurstGPTDataset(BenchmarkDataset): | ||||
|                     prompt_len=input_len, | ||||
|                     expected_output_len=output_len, | ||||
|                     lora_request=lora_req, | ||||
|                     request_id=request_id_prefix + str(i), | ||||
|                 ) | ||||
|             ) | ||||
|         return samples | ||||
| @ -752,12 +822,14 @@ class ConversationDataset(HuggingFaceDataset): | ||||
|         num_requests: int, | ||||
|         output_len: Optional[int] = None, | ||||
|         enable_multimodal_chat: bool = False, | ||||
|         request_id_prefix: str = "", | ||||
|         **kwargs, | ||||
|     ) -> list: | ||||
|         # Filter examples with at least 2 conversations | ||||
|         filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) | ||||
|         sampled_requests = [] | ||||
|         dynamic_output = output_len is None | ||||
|         ind = 0 | ||||
|  | ||||
|         for item in filtered_data: | ||||
|             if len(sampled_requests) >= num_requests: | ||||
| @ -785,9 +857,13 @@ class ConversationDataset(HuggingFaceDataset): | ||||
|                     prompt_len=prompt_len, | ||||
|                     expected_output_len=output_len, | ||||
|                     multi_modal_data=mm_content, | ||||
|                     request_id=request_id_prefix + str(ind), | ||||
|                 ) | ||||
|             ) | ||||
|         self.maybe_oversample_requests(sampled_requests, num_requests) | ||||
|             ind += 1 | ||||
|         self.maybe_oversample_requests( | ||||
|             sampled_requests, num_requests, request_id_prefix | ||||
|         ) | ||||
|         return sampled_requests | ||||
|  | ||||
|  | ||||
| @ -814,11 +890,12 @@ class VisionArenaDataset(HuggingFaceDataset): | ||||
|         num_requests: int, | ||||
|         output_len: Optional[int] = None, | ||||
|         enable_multimodal_chat: bool = False, | ||||
|         request_id_prefix: str = "", | ||||
|         **kwargs, | ||||
|     ) -> list: | ||||
|         output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN | ||||
|         sampled_requests = [] | ||||
|         for item in self.data: | ||||
|         for i, item in enumerate(self.data): | ||||
|             if len(sampled_requests) >= num_requests: | ||||
|                 break | ||||
|             parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) | ||||
| @ -838,9 +915,12 @@ class VisionArenaDataset(HuggingFaceDataset): | ||||
|                     prompt_len=prompt_len, | ||||
|                     expected_output_len=output_len, | ||||
|                     multi_modal_data=mm_content, | ||||
|                     request_id=request_id_prefix + str(i), | ||||
|                 ) | ||||
|             ) | ||||
|         self.maybe_oversample_requests(sampled_requests, num_requests) | ||||
|         self.maybe_oversample_requests( | ||||
|             sampled_requests, num_requests, request_id_prefix | ||||
|         ) | ||||
|         return sampled_requests | ||||
|  | ||||
|  | ||||
| @ -870,15 +950,18 @@ class InstructCoderDataset(HuggingFaceDataset): | ||||
|         num_requests: int, | ||||
|         output_len: Optional[int] = None, | ||||
|         enable_multimodal_chat: bool = False, | ||||
|         request_id_prefix: str = "", | ||||
|         **kwargs, | ||||
|     ) -> list: | ||||
|         output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN | ||||
|         sampled_requests = [] | ||||
|         for item in self.data: | ||||
|         for i, item in enumerate(self.data): | ||||
|             if len(sampled_requests) >= num_requests: | ||||
|                 break | ||||
|             prompt = f"{item['input']}\n\n{item['instruction']} Just output \ | ||||
|             the code, do not include any explanation." | ||||
|             prompt = ( | ||||
|                 f"{item['input']}\n\n{item['instruction']} Just output " | ||||
|                 "the code, do not include any explanation." | ||||
|             ) | ||||
|  | ||||
|             # apply template | ||||
|             prompt = tokenizer.apply_chat_template( | ||||
| @ -892,9 +975,12 @@ class InstructCoderDataset(HuggingFaceDataset): | ||||
|                     prompt=prompt, | ||||
|                     prompt_len=prompt_len, | ||||
|                     expected_output_len=output_len, | ||||
|                     request_id=request_id_prefix + str(i), | ||||
|                 ) | ||||
|             ) | ||||
|         self.maybe_oversample_requests(sampled_requests, num_requests) | ||||
|         self.maybe_oversample_requests( | ||||
|             sampled_requests, num_requests, request_id_prefix | ||||
|         ) | ||||
|         return sampled_requests | ||||
|  | ||||
|  | ||||
| @ -924,12 +1010,13 @@ class MTBenchDataset(HuggingFaceDataset): | ||||
|         num_requests: int, | ||||
|         output_len: Optional[int] = None, | ||||
|         enable_multimodal_chat: bool = False, | ||||
|         request_id_prefix: str = "", | ||||
|         **kwargs, | ||||
|     ) -> list: | ||||
|         output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN | ||||
|         sampled_requests = [] | ||||
|  | ||||
|         for item in self.data: | ||||
|         for i, item in enumerate(self.data): | ||||
|             if len(sampled_requests) >= num_requests: | ||||
|                 break | ||||
|             prompt = item["turns"][0] | ||||
| @ -947,9 +1034,12 @@ class MTBenchDataset(HuggingFaceDataset): | ||||
|                     prompt=prompt, | ||||
|                     prompt_len=prompt_len, | ||||
|                     expected_output_len=output_len, | ||||
|                     request_id=request_id_prefix + str(i), | ||||
|                 ) | ||||
|             ) | ||||
|         self.maybe_oversample_requests(sampled_requests, num_requests) | ||||
|         self.maybe_oversample_requests( | ||||
|             sampled_requests, num_requests, request_id_prefix | ||||
|         ) | ||||
|         return sampled_requests | ||||
|  | ||||
|  | ||||
| @ -974,10 +1064,12 @@ class AIMODataset(HuggingFaceDataset): | ||||
|         tokenizer: PreTrainedTokenizerBase, | ||||
|         num_requests: int, | ||||
|         output_len: Optional[int] = None, | ||||
|         request_id_prefix: str = "", | ||||
|         **kwargs, | ||||
|     ) -> list: | ||||
|         sampled_requests = [] | ||||
|         dynamic_output = output_len is None | ||||
|         ind = 0 | ||||
|  | ||||
|         for item in self.data: | ||||
|             if len(sampled_requests) >= num_requests: | ||||
| @ -1000,9 +1092,13 @@ class AIMODataset(HuggingFaceDataset): | ||||
|                     prompt_len=prompt_len, | ||||
|                     expected_output_len=output_len, | ||||
|                     multi_modal_data=None, | ||||
|                     request_id=request_id_prefix + str(ind), | ||||
|                 ) | ||||
|             ) | ||||
|         self.maybe_oversample_requests(sampled_requests, num_requests) | ||||
|             ind += 1 | ||||
|         self.maybe_oversample_requests( | ||||
|             sampled_requests, num_requests, request_id_prefix | ||||
|         ) | ||||
|         return sampled_requests | ||||
|  | ||||
|  | ||||
| @ -1072,12 +1168,18 @@ class NextEditPredictionDataset(HuggingFaceDataset): | ||||
|         "zed-industries/zeta": _format_zeta_prompt, | ||||
|     } | ||||
|  | ||||
|     def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs): | ||||
|     def sample( | ||||
|         self, | ||||
|         tokenizer: PreTrainedTokenizerBase, | ||||
|         num_requests: int, | ||||
|         request_id_prefix: str = "", | ||||
|         **kwargs, | ||||
|     ): | ||||
|         formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path) | ||||
|         if formatting_prompt_func is None: | ||||
|             raise ValueError(f"Unsupported dataset path: {self.dataset_path}") | ||||
|         samples = [] | ||||
|         for sample in self.data: | ||||
|         for i, sample in enumerate(self.data): | ||||
|             sample = formatting_prompt_func(sample) | ||||
|             samples.append( | ||||
|                 SampleRequest( | ||||
| @ -1086,11 +1188,12 @@ class NextEditPredictionDataset(HuggingFaceDataset): | ||||
|                     expected_output_len=len( | ||||
|                         tokenizer(sample["expected_output"]).input_ids | ||||
|                     ), | ||||
|                     request_id=request_id_prefix + str(i), | ||||
|                 ) | ||||
|             ) | ||||
|             if len(samples) >= num_requests: | ||||
|                 break | ||||
|         self.maybe_oversample_requests(samples, num_requests) | ||||
|         self.maybe_oversample_requests(samples, num_requests, request_id_prefix) | ||||
|         return samples | ||||
|  | ||||
|  | ||||
| @ -1139,6 +1242,7 @@ class ASRDataset(HuggingFaceDataset): | ||||
|         tokenizer: PreTrainedTokenizerBase, | ||||
|         num_requests: int, | ||||
|         output_len: Optional[int] = None, | ||||
|         request_id_prefix: str = "", | ||||
|         **kwargs, | ||||
|     ) -> list: | ||||
|         import librosa | ||||
| @ -1148,6 +1252,7 @@ class ASRDataset(HuggingFaceDataset): | ||||
|         prompt_len = len(tokenizer(prompt).input_ids) | ||||
|         sampled_requests = [] | ||||
|         skipped = 0 | ||||
|         ind = 0 | ||||
|         for item in self.data: | ||||
|             if len(sampled_requests) >= num_requests: | ||||
|                 break | ||||
| @ -1166,8 +1271,10 @@ class ASRDataset(HuggingFaceDataset): | ||||
|                     prompt_len=prompt_len, | ||||
|                     expected_output_len=output_len, | ||||
|                     multi_modal_data=mm_content, | ||||
|                     request_id=request_id_prefix + str(ind), | ||||
|                 ) | ||||
|             ) | ||||
|             ind += 1 | ||||
|         if skipped: | ||||
|             logger.warning( | ||||
|                 "%d samples discarded from dataset due to" | ||||
| @ -1175,5 +1282,7 @@ class ASRDataset(HuggingFaceDataset): | ||||
|                 " what Whisper supports.", | ||||
|                 skipped, | ||||
|             ) | ||||
|         self.maybe_oversample_requests(sampled_requests, num_requests) | ||||
|         self.maybe_oversample_requests( | ||||
|             sampled_requests, num_requests, request_id_prefix | ||||
|         ) | ||||
|         return sampled_requests | ||||
|  | ||||
| @ -375,11 +375,12 @@ async def benchmark( | ||||
|                     rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) | ||||
|                 last_int_rps = current_int_rps | ||||
|  | ||||
|         prompt, prompt_len, output_len, mm_content = ( | ||||
|         prompt, prompt_len, output_len, mm_content, request_id = ( | ||||
|             request.prompt, | ||||
|             request.prompt_len, | ||||
|             request.expected_output_len, | ||||
|             request.multi_modal_data, | ||||
|             request.request_id, | ||||
|         ) | ||||
|         req_model_id, req_model_name = model_id, model_name | ||||
|         if lora_modules: | ||||
| @ -397,6 +398,7 @@ async def benchmark( | ||||
|             multi_modal_content=mm_content, | ||||
|             ignore_eos=ignore_eos, | ||||
|             extra_body=extra_body, | ||||
|             request_id=request_id, | ||||
|         ) | ||||
|         task = limited_request_func(request_func_input=request_func_input, pbar=pbar) | ||||
|         tasks.append(asyncio.create_task(task)) | ||||
| @ -665,6 +667,7 @@ def main(args: argparse.Namespace): | ||||
|             tokenizer=tokenizer, | ||||
|             output_len=args.custom_output_len, | ||||
|             skip_chat_template=args.custom_skip_chat_template, | ||||
|             request_id_prefix=args.request_id_prefix, | ||||
|         ) | ||||
|  | ||||
|     elif args.dataset_name == "sonnet": | ||||
| @ -678,6 +681,7 @@ def main(args: argparse.Namespace): | ||||
|                 prefix_len=args.sonnet_prefix_len, | ||||
|                 tokenizer=tokenizer, | ||||
|                 return_prompt_formatted=False, | ||||
|                 request_id_prefix=args.request_id_prefix, | ||||
|             ) | ||||
|         else: | ||||
|             assert tokenizer.chat_template or tokenizer.default_chat_template, ( | ||||
| @ -690,6 +694,7 @@ def main(args: argparse.Namespace): | ||||
|                 prefix_len=args.sonnet_prefix_len, | ||||
|                 tokenizer=tokenizer, | ||||
|                 return_prompt_formatted=True, | ||||
|                 request_id_prefix=args.request_id_prefix, | ||||
|             ) | ||||
|  | ||||
|     elif args.dataset_name == "hf": | ||||
| @ -751,6 +756,7 @@ def main(args: argparse.Namespace): | ||||
|             num_requests=args.num_prompts, | ||||
|             tokenizer=tokenizer, | ||||
|             output_len=args.hf_output_len, | ||||
|             request_id_prefix=args.request_id_prefix, | ||||
|         ) | ||||
|  | ||||
|     else: | ||||
| @ -762,10 +768,15 @@ def main(args: argparse.Namespace): | ||||
|                 tokenizer=tokenizer, | ||||
|                 num_requests=args.num_prompts, | ||||
|                 output_len=args.sharegpt_output_len, | ||||
|                 request_id_prefix=args.request_id_prefix, | ||||
|             ), | ||||
|             "burstgpt": lambda: BurstGPTDataset( | ||||
|                 random_seed=args.seed, dataset_path=args.dataset_path | ||||
|             ).sample(tokenizer=tokenizer, num_requests=args.num_prompts), | ||||
|             ).sample( | ||||
|                 tokenizer=tokenizer, | ||||
|                 num_requests=args.num_prompts, | ||||
|                 request_id_prefix=args.request_id_prefix, | ||||
|             ), | ||||
|             "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample( | ||||
|                 tokenizer=tokenizer, | ||||
|                 num_requests=args.num_prompts, | ||||
| @ -773,6 +784,7 @@ def main(args: argparse.Namespace): | ||||
|                 input_len=args.random_input_len, | ||||
|                 output_len=args.random_output_len, | ||||
|                 range_ratio=args.random_range_ratio, | ||||
|                 request_id_prefix=args.request_id_prefix, | ||||
|             ), | ||||
|         } | ||||
|  | ||||
| @ -1118,6 +1130,13 @@ def create_argument_parser(): | ||||
|         "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " | ||||
|         "and the blog: https://hao-ai-lab.github.io/blogs/distserve", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--request-id-prefix", | ||||
|         type=str, | ||||
|         required=False, | ||||
|         default="benchmark-serving", | ||||
|         help="Specify the prefix of request id.", | ||||
|     ) | ||||
|  | ||||
|     # group for dataset specific arguments | ||||
|     custom_group = parser.add_argument_group("custom dataset options") | ||||
|  | ||||
| @ -96,7 +96,6 @@ def run_vllm( | ||||
|         end = time.perf_counter() | ||||
|     else: | ||||
|         assert lora_requests is None, "BeamSearch API does not support LoRA" | ||||
|         prompts = [request.prompt for request in requests] | ||||
|         # output_len should be the same for all requests. | ||||
|         output_len = requests[0].expected_output_len | ||||
|         for request in requests: | ||||
| @ -597,8 +596,8 @@ def validate_args(args): | ||||
|     # https://github.com/vllm-project/vllm/issues/16222 | ||||
|     if args.data_parallel_size > 1: | ||||
|         raise ValueError( | ||||
|             "Data parallel is not supported in offline benchmark, \ | ||||
|             please use benchmark serving instead" | ||||
|             "Data parallel is not supported in offline benchmark, " | ||||
|             "please use benchmark serving instead" | ||||
|         ) | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -80,6 +80,11 @@ def bench_run( | ||||
|         a, score, topk, renormalize=False | ||||
|     ) | ||||
|  | ||||
|     ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) | ||||
|     ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64) | ||||
|     c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64) | ||||
|     c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) | ||||
|  | ||||
|     def run_triton_moe( | ||||
|         a: torch.Tensor, | ||||
|         w1: torch.Tensor, | ||||
| @ -111,6 +116,10 @@ def bench_run( | ||||
|         w2: torch.Tensor, | ||||
|         w1_scale: torch.Tensor, | ||||
|         w2_scale: torch.Tensor, | ||||
|         ab_strides1: torch.Tensor, | ||||
|         ab_strides2: torch.Tensor, | ||||
|         c_strides1: torch.Tensor, | ||||
|         c_strides2: torch.Tensor, | ||||
|         topk_weights: torch.Tensor, | ||||
|         topk_ids: torch.Tensor, | ||||
|         per_act_token: bool, | ||||
| @ -125,6 +134,10 @@ def bench_run( | ||||
|                 topk_ids, | ||||
|                 w1_scale, | ||||
|                 w2_scale, | ||||
|                 ab_strides1, | ||||
|                 ab_strides2, | ||||
|                 c_strides1, | ||||
|                 c_strides2, | ||||
|                 per_act_token, | ||||
|                 a1_scale=None, | ||||
|             ) | ||||
| @ -136,6 +149,10 @@ def bench_run( | ||||
|         w2_q: torch.Tensor, | ||||
|         w1_scale: torch.Tensor, | ||||
|         w2_scale: torch.Tensor, | ||||
|         ab_strides1: torch.Tensor, | ||||
|         ab_strides2: torch.Tensor, | ||||
|         c_strides1: torch.Tensor, | ||||
|         c_strides2: torch.Tensor, | ||||
|         topk_weights: torch.Tensor, | ||||
|         topk_ids: torch.Tensor, | ||||
|     ): | ||||
| @ -150,6 +167,10 @@ def bench_run( | ||||
|                 topk_ids, | ||||
|                 w1_scale, | ||||
|                 w2_scale, | ||||
|                 ab_strides1, | ||||
|                 ab_strides2, | ||||
|                 c_strides1, | ||||
|                 c_strides2, | ||||
|                 per_act_token, | ||||
|                 a1_scale=None, | ||||
|             ) | ||||
| @ -194,6 +215,10 @@ def bench_run( | ||||
|             w2_q, | ||||
|             w1_scale, | ||||
|             w2_scale, | ||||
|             ab_strides1, | ||||
|             ab_strides2, | ||||
|             c_strides1, | ||||
|             c_strides2, | ||||
|             topk_weights, | ||||
|             topk_ids, | ||||
|         ) | ||||
| @ -231,6 +256,10 @@ def bench_run( | ||||
|         "w1_scale": w1_scale, | ||||
|         "w2_scale": w2_scale, | ||||
|         "per_act_token": per_act_token, | ||||
|         "ab_strides1": ab_strides1, | ||||
|         "ab_strides2": ab_strides2, | ||||
|         "c_strides1": c_strides1, | ||||
|         "c_strides2": c_strides2, | ||||
|         # cuda graph params | ||||
|         "cutlass_graph": cutlass_graph, | ||||
|         "triton_graph": triton_graph, | ||||
| @ -289,6 +318,10 @@ def bench_run( | ||||
|         w2_q, | ||||
|         w1_scale, | ||||
|         w2_scale, | ||||
|         ab_strides1, | ||||
|         ab_strides2, | ||||
|         c_strides1, | ||||
|         c_strides2, | ||||
|         topk_weights, | ||||
|         topk_ids, | ||||
|         per_act_token, | ||||
| @ -297,7 +330,7 @@ def bench_run( | ||||
|  | ||||
|     results.append( | ||||
|         benchmark.Timer( | ||||
|             stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)",  # noqa: E501 | ||||
|             stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)",  # noqa: E501 | ||||
|             globals=globals, | ||||
|             label=label, | ||||
|             sub_label=sub_label, | ||||
|  | ||||
| @ -253,28 +253,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: | ||||
|     else: | ||||
|         assert bt.a.dtype == torch.int8 | ||||
|         assert bt.wtype == scalar_types.uint4b8 | ||||
|  | ||||
|         if bt.w_ch_s is not None: | ||||
|             s_ch = bt.w_ch_s.to(torch.float32) | ||||
|         else: | ||||
|             s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device) | ||||
|  | ||||
|         if bt.w_tok_s is not None: | ||||
|             s_tok = bt.w_tok_s.to(torch.float32) | ||||
|         else: | ||||
|             s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device) | ||||
|  | ||||
|         fn = lambda: ops.marlin_qqq_gemm( | ||||
|             a=bt.a, | ||||
|             b_q_weight=w_q, | ||||
|             s_group=w_s, | ||||
|             s_tok=s_tok, | ||||
|             s_ch=s_ch, | ||||
|             workspace=workspace.scratch, | ||||
|             size_m=bt.a.shape[0], | ||||
|             size_n=bt.w_ref.shape[1], | ||||
|             size_k=bt.w_ref.shape[0], | ||||
|         ) | ||||
|         raise NotImplementedError("QQQ is not supported anymore") | ||||
|  | ||||
|     return fn | ||||
|  | ||||
| @ -305,6 +284,25 @@ def machete_create_bench_fn( | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def cutlass_w4a8_create_bench_fn( | ||||
|     bt: BenchmarkTensors, out_type=torch.dtype, schedule=None | ||||
| ) -> Callable: | ||||
|     w_q = bt.w_q.t().contiguous().t()  # make col major | ||||
|     w_q = ops.cutlass_encode_and_reorder_int4b(w_q) | ||||
|     # expects fp8 scales | ||||
|     w_s = ops.cutlass_pack_scale_fp8(bt.w_g_s.to(torch.float8_e4m3fn)) | ||||
|  | ||||
|     return lambda: ops.cutlass_w4a8_mm( | ||||
|         a=bt.a, | ||||
|         b_q=w_q, | ||||
|         b_group_scales=w_s, | ||||
|         b_group_size=bt.group_size, | ||||
|         b_channel_scales=bt.w_ch_s, | ||||
|         a_token_scales=bt.w_tok_s, | ||||
|         maybe_schedule=schedule, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| # impl | ||||
|  | ||||
| # bench | ||||
| @ -406,6 +404,20 @@ def bench( | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     # cutlass w4a8 | ||||
|     if types.act_type == torch.float8_e4m3fn and group_size == 128: | ||||
|         timers.append( | ||||
|             bench_fns( | ||||
|                 label, | ||||
|                 sub_label, | ||||
|                 f"cutlass w4a8 ({name_type_string})", | ||||
|                 [ | ||||
|                     cutlass_w4a8_create_bench_fn(bt, out_type=types.output_type) | ||||
|                     for bt in benchmark_tensors | ||||
|                 ], | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     if sweep_schedules: | ||||
|         global _SWEEP_SCHEDULES_RESULTS | ||||
|  | ||||
|  | ||||
| @ -430,7 +430,6 @@ class BenchmarkWorker: | ||||
|                 hidden_size, | ||||
|                 topk, | ||||
|                 dtype_str, | ||||
|                 is_marlin=False, | ||||
|             ) | ||||
|         else: | ||||
|             config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] | ||||
|  | ||||
							
								
								
									
										77
									
								
								benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,77 @@ | ||||
| #!/usr/bin/env python3 | ||||
| # SPDX-License-Identifier: Apache-2.0 | ||||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||||
| import time | ||||
|  | ||||
| import torch | ||||
|  | ||||
| from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( | ||||
|     silu_mul_fp8_quant_deep_gemm, | ||||
| ) | ||||
| from vllm.platforms import current_platform | ||||
|  | ||||
|  | ||||
| def benchmark(E, T, H, G=128, runs=50): | ||||
|     current_platform.seed_everything(42) | ||||
|     y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda") | ||||
|     tokens_per_expert = torch.randint( | ||||
|         T // 2, T, size=(E,), dtype=torch.int32, device="cuda" | ||||
|     ) | ||||
|  | ||||
|     # Warmup | ||||
|     for _ in range(10): | ||||
|         silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) | ||||
|         torch.cuda.synchronize() | ||||
|  | ||||
|     # Benchmark | ||||
|     torch.cuda.synchronize() | ||||
|     start = time.perf_counter() | ||||
|     for _ in range(runs): | ||||
|         silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G) | ||||
|     torch.cuda.synchronize() | ||||
|  | ||||
|     avg_time = (time.perf_counter() - start) / runs * 1000 | ||||
|  | ||||
|     # Calculate actual work done (only count valid tokens) | ||||
|     actual_tokens = tokens_per_expert.sum().item() | ||||
|     actual_elements = actual_tokens * H | ||||
|  | ||||
|     # GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops | ||||
|     ops_per_element = 8 | ||||
|     total_ops = actual_elements * ops_per_element | ||||
|     gflops = total_ops / (avg_time / 1000) / 1e9 | ||||
|  | ||||
|     # Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes) | ||||
|     input_bytes = actual_tokens * 2 * H * 2  # 2*H bfloat16 inputs | ||||
|     output_bytes = actual_tokens * H * 1  # H fp8 outputs | ||||
|     scale_bytes = actual_tokens * (H // G) * 4  # scales in float32 | ||||
|     total_bytes = input_bytes + output_bytes + scale_bytes | ||||
|     memory_bw = total_bytes / (avg_time / 1000) / 1e9 | ||||
|  | ||||
|     return avg_time, gflops, memory_bw | ||||
|  | ||||
|  | ||||
| configs = [ | ||||
|     (8, 32, 1024), | ||||
|     (16, 64, 2048), | ||||
|     (32, 128, 4096), | ||||
|     # DeepSeekV3 Configs | ||||
|     (256, 16, 7168), | ||||
|     (256, 32, 7168), | ||||
|     (256, 64, 7168), | ||||
|     (256, 128, 7168), | ||||
|     (256, 256, 7168), | ||||
|     (256, 512, 7168), | ||||
|     (256, 1024, 7168), | ||||
| ] | ||||
|  | ||||
| print(f"GPU: {torch.cuda.get_device_name()}") | ||||
| print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}") | ||||
| print("-" * 50) | ||||
|  | ||||
| for E, T, H in configs: | ||||
|     try: | ||||
|         time_ms, gflops, gbps = benchmark(E, T, H) | ||||
|         print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}") | ||||
|     except Exception: | ||||
|         print(f"E={E:3d},T={T:4d},H={H:4d} FAILED") | ||||
| @ -3,16 +3,17 @@ | ||||
|  | ||||
| import csv | ||||
| import os | ||||
| import random | ||||
| from datetime import datetime | ||||
| from typing import Optional | ||||
|  | ||||
| import flashinfer | ||||
| import torch | ||||
|  | ||||
| FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 | ||||
| from vllm.utils import round_up | ||||
|  | ||||
| # KV Cache Layout for TRT-LLM | ||||
| # kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) | ||||
| FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 | ||||
| FP8_DTYPE = torch.float8_e4m3fn | ||||
| FP4_DTYPE = torch.uint8 | ||||
|  | ||||
|  | ||||
| def to_float8(x, dtype=torch.float8_e4m3fn): | ||||
| @ -26,65 +27,106 @@ def to_float8(x, dtype=torch.float8_e4m3fn): | ||||
|  | ||||
| @torch.no_grad() | ||||
| def benchmark_decode( | ||||
|     num_seqs, | ||||
|     max_seq_len, | ||||
|     page_size=16, | ||||
|     dtype=torch.bfloat16, | ||||
|     kv_layout="HND", | ||||
|     num_kv_heads=8, | ||||
|     kv_cache_dtype="auto", | ||||
|     head_dim=128, | ||||
|     warmup=10, | ||||
|     trials=20, | ||||
|     dtype: torch.dtype, | ||||
|     quant_dtypes: tuple[ | ||||
|         Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] | ||||
|     ], | ||||
|     batch_size: int, | ||||
|     max_seq_len: int, | ||||
|     num_heads: tuple[int, int] = (64, 8), | ||||
|     head_size: int = 128, | ||||
|     kv_layout: str = "HND", | ||||
|     block_size: int = 16, | ||||
|     warmup: int = 10, | ||||
|     trials: int = 20, | ||||
| ): | ||||
|     torch.set_default_device("cuda") | ||||
|     device = "cuda" | ||||
|     torch.manual_seed(0) | ||||
|  | ||||
|     HEAD_GRP_SIZE = 8 | ||||
|     MAX_SEQ_LEN = max_seq_len | ||||
|     q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes | ||||
|     q_quant_dtype = q_quant_dtype or dtype | ||||
|     kv_quant_dtype = kv_quant_dtype or dtype | ||||
|     o_quant_dtype = o_quant_dtype or dtype | ||||
|  | ||||
|     num_qo_heads, num_kv_heads = num_heads | ||||
|     assert num_qo_heads % num_kv_heads == 0 | ||||
|  | ||||
|     sm_scale = float(1.0 / (head_size**0.5)) | ||||
|  | ||||
|     # large number to reduce kv_cache reuse | ||||
|     NUM_BLOCKS = int(256000 / page_size) | ||||
|     NUM_BLOCKS = int(256000 / block_size) | ||||
|  | ||||
|     workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device) | ||||
|     kv_cache_shape = None | ||||
|     if kv_layout == "NHD": | ||||
|         kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) | ||||
|     elif kv_layout == "HND": | ||||
|         kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) | ||||
|     else: | ||||
|         raise ValueError(f"Invalid kv_layout: {kv_layout}") | ||||
|  | ||||
|     # For decode, batch_size is num_decode_token | ||||
|     num_qo_heads = num_kv_heads * HEAD_GRP_SIZE | ||||
|     sm_scale = float(1.0 / (head_dim**0.5)) | ||||
|     q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype) | ||||
|     kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] | ||||
|     # Always using 1.0 scale to reflect the real perf in benchmarking | ||||
|     q_scale = 1.0 | ||||
|     ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) | ||||
|     if q_quant_dtype == FP8_DTYPE: | ||||
|         query, _ = to_float8(ref_query) | ||||
|     else: | ||||
|         query = ref_query | ||||
|  | ||||
|     max_kv_len = max(kv_lens) | ||||
|     kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device) | ||||
|     max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size | ||||
|     kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32) | ||||
|     kv_lens[-1] = max_seq_len | ||||
|  | ||||
|     block_tables = torch.randint( | ||||
|         0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 | ||||
|     ) | ||||
|     seq_lens = kv_lens | ||||
|     max_seq_len = torch.max(seq_lens).item() | ||||
|  | ||||
|     kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) | ||||
|     kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype) | ||||
|     # Always using 1.0 scale to reflect the real perf in benchmarking | ||||
|     k_scale = v_scale = 1.0 | ||||
|     ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype) | ||||
|     if kv_quant_dtype == FP8_DTYPE: | ||||
|         kv_cache, _ = to_float8(ref_kv_cache) | ||||
|     else: | ||||
|         kv_cache = ref_kv_cache | ||||
|  | ||||
|     if kv_cache_dtype.startswith("fp8"): | ||||
|         kv_cache, _ = to_float8(kv_cache) | ||||
|     max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size | ||||
|     block_tables = torch.randint( | ||||
|         0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 | ||||
|     ) | ||||
|     kv_indptr = [0] | ||||
|     kv_indices = [] | ||||
|     kv_last_page_lens = [] | ||||
|     for i in range(batch_size): | ||||
|         seq_len = seq_lens[i] | ||||
|         assert seq_len > 0 | ||||
|         num_blocks = (seq_len + block_size - 1) // block_size | ||||
|         kv_indices.extend(block_tables[i, :num_blocks]) | ||||
|         kv_indptr.append(kv_indptr[-1] + num_blocks) | ||||
|         kv_last_page_len = seq_len % block_size | ||||
|         if kv_last_page_len == 0: | ||||
|             kv_last_page_len = block_size | ||||
|         kv_last_page_lens.append(kv_last_page_len) | ||||
|  | ||||
|     output_trtllm = torch.empty(q.shape, dtype=dtype) | ||||
|     kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) | ||||
|     kv_indices = torch.tensor(kv_indices, dtype=torch.int32) | ||||
|     kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) | ||||
|     workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8) | ||||
|  | ||||
|     # Benchmark TRT decode | ||||
|     def trt_decode(): | ||||
|         return flashinfer.decode.trtllm_batch_decode_with_kv_cache( | ||||
|             q, | ||||
|             kv_cache, | ||||
|             workspace_buffer, | ||||
|             block_tables, | ||||
|             kv_lens_tensor, | ||||
|             max_kv_len, | ||||
|             bmm1_scale=k_scale * sm_scale, | ||||
|             bmm2_scale=v_scale, | ||||
|             out=output_trtllm, | ||||
|         ) | ||||
|     wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( | ||||
|         workspace_buffer, | ||||
|         kv_layout, | ||||
|         use_tensor_cores=True, | ||||
|     ) | ||||
|     wrapper.plan( | ||||
|         kv_indptr, | ||||
|         kv_indices, | ||||
|         kv_last_page_lens, | ||||
|         num_qo_heads, | ||||
|         num_kv_heads, | ||||
|         head_size, | ||||
|         block_size, | ||||
|         "NONE", | ||||
|         sm_scale=sm_scale, | ||||
|         q_data_type=dtype, | ||||
|         kv_data_type=dtype, | ||||
|     ) | ||||
|  | ||||
|     def time_fn(fn, warmup=10, trials=20): | ||||
|         torch.cuda.synchronize() | ||||
| @ -101,74 +143,72 @@ def benchmark_decode( | ||||
|             times.append(start.elapsed_time(end))  # ms | ||||
|         return sum(times) / len(times), torch.std(torch.tensor(times)) | ||||
|  | ||||
|     # TRT Decode | ||||
|     trt_mean, trt_std = time_fn(trt_decode) | ||||
|  | ||||
|     kv_indptr = [0] | ||||
|     kv_indices = [] | ||||
|     kv_last_page_lens = [] | ||||
|     for i in range(num_seqs): | ||||
|         seq_len = kv_lens[i] | ||||
|         assert seq_len > 0 | ||||
|         num_blocks = (seq_len + page_size - 1) // page_size | ||||
|         kv_indices.extend(block_tables[i, :num_blocks]) | ||||
|         kv_indptr.append(kv_indptr[-1] + num_blocks) | ||||
|         kv_last_page_len = seq_len % page_size | ||||
|         if kv_last_page_len == 0: | ||||
|             kv_last_page_len = page_size | ||||
|         kv_last_page_lens.append(kv_last_page_len) | ||||
|  | ||||
|     kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) | ||||
|     kv_indices = torch.tensor(kv_indices, dtype=torch.int32) | ||||
|     kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) | ||||
|  | ||||
|     output_baseline = torch.empty(q.shape, dtype=dtype) | ||||
|  | ||||
|     wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( | ||||
|         workspace_buffer, | ||||
|         kv_layout, | ||||
|         use_tensor_cores=((num_qo_heads // num_kv_heads) > 4), | ||||
|     ) | ||||
|  | ||||
|     wrapper.plan( | ||||
|         kv_indptr, | ||||
|         kv_indices, | ||||
|         kv_last_page_lens, | ||||
|         num_qo_heads, | ||||
|         num_kv_heads, | ||||
|         head_dim, | ||||
|         page_size, | ||||
|         "NONE", | ||||
|         q_data_type=dtype, | ||||
|         kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype, | ||||
|     ) | ||||
|     o_scale = 1.0 | ||||
|     o_sf_scale = None | ||||
|     output_baseline = torch.empty(ref_query.shape, dtype=dtype) | ||||
|     if o_quant_dtype == FP4_DTYPE: | ||||
|         o_sf_scale = 500.0 | ||||
|         output_trtllm = flashinfer.utils.FP4Tensor( | ||||
|             torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), | ||||
|             torch.empty( | ||||
|                 ( | ||||
|                     round_up(query.shape[0], 128), | ||||
|                     round_up(query.shape[1] * query.shape[2] // 16, 4), | ||||
|                 ), | ||||
|                 dtype=torch.float8_e4m3fn, | ||||
|             ), | ||||
|         ) | ||||
|     else: | ||||
|         output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) | ||||
|  | ||||
|     def baseline_decode(): | ||||
|         return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline) | ||||
|         return wrapper.run( | ||||
|             ref_query, | ||||
|             ref_kv_cache, | ||||
|             k_scale=k_scale, | ||||
|             v_scale=v_scale, | ||||
|             out=output_baseline, | ||||
|         ) | ||||
|  | ||||
|     def trtllm_decode(): | ||||
|         return flashinfer.decode.trtllm_batch_decode_with_kv_cache( | ||||
|             query=query, | ||||
|             kv_cache=kv_cache, | ||||
|             workspace_buffer=workspace_buffer, | ||||
|             block_tables=block_tables, | ||||
|             seq_lens=seq_lens, | ||||
|             max_seq_len=max_seq_len, | ||||
|             bmm1_scale=q_scale * k_scale * sm_scale, | ||||
|             bmm2_scale=v_scale / o_scale, | ||||
|             o_sf_scale=o_sf_scale, | ||||
|             out=output_trtllm, | ||||
|         ) | ||||
|  | ||||
|     baseline_mean, baseline_std = time_fn(baseline_decode) | ||||
|     trtllm_mean, trtllm_std = time_fn(trtllm_decode) | ||||
|  | ||||
|     # Calculate percentage speedup (positive means TRT is faster) | ||||
|     speedup_percent = (baseline_mean - trt_mean) / baseline_mean | ||||
|     speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean | ||||
|  | ||||
|     print( | ||||
|         f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}" | ||||
|         f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}" | ||||
|         f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}" | ||||
|     ) | ||||
|  | ||||
|     # Return results for CSV writing | ||||
|     return { | ||||
|         "num_seqs": num_seqs, | ||||
|         "trt_mean": trt_mean, | ||||
|         "trt_std": trt_std.item(), | ||||
|         "batch_size": batch_size, | ||||
|         "trtllm_mean": trtllm_mean, | ||||
|         "trtllm_std": trtllm_std.item(), | ||||
|         "baseline_mean": baseline_mean, | ||||
|         "baseline_std": baseline_std.item(), | ||||
|         "speedup_percent": speedup_percent, | ||||
|         "q_dtype": str(dtype), | ||||
|         "kv_cache_dtype": kv_cache_dtype, | ||||
|         "page_size": page_size, | ||||
|         "q_dtype": str(q_quant_dtype), | ||||
|         "kv_cache_dtype": str(kv_quant_dtype), | ||||
|         "output_dtype": str(o_quant_dtype), | ||||
|         "block_size": block_size, | ||||
|         "num_kv_heads": num_kv_heads, | ||||
|         "head_dim": head_dim, | ||||
|         "head_size": head_size, | ||||
|         "max_seq_len": max_seq_len, | ||||
|     } | ||||
|  | ||||
| @ -180,17 +220,18 @@ def write_results_to_csv(results, filename=None): | ||||
|         filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" | ||||
|  | ||||
|     fieldnames = [ | ||||
|         "num_seqs", | ||||
|         "trt_mean", | ||||
|         "trt_std", | ||||
|         "batch_size", | ||||
|         "trtllm_mean", | ||||
|         "trtllm_std", | ||||
|         "baseline_mean", | ||||
|         "baseline_std", | ||||
|         "speedup_percent", | ||||
|         "q_dtype", | ||||
|         "kv_cache_dtype", | ||||
|         "page_size", | ||||
|         "output_dtype", | ||||
|         "block_size", | ||||
|         "num_kv_heads", | ||||
|         "head_dim", | ||||
|         "head_size", | ||||
|         "max_seq_len", | ||||
|     ] | ||||
|  | ||||
| @ -209,45 +250,43 @@ def write_results_to_csv(results, filename=None): | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] | ||||
|     batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256] | ||||
|     max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] | ||||
|     all_results = [] | ||||
|  | ||||
|     print( | ||||
|         "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " | ||||
|         "output_dtype: bfloat16" | ||||
|     ) | ||||
|     print( | ||||
|         "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" | ||||
|         "baseline_std\tspeedup_percent" | ||||
|     ) | ||||
|     for max_seq_len in max_seq_lens: | ||||
|         for bs in num_seqs: | ||||
|             result = benchmark_decode( | ||||
|                 bs, | ||||
|                 max_seq_len, | ||||
|                 dtype=torch.bfloat16, | ||||
|                 kv_cache_dtype="auto", | ||||
|             ) | ||||
|             all_results.append(result) | ||||
|     dtype = torch.bfloat16 | ||||
|     quant_dtypes = [ | ||||
|         # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) | ||||
|         (None, None, None), | ||||
|         (None, FP8_DTYPE, None), | ||||
|         (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), | ||||
|         (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), | ||||
|     ] | ||||
|  | ||||
|     print( | ||||
|         "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, " | ||||
|         "output_dtype: bfloat16" | ||||
|     ) | ||||
|     print( | ||||
|         "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" | ||||
|         "baseline_std\tspeedup_percent" | ||||
|     ) | ||||
|     for max_seq_len in max_seq_lens: | ||||
|         for bs in num_seqs: | ||||
|             result = benchmark_decode( | ||||
|                 bs, | ||||
|                 max_seq_len, | ||||
|                 dtype=torch.bfloat16, | ||||
|                 kv_cache_dtype="fp8", | ||||
|             ) | ||||
|             all_results.append(result) | ||||
|     for quant_dtype in quant_dtypes: | ||||
|         q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype | ||||
|         q_quant_dtype = q_quant_dtype or dtype | ||||
|         kv_quant_dtype = kv_quant_dtype or dtype | ||||
|         o_quant_dtype = o_quant_dtype or dtype | ||||
|  | ||||
|         print( | ||||
|             f"Running benchmark for q_dtype = {q_quant_dtype}, " | ||||
|             f"kv_cache_dtype: {kv_quant_dtype}, " | ||||
|             f"output_dtype: {o_quant_dtype}" | ||||
|         ) | ||||
|         print( | ||||
|             "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t" | ||||
|             "baseline_std\tspeedup_percent" | ||||
|         ) | ||||
|         for max_seq_len in max_seq_lens: | ||||
|             for bs in batch_sizes: | ||||
|                 result = benchmark_decode( | ||||
|                     dtype=dtype, | ||||
|                     quant_dtypes=quant_dtype, | ||||
|                     batch_size=bs, | ||||
|                     max_seq_len=max_seq_len, | ||||
|                 ) | ||||
|                 all_results.append(result) | ||||
|  | ||||
|     # Write all results to CSV | ||||
|     write_results_to_csv(all_results) | ||||
|  | ||||
| @ -3,16 +3,17 @@ | ||||
|  | ||||
| import csv | ||||
| import os | ||||
| import random | ||||
| from datetime import datetime | ||||
| from typing import Optional | ||||
|  | ||||
| import flashinfer | ||||
| import torch | ||||
|  | ||||
| FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 | ||||
| from vllm.utils import round_up | ||||
|  | ||||
| # KV Cache Layout for TRT-LLM | ||||
| # kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) | ||||
| FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 | ||||
| FP8_DTYPE = torch.float8_e4m3fn | ||||
| FP4_DTYPE = torch.uint8 | ||||
|  | ||||
|  | ||||
| def to_float8(x, dtype=torch.float8_e4m3fn): | ||||
| @ -26,84 +27,100 @@ def to_float8(x, dtype=torch.float8_e4m3fn): | ||||
|  | ||||
| @torch.no_grad() | ||||
| def benchmark_prefill( | ||||
|     num_seqs, | ||||
|     max_seq_len, | ||||
|     page_size=16, | ||||
|     dtype=torch.bfloat16, | ||||
|     kv_layout="HND", | ||||
|     num_kv_heads=8, | ||||
|     kv_cache_dtype="auto", | ||||
|     head_dim=128, | ||||
|     warmup=10, | ||||
|     trials=20, | ||||
|     dtype: torch.dtype, | ||||
|     quant_dtypes: tuple[ | ||||
|         Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype] | ||||
|     ], | ||||
|     batch_size: int, | ||||
|     max_seq_len: int, | ||||
|     num_heads: tuple[int, int] = (64, 8), | ||||
|     head_size: int = 128, | ||||
|     kv_layout: str = "HND", | ||||
|     block_size: int = 16, | ||||
|     warmup: int = 10, | ||||
|     trials: int = 20, | ||||
| ): | ||||
|     torch.set_default_device("cuda") | ||||
|     torch.manual_seed(0) | ||||
|  | ||||
|     HEAD_GRP_SIZE = 8 | ||||
|     MAX_SEQ_LEN = max_seq_len | ||||
|     q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes | ||||
|     q_quant_dtype = q_quant_dtype or dtype | ||||
|     kv_quant_dtype = kv_quant_dtype or dtype | ||||
|     o_quant_dtype = o_quant_dtype or dtype | ||||
|  | ||||
|     max_q_len = max_kv_len = max_seq_len | ||||
|  | ||||
|     num_qo_heads, num_kv_heads = num_heads | ||||
|     assert num_qo_heads % num_kv_heads == 0 | ||||
|  | ||||
|     sm_scale = float(1.0 / (head_size**0.5)) | ||||
|  | ||||
|     # large number to reduce kv_cache reuse | ||||
|     NUM_BLOCKS = int(256000 / page_size) | ||||
|     NUM_BLOCKS = int(256000 / block_size) | ||||
|  | ||||
|     workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8) | ||||
|     kv_cache_shape = None | ||||
|     if kv_layout == "NHD": | ||||
|         kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) | ||||
|     elif kv_layout == "HND": | ||||
|         kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) | ||||
|     else: | ||||
|         raise ValueError(f"Invalid kv_layout: {kv_layout}") | ||||
|  | ||||
|     num_qo_heads = num_kv_heads * HEAD_GRP_SIZE | ||||
|     sm_scale = float(1.0 / (head_dim**0.5)) | ||||
|  | ||||
|     q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] | ||||
|     q_lens[-1] = MAX_SEQ_LEN | ||||
|     max_q_len = max(q_lens) | ||||
|     q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32) | ||||
|     q_lens[-1] = max_q_len | ||||
|     q_indptr = torch.cat( | ||||
|         [ | ||||
|             torch.tensor([0], dtype=torch.int32), | ||||
|             torch.cumsum( | ||||
|                 torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32 | ||||
|             ), | ||||
|             torch.cumsum(q_lens, dim=0, dtype=torch.int32), | ||||
|         ] | ||||
|     ) | ||||
|     q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype) | ||||
|  | ||||
|     kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)] | ||||
|     kv_lens[-1] = MAX_SEQ_LEN | ||||
|  | ||||
|     seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)] | ||||
|     max_seq_len = max(seq_lens) | ||||
|     seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32) | ||||
|  | ||||
|     max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size | ||||
|     block_tables = torch.randint( | ||||
|         0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 | ||||
|     # Always using 1.0 scale to reflect the real perf in benchmarking | ||||
|     q_scale = 1.0 | ||||
|     ref_query = torch.randn( | ||||
|         torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype | ||||
|     ) | ||||
|     if q_quant_dtype == FP8_DTYPE: | ||||
|         query, _ = to_float8(ref_query) | ||||
|     else: | ||||
|         query = ref_query | ||||
|  | ||||
|     kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) | ||||
|     kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype) | ||||
|     kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32) | ||||
|     kv_lens[-1] = max_kv_len | ||||
|  | ||||
|     seq_lens = kv_lens + q_lens | ||||
|     max_seq_len = torch.max(seq_lens).item() | ||||
|  | ||||
|     # Always using 1.0 scale to reflect the real perf in benchmarking | ||||
|     k_scale = v_scale = 1.0 | ||||
|     ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype) | ||||
|     if kv_quant_dtype == FP8_DTYPE: | ||||
|         kv_cache, _ = to_float8(ref_kv_cache) | ||||
|     else: | ||||
|         kv_cache = ref_kv_cache | ||||
|  | ||||
|     if kv_cache_dtype.startswith("fp8"): | ||||
|         kv_cache, _ = to_float8(kv_cache) | ||||
|  | ||||
|     output_trtllm = torch.empty(q.shape, dtype=dtype) | ||||
|  | ||||
|     max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size | ||||
|     block_tables = torch.randint( | ||||
|         0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32 | ||||
|     ) | ||||
|     kv_indptr = [0] | ||||
|     kv_indices = [] | ||||
|     kv_last_page_lens = [] | ||||
|     for i in range(num_seqs): | ||||
|     for i in range(batch_size): | ||||
|         seq_len = seq_lens[i] | ||||
|         assert seq_len > 0 | ||||
|         num_blocks = (seq_len + page_size - 1) // page_size | ||||
|         num_blocks = (seq_len + block_size - 1) // block_size | ||||
|         kv_indices.extend(block_tables[i, :num_blocks]) | ||||
|         kv_indptr.append(kv_indptr[-1] + num_blocks) | ||||
|         kv_last_page_len = seq_len % page_size | ||||
|         kv_last_page_len = seq_len % block_size | ||||
|         if kv_last_page_len == 0: | ||||
|             kv_last_page_len = page_size | ||||
|             kv_last_page_len = block_size | ||||
|         kv_last_page_lens.append(kv_last_page_len) | ||||
|  | ||||
|     kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) | ||||
|     kv_indices = torch.tensor(kv_indices, dtype=torch.int32) | ||||
|     kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) | ||||
|  | ||||
|     output_baseline = torch.empty(q.shape, dtype=dtype) | ||||
|     workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8) | ||||
|  | ||||
|     wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( | ||||
|         workspace_buffer, kv_layout | ||||
| @ -115,12 +132,12 @@ def benchmark_prefill( | ||||
|         kv_last_page_lens, | ||||
|         num_qo_heads, | ||||
|         num_kv_heads, | ||||
|         head_dim, | ||||
|         page_size, | ||||
|         head_size, | ||||
|         block_size, | ||||
|         causal=True, | ||||
|         sm_scale=sm_scale, | ||||
|         q_data_type=dtype, | ||||
|         kv_data_type=kv_cache.dtype, | ||||
|         kv_data_type=dtype, | ||||
|     ) | ||||
|  | ||||
|     def time_fn(fn, warmup=10, trials=20): | ||||
| @ -138,52 +155,76 @@ def benchmark_prefill( | ||||
|             times.append(start.elapsed_time(end))  # ms | ||||
|         return sum(times) / len(times), torch.std(torch.tensor(times)) | ||||
|  | ||||
|     o_scale = 1.0 | ||||
|     o_sf_scale = None | ||||
|     output_baseline = torch.empty(ref_query.shape, dtype=dtype) | ||||
|     if o_quant_dtype == FP4_DTYPE: | ||||
|         o_sf_scale = 500.0 | ||||
|         output_trtllm = flashinfer.utils.FP4Tensor( | ||||
|             torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), | ||||
|             torch.empty( | ||||
|                 ( | ||||
|                     round_up(query.shape[0], 128), | ||||
|                     round_up(query.shape[1] * query.shape[2] // 16, 4), | ||||
|                 ), | ||||
|                 dtype=torch.float8_e4m3fn, | ||||
|             ), | ||||
|         ) | ||||
|     else: | ||||
|         output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) | ||||
|  | ||||
|     def baseline_prefill(): | ||||
|         return wrapper.run( | ||||
|             q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline | ||||
|             ref_query, | ||||
|             ref_kv_cache, | ||||
|             k_scale=k_scale, | ||||
|             v_scale=v_scale, | ||||
|             out=output_baseline, | ||||
|         ) | ||||
|  | ||||
|     def trt_prefill(): | ||||
|     def trtllm_prefill(): | ||||
|         return flashinfer.prefill.trtllm_batch_context_with_kv_cache( | ||||
|             query=q, | ||||
|             query=query, | ||||
|             kv_cache=kv_cache, | ||||
|             workspace_buffer=workspace_buffer, | ||||
|             block_tables=block_tables, | ||||
|             seq_lens=seq_lens_tensor, | ||||
|             seq_lens=seq_lens, | ||||
|             max_q_len=max_q_len, | ||||
|             max_kv_len=max_seq_len, | ||||
|             bmm1_scale=k_scale * sm_scale, | ||||
|             bmm2_scale=v_scale, | ||||
|             batch_size=num_seqs, | ||||
|             bmm1_scale=q_scale * k_scale * sm_scale, | ||||
|             bmm2_scale=v_scale / o_scale, | ||||
|             batch_size=batch_size, | ||||
|             cum_seq_lens_q=q_indptr, | ||||
|             cum_seq_lens_kv=kv_indptr, | ||||
|             o_sf_scale=o_sf_scale, | ||||
|             out=output_trtllm, | ||||
|         ) | ||||
|  | ||||
|     trt_mean, trt_std = time_fn(trt_prefill) | ||||
|     baseline_mean, baseline_std = time_fn(baseline_prefill) | ||||
|     trtllm_mean, trtllm_std = time_fn(trtllm_prefill) | ||||
|  | ||||
|     # Calculate percentage speedup (positive means TRT is faster) | ||||
|     speedup_percent = (baseline_mean - trt_mean) / baseline_mean | ||||
|     speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean | ||||
|  | ||||
|     print( | ||||
|         f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}" | ||||
|         f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}" | ||||
|         f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}" | ||||
|         f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}" | ||||
|     ) | ||||
|  | ||||
|     # Return results for CSV writing | ||||
|     return { | ||||
|         "num_seqs": num_seqs, | ||||
|         "trt_mean": trt_mean, | ||||
|         "trt_std": trt_std.item(), | ||||
|         "batch_size": batch_size, | ||||
|         "trtllm_mean": trtllm_mean, | ||||
|         "trtllm_std": trtllm_std.item(), | ||||
|         "baseline_mean": baseline_mean, | ||||
|         "baseline_std": baseline_std.item(), | ||||
|         "speedup_percent": speedup_percent, | ||||
|         "q_dtype": str(dtype), | ||||
|         "kv_cache_dtype": kv_cache_dtype, | ||||
|         "page_size": page_size, | ||||
|         "q_dtype": str(q_quant_dtype), | ||||
|         "kv_cache_dtype": str(kv_quant_dtype), | ||||
|         "output_dtype": str(o_quant_dtype), | ||||
|         "block_size": block_size, | ||||
|         "num_kv_heads": num_kv_heads, | ||||
|         "head_dim": head_dim, | ||||
|         "head_size": head_size, | ||||
|         "max_seq_len": max_seq_len, | ||||
|     } | ||||
|  | ||||
| @ -195,17 +236,18 @@ def write_results_to_csv(results, filename=None): | ||||
|         filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" | ||||
|  | ||||
|     fieldnames = [ | ||||
|         "num_seqs", | ||||
|         "trt_mean", | ||||
|         "trt_std", | ||||
|         "batch_size", | ||||
|         "trtllm_mean", | ||||
|         "trtllm_std", | ||||
|         "baseline_mean", | ||||
|         "baseline_std", | ||||
|         "speedup_percent", | ||||
|         "q_dtype", | ||||
|         "kv_cache_dtype", | ||||
|         "page_size", | ||||
|         "output_dtype", | ||||
|         "block_size", | ||||
|         "num_kv_heads", | ||||
|         "head_dim", | ||||
|         "head_size", | ||||
|         "max_seq_len", | ||||
|     ] | ||||
|  | ||||
| @ -224,27 +266,42 @@ def write_results_to_csv(results, filename=None): | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] | ||||
|     batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256] | ||||
|     max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] | ||||
|     all_results = [] | ||||
|  | ||||
|     print( | ||||
|         "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " | ||||
|         "output_dtype: bfloat16" | ||||
|     ) | ||||
|     print( | ||||
|         "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t" | ||||
|         "baseline_std\tspeedup_percent" | ||||
|     ) | ||||
|     for max_seq_len in max_seq_lens: | ||||
|         for bs in num_seqs: | ||||
|             result = benchmark_prefill( | ||||
|                 bs, | ||||
|                 max_seq_len, | ||||
|                 dtype=torch.bfloat16, | ||||
|                 kv_cache_dtype="auto", | ||||
|             ) | ||||
|             all_results.append(result) | ||||
|     dtype = torch.bfloat16 | ||||
|     quant_dtypes = [ | ||||
|         # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) | ||||
|         (None, None, None), | ||||
|         (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), | ||||
|         (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), | ||||
|     ] | ||||
|  | ||||
|     for quant_dtype in quant_dtypes: | ||||
|         q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype | ||||
|         q_quant_dtype = q_quant_dtype or dtype | ||||
|         kv_quant_dtype = kv_quant_dtype or dtype | ||||
|         o_quant_dtype = o_quant_dtype or dtype | ||||
|  | ||||
|         print( | ||||
|             f"Running benchmark for q_dtype = {q_quant_dtype}, " | ||||
|             f"kv_cache_dtype: {kv_quant_dtype}, " | ||||
|             f"output_dtype: {o_quant_dtype}" | ||||
|         ) | ||||
|         print( | ||||
|             "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t" | ||||
|             "baseline_std\tspeedup_percent" | ||||
|         ) | ||||
|         for max_seq_len in max_seq_lens: | ||||
|             for bs in batch_sizes: | ||||
|                 result = benchmark_prefill( | ||||
|                     dtype=dtype, | ||||
|                     quant_dtypes=quant_dtype, | ||||
|                     batch_size=bs, | ||||
|                     max_seq_len=max_seq_len, | ||||
|                 ) | ||||
|                 all_results.append(result) | ||||
|  | ||||
|     # Write all results to CSV | ||||
|     write_results_to_csv(all_results) | ||||
|  | ||||
| @ -11,8 +11,8 @@ from datetime import datetime | ||||
| from typing import Any | ||||
|  | ||||
| import torch | ||||
| import tqdm | ||||
| import triton | ||||
| from tqdm import tqdm | ||||
|  | ||||
| from vllm.model_executor.layers.quantization.utils.fp8_utils import ( | ||||
|     _w8a8_block_fp8_matmul, | ||||
|  | ||||
| @ -95,4 +95,10 @@ WEIGHT_SHAPES = { | ||||
|         ([2048, 2816], 1), | ||||
|         ([1408, 2048], 0), | ||||
|     ], | ||||
|     "CohereLabs/c4ai-command-a-03-2025": [ | ||||
|         ([12288, 14336], 1), | ||||
|         ([12288, 12288], 0), | ||||
|         ([12288, 73728], 1), | ||||
|         ([36864, 12288], 0), | ||||
|     ], | ||||
| } | ||||
|  | ||||
| @ -5,11 +5,13 @@ The requirements (pip) for `benchmark_serving_multi_turn.py` can be found in `re | ||||
| First start serving your model | ||||
|  | ||||
| ```bash | ||||
| export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ | ||||
| export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ | ||||
|  | ||||
| vllm serve $MODEL_NAME --disable-log-requests | ||||
| vllm serve $MODEL_PATH --served-model-name Llama --disable-log-requests | ||||
| ``` | ||||
|  | ||||
| The variable `MODEL_PATH` should be a path to the model files (e.g. downloaded from huggingface). | ||||
|  | ||||
| ## Synthetic Multi-Turn Conversations | ||||
|  | ||||
| Download the following text file (used for generation of synthetic conversations) | ||||
| @ -26,10 +28,10 @@ But you may use other text files if you prefer (using this specific file is not | ||||
| Then run the benchmarking script | ||||
|  | ||||
| ```bash | ||||
| export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ | ||||
| export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ | ||||
|  | ||||
| python benchmark_serving_multi_turn.py --model $MODEL_NAME --input-file generate_multi_turn.json \ | ||||
| --num-clients 2 --max-active-conversations 6 | ||||
| python benchmark_serving_multi_turn.py --model $MODEL_PATH --served-model-name Llama \ | ||||
| --input-file generate_multi_turn.json --num-clients 2 --max-active-conversations 6 | ||||
| ``` | ||||
|  | ||||
| You can edit the file `generate_multi_turn.json` to change the conversation parameters (number of turns, etc.). | ||||
|  | ||||
| @ -825,9 +825,11 @@ def get_client_config( | ||||
|  | ||||
|     # Arguments for API requests | ||||
|     chat_url = f"{args.url}/v1/chat/completions" | ||||
|     model_name = args.served_model_name if args.served_model_name else args.model | ||||
|  | ||||
|     req_args = RequestArgs( | ||||
|         chat_url=chat_url, | ||||
|         model=args.model, | ||||
|         model=model_name, | ||||
|         stream=not args.no_stream, | ||||
|         limit_min_tokens=args.limit_min_tokens, | ||||
|         limit_max_tokens=args.limit_max_tokens, | ||||
| @ -1247,9 +1249,19 @@ async def main() -> None: | ||||
|         default=0, | ||||
|         help="Seed for random number generators (default: 0)", | ||||
|     ) | ||||
|  | ||||
|     parser.add_argument( | ||||
|         "-m", "--model", type=str, required=True, help="Path of the LLM model" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--served-model-name", | ||||
|         type=str, | ||||
|         default=None, | ||||
|         help="The model name used in the API. " | ||||
|         "If not specified, the model name will be the " | ||||
|         "same as the ``--model`` argument. ", | ||||
|     ) | ||||
|  | ||||
|     parser.add_argument( | ||||
|         "-u", | ||||
|         "--url", | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| include(FetchContent) | ||||
|  | ||||
| set(CMAKE_CXX_STANDARD_REQUIRED ON) | ||||
| set(CMAKE_CXX_STANDARD 17) | ||||
| set(CMAKE_CXX_EXTENSIONS ON) | ||||
| set(CMAKE_EXPORT_COMPILE_COMMANDS ON) | ||||
|  | ||||
| @ -182,17 +183,17 @@ endif() | ||||
| # | ||||
| # Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms) | ||||
| # Flag to enable ACL kernels for AARCH64 platforms | ||||
| if ( VLLM_BUILD_ACL STREQUAL "ON") | ||||
| if (VLLM_BUILD_ACL STREQUAL "ON") | ||||
|     set(USE_ACL ON) | ||||
| else() | ||||
|     set(USE_ACL OFF) | ||||
| endif() | ||||
|  | ||||
| if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) | ||||
| if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) | ||||
|     FetchContent_Declare( | ||||
|         oneDNN | ||||
|         GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git | ||||
|         GIT_TAG  v3.8.1 | ||||
|         GIT_TAG v3.9 | ||||
|         GIT_PROGRESS TRUE | ||||
|         GIT_SHALLOW TRUE | ||||
|     ) | ||||
| @ -204,7 +205,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) | ||||
|         endif() | ||||
|         set(ONEDNN_AARCH64_USE_ACL "ON") | ||||
|         set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/") | ||||
|         endif() | ||||
|     endif() | ||||
|  | ||||
|     set(ONEDNN_LIBRARY_TYPE "STATIC") | ||||
|     set(ONEDNN_BUILD_DOC "OFF") | ||||
| @ -217,38 +218,23 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) | ||||
|     set(ONEDNN_ENABLE_ITT_TASKS "OFF") | ||||
|     set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") | ||||
|     set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") | ||||
|     set(ONEDNN_VERBOSE "OFF") | ||||
|     set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) | ||||
|  | ||||
|     FetchContent_MakeAvailable(oneDNN) | ||||
|      | ||||
|     list(APPEND LIBS dnnl) | ||||
| elseif(POWER10_FOUND) | ||||
|     FetchContent_Declare( | ||||
|         oneDNN | ||||
|         GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git | ||||
|         GIT_TAG v3.7.2 | ||||
|         GIT_PROGRESS TRUE | ||||
|         GIT_SHALLOW TRUE | ||||
|     add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp") | ||||
|     target_include_directories( | ||||
|         dnnl_ext | ||||
|         PUBLIC ${oneDNN_SOURCE_DIR}/include | ||||
|         PUBLIC ${oneDNN_BINARY_DIR}/include | ||||
|         PRIVATE ${oneDNN_SOURCE_DIR}/src | ||||
|     ) | ||||
|  | ||||
|     set(ONEDNN_LIBRARY_TYPE "STATIC") | ||||
|     set(ONEDNN_BUILD_DOC "OFF") | ||||
|     set(ONEDNN_BUILD_EXAMPLES "OFF") | ||||
|     set(ONEDNN_BUILD_TESTS "OFF") | ||||
|     set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") | ||||
|     set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") | ||||
|     set(ONEDNN_BUILD_GRAPH "OFF") | ||||
|     set(ONEDNN_ENABLE_JIT_PROFILING "OFF") | ||||
|     set(ONEDNN_ENABLE_ITT_TASKS "OFF") | ||||
|     set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") | ||||
|     set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") | ||||
|     set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) | ||||
|  | ||||
|     set(DNNL_CPU_RUNTIME "OMP") | ||||
|  | ||||
|     FetchContent_MakeAvailable(oneDNN) | ||||
|  | ||||
|     list(APPEND LIBS dnnl) | ||||
|     target_link_libraries(dnnl_ext dnnl) | ||||
|     target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC) | ||||
|     list(APPEND LIBS dnnl_ext) | ||||
|     set(USE_ONEDNN ON) | ||||
| else() | ||||
|     set(USE_ONEDNN OFF) | ||||
| endif() | ||||
|  | ||||
| message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") | ||||
| @ -275,7 +261,6 @@ set(VLLM_EXT_SRC | ||||
|  | ||||
| if (AVX512_FOUND AND NOT AVX512_DISABLED) | ||||
|     set(VLLM_EXT_SRC | ||||
|         "csrc/cpu/quant.cpp" | ||||
|         "csrc/cpu/shm.cpp" | ||||
|         ${VLLM_EXT_SRC}) | ||||
|     if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) | ||||
| @ -289,14 +274,11 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) | ||||
|             ${VLLM_EXT_SRC}) | ||||
|         add_compile_definitions(-DCPU_CAPABILITY_AVX512) | ||||
|     endif() | ||||
| elseif(POWER10_FOUND) | ||||
|     set(VLLM_EXT_SRC | ||||
|         "csrc/cpu/quant.cpp" | ||||
|         ${VLLM_EXT_SRC}) | ||||
| endif() | ||||
| if (ASIMD_FOUND) | ||||
|  | ||||
| if(USE_ONEDNN) | ||||
|     set(VLLM_EXT_SRC | ||||
|         "csrc/cpu/quant.cpp" | ||||
|         "csrc/cpu/dnnl_kernels.cpp" | ||||
|         ${VLLM_EXT_SRC}) | ||||
| endif() | ||||
|  | ||||
|  | ||||
| @ -19,7 +19,7 @@ else() | ||||
|   FetchContent_Declare( | ||||
|         flashmla | ||||
|         GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git | ||||
|         GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1 | ||||
|         GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de | ||||
|         GIT_PROGRESS TRUE | ||||
|         CONFIGURE_COMMAND "" | ||||
|         BUILD_COMMAND "" | ||||
| @ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") | ||||
| if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) | ||||
|     set(FlashMLA_SOURCES | ||||
|         ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp | ||||
|         ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu | ||||
|         ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu | ||||
|         ${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu | ||||
|         ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu) | ||||
|         ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu | ||||
|         ${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu) | ||||
|  | ||||
|     set(FlashMLA_INCLUDES | ||||
|         ${flashmla_SOURCE_DIR}/csrc/cutlass/include | ||||
|         ${flashmla_SOURCE_DIR}/csrc/include) | ||||
|         ${flashmla_SOURCE_DIR}/csrc) | ||||
|  | ||||
|     set_gencode_flags_for_srcs( | ||||
|         SRCS "${FlashMLA_SOURCES}" | ||||
|  | ||||
| @ -167,7 +167,7 @@ typename T::Fmha::Arguments args_from_options( | ||||
|       // TODO(trevor-m): Change split_kv back to -1 when | ||||
|       // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will | ||||
|       // perform worse with larger context length and smaller batch sizes. | ||||
|       num_kv_splits, // split_kv | ||||
|       static_cast<int>(num_kv_splits), // split_kv | ||||
|       nullptr,       // is_var_split_kv | ||||
|   }; | ||||
|   // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute | ||||
| @ -264,7 +264,7 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba | ||||
|   // Assumes device 0 when getting sm_count. | ||||
|   arguments.hw_info.sm_count = | ||||
|       sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count; | ||||
|   arguments.split_kv = num_kv_splits; | ||||
|   arguments.split_kv = static_cast<int>(num_kv_splits); | ||||
|   MlaSm100Type::Fmha::set_split_kv(arguments); | ||||
|  | ||||
|   return MlaSm100Type::Fmha::get_workspace_size(arguments); | ||||
|  | ||||
| @ -40,9 +40,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, | ||||
| void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, | ||||
|                  const double scale, const std::string& kv_cache_dtype); | ||||
|  | ||||
| void gather_cache( | ||||
| void gather_and_maybe_dequant_cache( | ||||
|     torch::Tensor const& src_cache,    // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] | ||||
|     torch::Tensor const& dst,          // [TOT_TOKENS, ENTRIES...] | ||||
|     torch::Tensor const& block_table,  // [BATCH, BLOCK_INDICES] | ||||
|     torch::Tensor const& cu_seq_lens,  // [BATCH+1] | ||||
|     int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt); | ||||
|     int64_t batch_size, const std::string& kv_cache_dtype, | ||||
|     torch::Tensor const& scale, | ||||
|     std::optional<torch::Tensor> seq_starts = std::nullopt); | ||||
| @ -624,9 +624,9 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, | ||||
| namespace vllm { | ||||
|  | ||||
| // grid is launched with dimensions (batch, num_splits) | ||||
| template <typename scalar_t> | ||||
| __global__ void gather_cache( | ||||
|     const scalar_t* __restrict__ src_cache,   // [NUM_BLOCKS, BLOCK_SIZE, | ||||
| template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt> | ||||
| __global__ void gather_and_maybe_dequant_cache( | ||||
|     const cache_t* __restrict__ src_cache,    // [NUM_BLOCKS, BLOCK_SIZE, | ||||
|                                               // ENTRIES...] | ||||
|     scalar_t* __restrict__ dst,               // [TOT_TOKENS, ENTRIES...] | ||||
|     const int32_t* __restrict__ block_table,  // [BATCH, BLOCK_INDICES] | ||||
| @ -634,6 +634,7 @@ __global__ void gather_cache( | ||||
|     const int32_t block_size, const int32_t entry_size, | ||||
|     const int64_t block_table_stride, const int64_t cache_block_stride, | ||||
|     const int64_t cache_entry_stride, const int64_t dst_entry_stride, | ||||
|     const float* __restrict__ scale, | ||||
|     const int32_t* __restrict__ seq_starts) {  // Optional: starting offsets per | ||||
|                                                // batch | ||||
|  | ||||
| @ -675,10 +676,16 @@ __global__ void gather_cache( | ||||
|     if (partial_block_size) full_blocks_end -= 1; | ||||
|   } | ||||
|  | ||||
|   auto copy_entry = [&](const scalar_t* __restrict__ _src, | ||||
|   auto copy_entry = [&](const cache_t* __restrict__ _src, | ||||
|                         scalar_t* __restrict__ _dst) { | ||||
|     for (int i = threadIdx.x; i < entry_size; i += blockDim.x) | ||||
|       _dst[i] = _src[i]; | ||||
|     for (int i = threadIdx.x; i < entry_size; i += blockDim.x) { | ||||
|       if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { | ||||
|         _dst[i] = static_cast<scalar_t>(_src[i]); | ||||
|       } else { | ||||
|         _dst[i] = | ||||
|             fp8::scaled_convert<scalar_t, cache_t, kv_dt>(_src[i], *scale); | ||||
|       } | ||||
|     } | ||||
|   }; | ||||
|  | ||||
|   for (int pid = split_start; pid < full_blocks_end; ++pid) { | ||||
| @ -705,25 +712,31 @@ __global__ void gather_cache( | ||||
| }  // namespace vllm | ||||
|  | ||||
| // Macro to dispatch the kernel based on the data type. | ||||
| #define CALL_GATHER_CACHE(CPY_DTYPE)                                    \ | ||||
|   vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>(            \ | ||||
|       reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()),               \ | ||||
|       reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()),                     \ | ||||
|       block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \ | ||||
|       block_size, entry_size, block_table_stride, cache_block_stride,   \ | ||||
|       cache_entry_stride, dst_entry_stride, seq_starts_ptr); | ||||
| // SCALAR_T is the data type of the destination tensor. | ||||
| // CACHE_T is the stored data type of kv-cache. | ||||
| // KV_DTYPE is the real data type of kv-cache. | ||||
| #define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE)                      \ | ||||
|   vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE>         \ | ||||
|       <<<grid, block, 0, stream>>>(                                         \ | ||||
|           reinterpret_cast<CACHE_T*>(src_cache.data_ptr()),                 \ | ||||
|           reinterpret_cast<SCALAR_T*>(dst.data_ptr()),                      \ | ||||
|           block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \ | ||||
|           block_size, entry_size, block_table_stride, cache_block_stride,   \ | ||||
|           cache_entry_stride, dst_entry_stride,                             \ | ||||
|           reinterpret_cast<const float*>(scale.data_ptr()), seq_starts_ptr); | ||||
|  | ||||
| // Gather sequences from the cache into the destination tensor. | ||||
| //  - cu_seq_lens contains the cumulative sequence lengths for each batch | ||||
| //  - block_table contains the cache block indices for each sequence | ||||
| //  - Optionally, seq_starts (if provided) offsets the starting block index by | ||||
| //  (seq_starts[bid] / page_size) | ||||
| void gather_cache( | ||||
| void gather_and_maybe_dequant_cache( | ||||
|     torch::Tensor const& src_cache,    // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] | ||||
|     torch::Tensor const& dst,          // [TOT_TOKENS, ENTRIES...] | ||||
|     torch::Tensor const& block_table,  // [BATCH, BLOCK_INDICES] | ||||
|     torch::Tensor const& cu_seq_lens,  // [BATCH+1] | ||||
|     int64_t batch_size, | ||||
|     int64_t batch_size, const std::string& kv_cache_dtype, | ||||
|     torch::Tensor const& scale, | ||||
|     std::optional<torch::Tensor> seq_starts = std::nullopt) { | ||||
|   at::cuda::OptionalCUDAGuard device_guard(src_cache.device()); | ||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||
| @ -761,20 +774,8 @@ void gather_cache( | ||||
|   dim3 grid(batch_size, num_splits); | ||||
|   dim3 block(1024); | ||||
|  | ||||
|   TORCH_CHECK(src_cache.dtype() == dst.dtype(), | ||||
|               "src_cache and dst must have the same dtype"); | ||||
|  | ||||
|   const int dtype_bits = src_cache.element_size() * 8; | ||||
|   const int32_t* seq_starts_ptr = | ||||
|       seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr; | ||||
|  | ||||
|   if (dtype_bits == 32) { | ||||
|     CALL_GATHER_CACHE(uint32_t); | ||||
|   } else if (dtype_bits == 16) { | ||||
|     CALL_GATHER_CACHE(uint16_t); | ||||
|   } else if (dtype_bits == 8) { | ||||
|     CALL_GATHER_CACHE(uint8_t); | ||||
|   } else { | ||||
|     TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); | ||||
|   } | ||||
|   DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE); | ||||
| } | ||||
|  | ||||
| @ -89,7 +89,7 @@ struct FP16Vec16 : public Vec<FP16Vec16> { | ||||
|  | ||||
|   explicit FP16Vec16(const FP32Vec16&); | ||||
|  | ||||
|   void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } | ||||
|   void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); } | ||||
|  | ||||
|   void save(void* ptr, const int elem_num) const { | ||||
|     constexpr uint32_t M = 0xFFFFFFFF; | ||||
| @ -126,7 +126,7 @@ struct BF16Vec16 : public Vec<BF16Vec16> { | ||||
|  | ||||
|   explicit BF16Vec16(const FP32Vec16&); | ||||
|  | ||||
|   void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; } | ||||
|   void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); } | ||||
|  | ||||
|   void save(void* ptr, const int elem_num) const { | ||||
|     constexpr uint32_t M = 0xFFFFFFFF; | ||||
| @ -180,8 +180,8 @@ struct BF16Vec32 : public Vec<BF16Vec32> { | ||||
|             (__m128i)vec8_data.reg, 1)) {} | ||||
|  | ||||
|   void save(void* ptr) const { | ||||
|     *reinterpret_cast<__m256i*>(ptr) = reg_low; | ||||
|     *reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high; | ||||
|     _mm256_storeu_si256((__m256i*)ptr, reg_low); | ||||
|     _mm256_storeu_si256((__m256i*)ptr + 1, reg_high); | ||||
|   } | ||||
| }; | ||||
| #endif | ||||
|  | ||||
							
								
								
									
										346
									
								
								csrc/cpu/dnnl_helper.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										346
									
								
								csrc/cpu/dnnl_helper.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,346 @@ | ||||
| #include <list> | ||||
| #include <optional> | ||||
|  | ||||
| #include "common/memory_desc.hpp" | ||||
| #include "common/memory.hpp" | ||||
|  | ||||
| #include "dnnl_helper.h" | ||||
|  | ||||
| static dnnl::engine& default_engine() { | ||||
|   static dnnl::engine engine(dnnl::engine::kind::cpu, 0); | ||||
|   return engine; | ||||
| } | ||||
|  | ||||
| static dnnl::stream& default_stream() { | ||||
|   static dnnl::stream stream(default_engine()); | ||||
|   return stream; | ||||
| } | ||||
|  | ||||
| void release_dnnl_matmul_handler(int64_t handler) { | ||||
|   DNNLMatMulPrimitiveHandler* ptr = | ||||
|       reinterpret_cast<DNNLMatMulPrimitiveHandler*>(handler); | ||||
|   delete ptr; | ||||
| } | ||||
|  | ||||
| template <typename KT, typename VT> | ||||
| class DNNLPrimitiveCache { | ||||
|  public: | ||||
|   using cache_value_t = std::pair<KT, VT>; | ||||
|   using result_value_t = VT; | ||||
|   using container_t = std::list<cache_value_t>; | ||||
|   using value_iterator_t = typename container_t::iterator; | ||||
|   using map_t = std::unordered_map<KT, value_iterator_t>; | ||||
|   using creator_t = VT (*)(); | ||||
|  | ||||
|  public: | ||||
|   DNNLPrimitiveCache(size_t capacity) | ||||
|       : capacity_(capacity), | ||||
|         values_(), | ||||
|         key_to_value_(std::min(256lu, capacity)) { | ||||
|     assert(capacity > 0); | ||||
|   } | ||||
|  | ||||
|   template <typename F> | ||||
|   result_value_t get_or_create(const KT& key, F&& creator) { | ||||
|     std::optional<value_iterator_t> value = get_value(key); | ||||
|     if (value.has_value()) { | ||||
|       return value.value()->second; | ||||
|     } else { | ||||
|       return add_value({key, creator()})->second; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   size_t size() const { return values_.size(); } | ||||
|  | ||||
|  private: | ||||
|   void dump_data() { | ||||
|     std::stringstream ss; | ||||
|     ss << "table_id: " << std::hex << reinterpret_cast<size_t>(this) << std::dec | ||||
|        << "\n"; | ||||
|     ss << "container: ["; | ||||
|     for (auto&& iter : values_) { | ||||
|       ss << "(" << iter.first << ", " << std::hex | ||||
|          << reinterpret_cast<size_t>(iter.second.get()) << "), " << std::dec; | ||||
|     } | ||||
|     ss << "]\n"; | ||||
|  | ||||
|     ss << "map: ["; | ||||
|     for (auto&& iter : key_to_value_) { | ||||
|       ss << "(" << iter.first << ", " << iter.second->first << ", " << std::hex | ||||
|          << reinterpret_cast<size_t>(iter.second->second.get()) << std::dec | ||||
|          << "), "; | ||||
|     } | ||||
|     ss << "]\n"; | ||||
|     std::printf("%s\n", ss.str().c_str()); | ||||
|   } | ||||
|  | ||||
|   value_iterator_t add_value(cache_value_t&& new_value) { | ||||
|     if (size() == capacity_) { | ||||
|       cache_value_t& last_item = values_.back(); | ||||
|       key_to_value_.erase(last_item.first); | ||||
|       values_.pop_back(); | ||||
|     } | ||||
|  | ||||
|     auto& added_value_ = values_.emplace_front(std::move(new_value)); | ||||
|     key_to_value_.emplace(added_value_.first, values_.begin()); | ||||
|     return values_.begin(); | ||||
|   } | ||||
|  | ||||
|   std::optional<value_iterator_t> get_value(const KT& key) { | ||||
|     if (key_to_value_.size() > 0 && key == values_.begin()->first) { | ||||
|       return values_.begin(); | ||||
|     } | ||||
|  | ||||
|     auto value_map_iterator = key_to_value_.find(key); | ||||
|     if (value_map_iterator != key_to_value_.end()) { | ||||
|       values_.splice(values_.begin(), values_, value_map_iterator->second); | ||||
|       return value_map_iterator->second; | ||||
|     } else { | ||||
|       return {}; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   const size_t capacity_; | ||||
|   container_t values_; | ||||
|   map_t key_to_value_; | ||||
| }; | ||||
|  | ||||
| DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler( | ||||
|     const Args& args, dnnl::memory::data_type b_type) | ||||
|     : b_n_size_(args.b_n_size), | ||||
|       b_n_stride_(args.b_n_stride), | ||||
|       b_k_size_(args.b_k_size), | ||||
|       b_k_stride_(args.b_k_stride), | ||||
|       b_type_(b_type), | ||||
|       c_type_(args.c_type), | ||||
|       runtime_memory_ptrs_(8), | ||||
|       primitive_cache_size_(args.primitive_cache_size) { | ||||
|   assert(primitive_cache_size_ > 0); | ||||
| } | ||||
|  | ||||
| void DNNLMatMulPrimitiveHandler::prepack_weight( | ||||
|     void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) { | ||||
|   dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_, | ||||
|                                    {b_k_stride_, b_n_stride_}); | ||||
|   dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr); | ||||
|   dnnl::memory packed_weight(b_target_mem_desc, default_engine()); | ||||
|   { | ||||
|     dnnl::reorder(original_weight, packed_weight) | ||||
|         .execute(default_stream(), original_weight, packed_weight); | ||||
|     default_stream().wait(); | ||||
|   } | ||||
|   memory_cache_[DNNL_ARG_WEIGHTS] = packed_weight; | ||||
|   b_target_mem_desc_ = b_target_mem_desc; | ||||
| } | ||||
|  | ||||
| void DNNLMatMulPrimitiveHandler::set_runtime_memory_ptr( | ||||
|     size_t index, dnnl_memory* memory_ptr) { | ||||
|   dnnl::impl::memory_storage_t* mem_storage_ptr = memory_ptr->memory_storage(); | ||||
|   dnnl_memory_desc* mem_desc = const_cast<dnnl_memory_desc*>(memory_ptr->md()); | ||||
|   runtime_memory_ptrs_[index] = {mem_storage_ptr, mem_desc}; | ||||
| } | ||||
|  | ||||
| std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*> | ||||
| DNNLMatMulPrimitiveHandler::get_runtime_memory_ptr(size_t index) { | ||||
|   return runtime_memory_ptrs_[index]; | ||||
| } | ||||
|  | ||||
| namespace std { | ||||
| template <> | ||||
| struct hash<W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey> { | ||||
|   size_t operator()( | ||||
|       const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const { | ||||
|     return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size) ^ | ||||
|            hash<int>()(static_cast<int>(val.a_qs)) ^ | ||||
|            hash<int>()(static_cast<int>(val.b_qs)) ^ hash<bool>()(val.use_azp) ^ | ||||
|            hash<int>()(static_cast<int>(val.c_type)); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct hash<W8A8MatMulPrimitiveHandler::MSizeCacheKey> { | ||||
|   size_t operator()( | ||||
|       const W8A8MatMulPrimitiveHandler::MSizeCacheKey& val) const { | ||||
|     return hash<dnnl_dim_t>()(val.a_m_size) ^ hash<bool>()(val.use_bias) ^ | ||||
|            hash<int>()(static_cast<int>(val.bias_type)); | ||||
|   } | ||||
| }; | ||||
| }  // namespace std | ||||
|  | ||||
| bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l, | ||||
|                 const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& r) { | ||||
|   return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size && | ||||
|          l.a_qs == r.a_qs && l.b_qs == r.b_qs && l.use_azp == r.use_azp && | ||||
|          l.c_type == r.c_type; | ||||
| } | ||||
|  | ||||
| bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l, | ||||
|                 const W8A8MatMulPrimitiveHandler::MSizeCacheKey& r) { | ||||
|   return l.use_bias == r.use_bias && l.a_m_size == r.a_m_size && | ||||
|          l.bias_type == r.bias_type; | ||||
| } | ||||
|  | ||||
| static std::shared_ptr<W8A8MatMulPrimitiveHandler::MSizeCache> | ||||
| get_w8a8_class_primitive_cache( | ||||
|     const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key, | ||||
|     int64_t cache_size) { | ||||
|   static W8A8MatMulPrimitiveHandler::ClassMatmulCache cache(128); | ||||
|   assert(cache_size > 0); | ||||
|   return cache.get_or_create(key, [&]() { | ||||
|     return std::make_shared<W8A8MatMulPrimitiveHandler::MSizeCache>(cache_size); | ||||
|   }); | ||||
| } | ||||
|  | ||||
| W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args) | ||||
|     : DNNLMatMulPrimitiveHandler( | ||||
|           static_cast<const DNNLMatMulPrimitiveHandler::Args&>(args), | ||||
|           dnnl::memory::data_type::s8), | ||||
|       use_azp_(args.use_a_zero_point), | ||||
|       a_qs_(args.a_quantization_strategy), | ||||
|       b_qs_(args.b_quantization_strategy), | ||||
|       m_size_cache_(nullptr) { | ||||
|   assert(a_qs_ != QuantizationStrategy::PER_OUTPUT_CHANNEL); | ||||
|   assert(b_qs_ != QuantizationStrategy::PER_TOKEN); | ||||
|   if (a_qs_ == QuantizationStrategy::PER_TOKEN) { | ||||
|     assert(!use_azp_); | ||||
|   }; | ||||
|   prepack_weight(args.b_ptr, | ||||
|                  create_primitive_desc( | ||||
|                      MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL, | ||||
|                                    .use_bias = false, | ||||
|                                    .bias_type = dnnl::memory::data_type::undef}, | ||||
|                      true) | ||||
|                      .weights_desc()); | ||||
|   init_runtime_memory_cache(args); | ||||
| } | ||||
|  | ||||
| void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) { | ||||
|   auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0); | ||||
|   auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1); | ||||
|   a_storage->set_data_handle((void*)args.a_ptr); | ||||
|   a_mem_desc->dims[0] = args.a_m_size; | ||||
|   c_storage->set_data_handle((void*)args.c_ptr); | ||||
|   c_mem_desc->dims[0] = args.a_m_size; | ||||
|  | ||||
|   if (a_qs_ == QuantizationStrategy::PER_TENSOR) { | ||||
|     auto&& [a_scale_storage, a_scale_mem_desc] = get_runtime_memory_ptr(2); | ||||
|     a_scale_storage->set_data_handle((void*)args.a_scales_ptr); | ||||
|   } | ||||
|   if (use_azp_) { | ||||
|     auto&& [a_zero_point_storage, a_zero_point_mem_desc] = | ||||
|         get_runtime_memory_ptr(3); | ||||
|     a_zero_point_storage->set_data_handle((void*)args.a_zero_points_ptr); | ||||
|   } | ||||
|  | ||||
|   if (args.use_bias) { | ||||
|     auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(4); | ||||
|     bias_storage->set_data_handle((void*)args.bias_ptr); | ||||
|   } | ||||
|  | ||||
|   dnnl::matmul matmul = get_matmul_cache(args); | ||||
|   matmul.execute(default_stream(), memory_cache_); | ||||
|   default_stream().wait(); | ||||
| } | ||||
|  | ||||
| dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache( | ||||
|     const MSizeCacheKey& key) { | ||||
|   if (m_size_cache_.get() == nullptr) { | ||||
|     ClassMatmulCacheKey key = {.b_n_size = b_n_size_, | ||||
|                                .b_k_size = b_k_size_, | ||||
|                                .a_qs = a_qs_, | ||||
|                                .b_qs = b_qs_, | ||||
|                                .use_azp = use_azp_, | ||||
|                                .c_type = c_type_}; | ||||
|     m_size_cache_ = get_w8a8_class_primitive_cache(key, primitive_cache_size_); | ||||
|   } | ||||
|  | ||||
|   return m_size_cache_->get_or_create(key, [&]() { | ||||
|     dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false); | ||||
|     return dnnl::matmul(desc); | ||||
|   }); | ||||
| } | ||||
|  | ||||
| void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) { | ||||
|   memory_cache_[DNNL_ARG_SRC] = dnnl::memory({{1, b_k_size_}, | ||||
|                                               dnnl::memory::data_type::s8, | ||||
|                                               dnnl::memory::format_tag::ab}, | ||||
|                                              default_engine(), nullptr); | ||||
|   set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get()); | ||||
|   memory_cache_[DNNL_ARG_DST] = | ||||
|       dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab}, | ||||
|                    default_engine(), nullptr); | ||||
|   set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get()); | ||||
|  | ||||
|   // For PER_TOKEN, scales will be applied in outside epilogue | ||||
|   if (a_qs_ == QuantizationStrategy::PER_TENSOR) { | ||||
|     memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = dnnl::memory( | ||||
|         {{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr); | ||||
|     set_runtime_memory_ptr( | ||||
|         2, memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC].get()); | ||||
|     if (use_azp_) { | ||||
|       memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = dnnl::memory( | ||||
|           {{1}, dnnl::memory::data_type::s32, {1}}, default_engine(), nullptr); | ||||
|       set_runtime_memory_ptr( | ||||
|           3, memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC].get()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   if (b_qs_ == QuantizationStrategy::PER_TENSOR) { | ||||
|     memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = | ||||
|         dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), | ||||
|                      (void*)args.b_scales_ptr); | ||||
|   } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) { | ||||
|     memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = | ||||
|         dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, | ||||
|                      default_engine(), (void*)args.b_scales_ptr); | ||||
|   } | ||||
|  | ||||
|   memory_cache_[DNNL_ARG_BIAS] = | ||||
|       dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}}, | ||||
|                    default_engine(), nullptr); | ||||
|   set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get()); | ||||
| } | ||||
|  | ||||
| dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc( | ||||
|     const MSizeCacheKey& key, bool first_time) { | ||||
|   dnnl::memory::desc a_md({key.a_m_size, b_k_size_}, | ||||
|                           dnnl::memory::data_type::s8, | ||||
|                           dnnl::memory::format_tag::ab); | ||||
|   dnnl::memory::desc b_md; | ||||
|   if (first_time) { | ||||
|     b_md = | ||||
|         dnnl::memory::desc({b_k_size_, b_n_size_}, dnnl::memory::data_type::s8, | ||||
|                            dnnl::memory::format_tag::any); | ||||
|   } else { | ||||
|     b_md = b_target_mem_desc_; | ||||
|   } | ||||
|   dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_, | ||||
|                           dnnl::memory::format_tag::ab); | ||||
|  | ||||
|   dnnl::primitive_attr attr; | ||||
|   // For PER_TOKEN, scales will be applied in outside epilogue | ||||
|   if (a_qs_ == QuantizationStrategy::PER_TENSOR) { | ||||
|     attr.set_scales_mask(DNNL_ARG_SRC, 0); | ||||
|     if (use_azp_) { | ||||
|       attr.set_zero_points_mask(DNNL_ARG_SRC, 0); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   if (b_qs_ == QuantizationStrategy::PER_TENSOR) { | ||||
|     attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); | ||||
|   } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) { | ||||
|     attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); | ||||
|   } | ||||
|  | ||||
|   if (key.use_bias) { | ||||
|     // For PER_TOKEN, bias will be applied in epilogue | ||||
|     assert(a_qs_ == QuantizationStrategy::PER_TENSOR); | ||||
|     dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1}); | ||||
|     return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md, | ||||
|                                         c_md, attr); | ||||
|   } else { | ||||
|     return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, | ||||
|                                         attr); | ||||
|   } | ||||
| } | ||||
							
								
								
									
										169
									
								
								csrc/cpu/dnnl_helper.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										169
									
								
								csrc/cpu/dnnl_helper.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,169 @@ | ||||
| #ifndef DNNL_HELPER_H | ||||
| #define DNNL_HELPER_H | ||||
|  | ||||
| #include <optional> | ||||
| #include <cassert> | ||||
|  | ||||
| #include "oneapi/dnnl/dnnl.hpp" | ||||
|  | ||||
| namespace c10 { | ||||
| struct BFloat16; | ||||
| struct Half; | ||||
| }  // namespace c10 | ||||
|  | ||||
| namespace dnnl { | ||||
| namespace impl { | ||||
| struct memory_storage_t; | ||||
| struct matmul_pd_t; | ||||
| struct matmul_desc_t; | ||||
| }  // namespace impl | ||||
| }  // namespace dnnl | ||||
| struct dnnl_memory_desc; | ||||
|  | ||||
| template <typename KT, typename VT> | ||||
| class DNNLPrimitiveCache; | ||||
|  | ||||
| template <typename T> | ||||
| struct DNNLType { | ||||
|   static constexpr dnnl::memory::data_type type = | ||||
|       dnnl::memory::data_type::undef; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct DNNLType<int8_t> { | ||||
|   static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct DNNLType<int32_t> { | ||||
|   static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct DNNLType<float> { | ||||
|   static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct DNNLType<c10::BFloat16> { | ||||
|   static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct DNNLType<c10::Half> { | ||||
|   static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16; | ||||
| }; | ||||
|  | ||||
| template <typename T> | ||||
| constexpr inline dnnl::memory::data_type get_dnnl_type() { | ||||
|   return DNNLType<std::decay_t<T>>::type; | ||||
| } | ||||
|  | ||||
| class DNNLMatMulPrimitiveHandler { | ||||
|  public: | ||||
|   virtual ~DNNLMatMulPrimitiveHandler() = default; | ||||
|  | ||||
|  protected: | ||||
|   struct Args { | ||||
|     dnnl_dim_t b_n_size; | ||||
|     dnnl_dim_t b_n_stride; | ||||
|     dnnl_dim_t b_k_size; | ||||
|     dnnl_dim_t b_k_stride; | ||||
|     void* b_ptr; | ||||
|     dnnl::memory::data_type c_type; | ||||
|     size_t primitive_cache_size; | ||||
|   }; | ||||
|  | ||||
|  protected: | ||||
|   DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type); | ||||
|  | ||||
|   void prepack_weight(void* original_b_ptr, | ||||
|                       dnnl::memory::desc b_target_mem_desc); | ||||
|  | ||||
|   void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr); | ||||
|  | ||||
|   std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*> | ||||
|   get_runtime_memory_ptr(size_t index); | ||||
|  | ||||
|  protected: | ||||
|   const dnnl_dim_t b_n_size_; | ||||
|   const dnnl_dim_t b_n_stride_; | ||||
|   const dnnl_dim_t b_k_size_; | ||||
|   const dnnl_dim_t b_k_stride_; | ||||
|   dnnl::memory::data_type b_type_; | ||||
|   dnnl::memory::data_type c_type_; | ||||
|   std::unordered_map<int, dnnl::memory> memory_cache_; | ||||
|   std::vector<std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*>> | ||||
|       runtime_memory_ptrs_; | ||||
|   dnnl::memory::desc b_target_mem_desc_; | ||||
|   int64_t primitive_cache_size_; | ||||
| }; | ||||
|  | ||||
| class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler { | ||||
|  public: | ||||
|   enum class QuantizationStrategy { PER_TOKEN, PER_TENSOR, PER_OUTPUT_CHANNEL }; | ||||
|  | ||||
|   struct Args : public DNNLMatMulPrimitiveHandler::Args { | ||||
|     bool use_a_zero_point; | ||||
|     QuantizationStrategy a_quantization_strategy; | ||||
|     QuantizationStrategy b_quantization_strategy; | ||||
|     float* b_scales_ptr; | ||||
|   }; | ||||
|  | ||||
|   struct ClassMatmulCacheKey { | ||||
|     dnnl_dim_t b_n_size; | ||||
|     dnnl_dim_t b_k_size; | ||||
|     QuantizationStrategy a_qs; | ||||
|     QuantizationStrategy b_qs; | ||||
|     bool use_azp; | ||||
|     dnnl::memory::data_type c_type; | ||||
|  | ||||
|     friend bool operator==(const ClassMatmulCacheKey& l, | ||||
|                            const ClassMatmulCacheKey& r); | ||||
|   }; | ||||
|  | ||||
|   struct MSizeCacheKey { | ||||
|     dnnl_dim_t a_m_size; | ||||
|     bool use_bias; | ||||
|     dnnl::memory::data_type bias_type; | ||||
|  | ||||
|     friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r); | ||||
|   }; | ||||
|  | ||||
|   using MSizeCache = DNNLPrimitiveCache<MSizeCacheKey, dnnl::matmul>; | ||||
|   using ClassMatmulCache = | ||||
|       DNNLPrimitiveCache<ClassMatmulCacheKey, std::shared_ptr<MSizeCache>>; | ||||
|  | ||||
|   struct ExecArgs : public MSizeCacheKey { | ||||
|     const int8_t* a_ptr; | ||||
|     const float* a_scales_ptr; | ||||
|     const int32_t* a_zero_points_ptr; | ||||
|     const void* bias_ptr; | ||||
|     void* c_ptr; | ||||
|   }; | ||||
|  | ||||
|  public: | ||||
|   W8A8MatMulPrimitiveHandler(const Args& args); | ||||
|  | ||||
|   QuantizationStrategy get_input_scale_strategy() const { return a_qs_; } | ||||
|  | ||||
|   bool get_input_use_zero_point() const { return use_azp_; } | ||||
|  | ||||
|   void execute(ExecArgs& args); | ||||
|  | ||||
|  private: | ||||
|   dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key, | ||||
|                                                      bool first_time); | ||||
|  | ||||
|   void init_runtime_memory_cache(const Args& args); | ||||
|  | ||||
|   dnnl::matmul get_matmul_cache(const MSizeCacheKey& key); | ||||
|  | ||||
|  private: | ||||
|   const bool use_azp_; | ||||
|   const QuantizationStrategy a_qs_; | ||||
|   const QuantizationStrategy b_qs_; | ||||
|   std::shared_ptr<MSizeCache> m_size_cache_; | ||||
| }; | ||||
|  | ||||
| #endif | ||||
| @ -1,206 +0,0 @@ | ||||
| #ifndef DNNL_HELPER_HPP | ||||
| #define DNNL_HELPER_HPP | ||||
|  | ||||
| #include <c10/util/BFloat16.h> | ||||
| #include <c10/util/Half.h> | ||||
|  | ||||
| #include "oneapi/dnnl/dnnl.hpp" | ||||
|  | ||||
| namespace { | ||||
| template <typename T> | ||||
| struct DNNLType { | ||||
|   static constexpr dnnl::memory::data_type type = | ||||
|       dnnl::memory::data_type::undef; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct DNNLType<int8_t> { | ||||
|   static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct DNNLType<int32_t> { | ||||
|   static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct DNNLType<float> { | ||||
|   static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct DNNLType<c10::BFloat16> { | ||||
|   static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct DNNLType<c10::Half> { | ||||
|   static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16; | ||||
| }; | ||||
|  | ||||
| template <typename T> | ||||
| constexpr inline dnnl::memory::data_type get_dnnl_type() { | ||||
|   return DNNLType<std::decay_t<T>>::type; | ||||
| } | ||||
| };  // namespace | ||||
|  | ||||
| template <bool InputNoScale> | ||||
| class DNNLPrimitiveHelper { | ||||
|  public: | ||||
|   // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias) | ||||
|   // A: [M, K], row-major | ||||
|   // B: [K, N], column-major | ||||
|   // C: [M, N], row-major | ||||
|   // bias: [N], row-major, optional | ||||
|   // a_scales: [MS] | ||||
|   // b_scales: [NS] | ||||
|   // Note: Due to the limitation of oneDNN | ||||
|   // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is | ||||
|   // not supported. | ||||
|  | ||||
|   template <typename OutputT, typename BiasT> | ||||
|   static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c, | ||||
|                             const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N, | ||||
|                             dnnl_dim_t K, const float* a_scales, | ||||
|                             const float* b_scales, dnnl_dim_t MS, | ||||
|                             dnnl_dim_t NS) { | ||||
|     auto&& OutputType = get_dnnl_type<OutputT>(); | ||||
|     auto&& BiasType = get_dnnl_type<BiasT>(); | ||||
|  | ||||
|     dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1}); | ||||
|     dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K}); | ||||
|     dnnl::memory::desc c_md({M, N}, OutputType, {N, 1}); | ||||
|  | ||||
|     dnnl::primitive_attr attr; | ||||
|     if constexpr (!InputNoScale) { | ||||
|       if (MS == 1) { | ||||
|         // per-tensor | ||||
|         attr.set_scales_mask(DNNL_ARG_SRC, 0); | ||||
|       } else { | ||||
|         // per-token | ||||
|         TORCH_CHECK(false, "per-token quantization is unsupported."); | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     if (NS == 1) { | ||||
|       // per-tensor | ||||
|       attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); | ||||
|     } else { | ||||
|       // per-channel | ||||
|       attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); | ||||
|     } | ||||
|  | ||||
|     dnnl::matmul::primitive_desc matmul_pd; | ||||
| // Create memory descriptors with format_tag::any for the primitive. This | ||||
| // enables the matmul primitive to choose memory layouts for an | ||||
| // optimized primitive implementation, and these layouts may differ from the | ||||
| // ones provided by the user. | ||||
| #ifdef __aarch64__ | ||||
|     auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8, | ||||
|                                          dnnl::memory::format_tag::any); | ||||
|     auto mat_weights_md = dnnl::memory::desc( | ||||
|         {K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any); | ||||
|     auto mat_dst_md = | ||||
|         dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any); | ||||
|     if (bias) { | ||||
|       dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); | ||||
|       matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md, | ||||
|                                                mat_weights_md, bias_md, | ||||
|                                                mat_dst_md, attr); | ||||
|     } else { | ||||
|       matmul_pd = dnnl::matmul::primitive_desc( | ||||
|           default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr); | ||||
|     } | ||||
| #else | ||||
|     if (bias) { | ||||
|       dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); | ||||
|       matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, | ||||
|                                                bias_md, c_md, attr); | ||||
|     } else { | ||||
|       matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, | ||||
|                                                c_md, attr); | ||||
|     } | ||||
| #endif | ||||
|     dnnl::matmul matmul(matmul_pd); | ||||
|  | ||||
|     auto& engine = default_engine(); | ||||
|  | ||||
|     dnnl::memory a_m(a_md, engine, (void*)a); | ||||
|     dnnl::memory b_m(b_md, engine, (void*)b); | ||||
|     dnnl::memory c_m(c_md, engine, (void*)c); | ||||
|     dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine, | ||||
|                             (void*)a_scales); | ||||
|     dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine, | ||||
|                             (void*)b_scales); | ||||
|  | ||||
|     auto& stream = default_stream(); | ||||
|  | ||||
|     auto mat_src_mem = a_m; | ||||
|     auto mat_weights_mem = b_m; | ||||
|     auto mat_dst_mem = c_m; | ||||
| #ifdef __aarch64__ | ||||
|     if (matmul_pd.weights_desc() != b_m.get_desc()) { | ||||
|       mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine); | ||||
|       dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem); | ||||
|     } | ||||
| #endif | ||||
|     if constexpr (InputNoScale) { | ||||
|       if (bias) { | ||||
|         dnnl::memory::desc bias_md({N}, BiasType, {1}); | ||||
|         dnnl::memory bias_m(bias_md, engine, (void*)bias); | ||||
|         matmul.execute( | ||||
|             stream, { | ||||
|                         {DNNL_ARG_SRC, mat_src_mem}, | ||||
|                         {DNNL_ARG_WEIGHTS, mat_weights_mem}, | ||||
|                         {DNNL_ARG_BIAS, bias_m}, | ||||
|                         {DNNL_ARG_DST, mat_dst_mem}, | ||||
|                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, | ||||
|                     }); | ||||
|       } else { | ||||
|         matmul.execute( | ||||
|             stream, { | ||||
|                         {DNNL_ARG_SRC, mat_src_mem}, | ||||
|                         {DNNL_ARG_WEIGHTS, mat_weights_mem}, | ||||
|                         {DNNL_ARG_DST, mat_dst_mem}, | ||||
|                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, | ||||
|                     }); | ||||
|       } | ||||
|     } else { | ||||
|       if (bias) { | ||||
|         dnnl::memory::desc bias_md({N}, BiasType, {1}); | ||||
|         dnnl::memory bias_m(bias_md, engine, (void*)bias); | ||||
|         matmul.execute( | ||||
|             stream, { | ||||
|                         {DNNL_ARG_SRC, mat_src_mem}, | ||||
|                         {DNNL_ARG_WEIGHTS, mat_weights_mem}, | ||||
|                         {DNNL_ARG_BIAS, bias_m}, | ||||
|                         {DNNL_ARG_DST, mat_dst_mem}, | ||||
|                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, | ||||
|                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, | ||||
|                     }); | ||||
|       } else { | ||||
|         matmul.execute( | ||||
|             stream, { | ||||
|                         {DNNL_ARG_SRC, mat_src_mem}, | ||||
|                         {DNNL_ARG_WEIGHTS, mat_weights_mem}, | ||||
|                         {DNNL_ARG_DST, mat_dst_mem}, | ||||
|                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, | ||||
|                         {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, | ||||
|                     }); | ||||
|       } | ||||
|     } | ||||
|     stream.wait(); | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   static dnnl::engine& default_engine() { | ||||
|     static dnnl::engine engine(dnnl::engine::kind::cpu, 0); | ||||
|     return engine; | ||||
|   } | ||||
|  | ||||
|   static dnnl::stream& default_stream() { | ||||
|     static dnnl::stream stream(default_engine()); | ||||
|     return stream; | ||||
|   } | ||||
| }; | ||||
| #endif | ||||
							
								
								
									
										494
									
								
								csrc/cpu/dnnl_kernels.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										494
									
								
								csrc/cpu/dnnl_kernels.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,494 @@ | ||||
| #include "cpu_types.hpp" | ||||
| #include "dnnl_helper.h" | ||||
|  | ||||
| namespace { | ||||
| template <typename scalar_t> | ||||
| struct KernelVecType { | ||||
|   using load_vec_type = void; | ||||
|   using cvt_vec_type = void; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct KernelVecType<float> { | ||||
|   using load_vec_type = vec_op::FP32Vec16; | ||||
|   using cvt_vec_type = vec_op::FP32Vec16; | ||||
| }; | ||||
|  | ||||
| #if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT) | ||||
| template <> | ||||
| struct KernelVecType<c10::BFloat16> { | ||||
|   using load_vec_type = vec_op::BF16Vec16; | ||||
|   using cvt_vec_type = vec_op::FP32Vec16; | ||||
| }; | ||||
| #endif | ||||
|  | ||||
| template <> | ||||
| struct KernelVecType<c10::Half> { | ||||
| #if defined(__powerpc64__) || defined(__s390x__) | ||||
|   // Power architecture-specific vector type | ||||
|   using load_vec_type = vec_op::FP32Vec16; | ||||
| #else | ||||
|   // Fallback for other architectures | ||||
|   using load_vec_type = vec_op::FP16Vec16; | ||||
| #endif | ||||
|   using cvt_vec_type = vec_op::FP32Vec16; | ||||
| }; | ||||
|  | ||||
| template <bool AZP, typename scalar_t> | ||||
| void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, | ||||
|                                    const float* scale, const int32_t* azp, | ||||
|                                    const int64_t num_tokens, | ||||
|                                    const int64_t input_stride, | ||||
|                                    const int64_t hidden_size) { | ||||
|   using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type; | ||||
|   using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type; | ||||
|   constexpr int64_t vec_elem_num = load_vec_t::VEC_ELEM_NUM; | ||||
|  | ||||
|   constexpr float i8_min = | ||||
|       static_cast<float>(std::numeric_limits<int8_t>::min()); | ||||
|   constexpr float i8_max = | ||||
|       static_cast<float>(std::numeric_limits<int8_t>::max()); | ||||
|   const cvt_vec_t inv_scale(1.0 / *scale); | ||||
|   const cvt_vec_t i8_min_vec(i8_min); | ||||
|   const cvt_vec_t i8_max_vec(i8_max); | ||||
|  | ||||
|   cvt_vec_t zp_vec; | ||||
|   if constexpr (AZP) { | ||||
|     zp_vec = cvt_vec_t(static_cast<float>(*azp)); | ||||
|   } | ||||
|  | ||||
| #pragma omp parallel for | ||||
|   for (int64_t i = 0; i < num_tokens; ++i) { | ||||
|     int64_t j = 0; | ||||
|     const scalar_t* input_ptr = input + i * input_stride; | ||||
|     int8_t* output_ptr = output + i * hidden_size; | ||||
|     for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { | ||||
|       load_vec_t elems(input_ptr + j); | ||||
|       cvt_vec_t elems_fp32(elems); | ||||
|       elems_fp32 = elems_fp32 * inv_scale; | ||||
|  | ||||
|       if constexpr (AZP) { | ||||
|         elems_fp32 = elems_fp32 + zp_vec; | ||||
|       } | ||||
|  | ||||
|       elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); | ||||
|       vec_op::INT8Vec16 elems_int8(elems_fp32); | ||||
|       elems_int8.save(output_ptr + j); | ||||
|     } | ||||
|  | ||||
|     load_vec_t elems(input_ptr + j); | ||||
|     cvt_vec_t elems_fp32(elems); | ||||
|     elems_fp32 = elems_fp32 * inv_scale; | ||||
|  | ||||
|     if constexpr (AZP) { | ||||
|       elems_fp32 = elems_fp32 + zp_vec; | ||||
|     } | ||||
|  | ||||
|     elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); | ||||
|     vec_op::INT8Vec16 elems_int8(elems_fp32); | ||||
|     elems_int8.save(output_ptr + j, hidden_size - j); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <bool AZP, typename scalar_t> | ||||
| void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, | ||||
|                                     float* scale, int32_t* azp, | ||||
|                                     const int64_t num_tokens, | ||||
|                                     const int64_t input_stride, | ||||
|                                     const int64_t hidden_size) { | ||||
|   using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type; | ||||
|   using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type; | ||||
|   constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; | ||||
|  | ||||
|   constexpr float i8_min = | ||||
|       static_cast<float>(std::numeric_limits<int8_t>::min()); | ||||
|   constexpr float i8_max = | ||||
|       static_cast<float>(std::numeric_limits<int8_t>::max()); | ||||
|   const cvt_vec_t i8_min_vec(i8_min); | ||||
|   const cvt_vec_t i8_max_vec(i8_max); | ||||
|  | ||||
| #pragma omp parallel for | ||||
|   for (int64_t i = 0; i < num_tokens; ++i) { | ||||
|     cvt_vec_t max_value(std::numeric_limits<float>::lowest()); | ||||
|     cvt_vec_t min_value(std::numeric_limits<float>::max()); | ||||
|     { | ||||
|       int64_t j = 0; | ||||
|       const scalar_t* input_ptr = input + i * input_stride; | ||||
|       for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { | ||||
|         load_vec_t elems(input_ptr + j); | ||||
|         cvt_vec_t elems_fp32(elems); | ||||
|         if constexpr (AZP) { | ||||
|           max_value = max_value.max(elems_fp32); | ||||
|           min_value = min_value.min(elems_fp32); | ||||
|         } else { | ||||
|           max_value = max_value.max(elems_fp32.abs()); | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       load_vec_t elems(input_ptr + j); | ||||
|       cvt_vec_t elems_fp32(elems); | ||||
|  | ||||
|       if (j + vec_elem_num == hidden_size) { | ||||
|         if constexpr (AZP) { | ||||
|           max_value = max_value.max(elems_fp32); | ||||
|           min_value = min_value.min(elems_fp32); | ||||
|         } else { | ||||
|           max_value = max_value.max(elems_fp32.abs()); | ||||
|         } | ||||
|       } else { | ||||
|         if constexpr (AZP) { | ||||
|           max_value = max_value.max(elems_fp32, hidden_size - j); | ||||
|           min_value = min_value.min(elems_fp32, hidden_size - j); | ||||
|         } else { | ||||
|           max_value = max_value.max(elems_fp32.abs(), hidden_size - j); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     float scale_val, azp_val; | ||||
|     if constexpr (AZP) { | ||||
|       float max_scalar = max_value.reduce_max(); | ||||
|       float min_scalar = min_value.reduce_min(); | ||||
|       scale_val = (max_scalar - min_scalar) / 255.0f; | ||||
|       azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); | ||||
|       azp[i] = azp_val; | ||||
|       scale[i] = scale_val; | ||||
|     } else { | ||||
|       scale_val = max_value.reduce_max() / 127.0f; | ||||
|       scale[i] = scale_val; | ||||
|     } | ||||
|  | ||||
|     const cvt_vec_t inv_scale(1.0 / scale_val); | ||||
|     const cvt_vec_t azp_vec(azp_val); | ||||
|  | ||||
|     { | ||||
|       int64_t j = 0; | ||||
|       const scalar_t* input_ptr = input + i * input_stride; | ||||
|       int8_t* output_ptr = output + i * hidden_size; | ||||
|       for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { | ||||
|         load_vec_t elems(input_ptr + j); | ||||
|         cvt_vec_t elems_fp32(elems); | ||||
|         elems_fp32 = (elems_fp32 * inv_scale); | ||||
|  | ||||
|         if constexpr (AZP) { | ||||
|           elems_fp32 = elems_fp32 + azp_vec; | ||||
|         } | ||||
|         elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); | ||||
|         vec_op::INT8Vec16 elems_int8(elems_fp32); | ||||
|         elems_int8.save(output_ptr + j); | ||||
|       } | ||||
|  | ||||
|       load_vec_t elems(input_ptr + j); | ||||
|       cvt_vec_t elems_fp32(elems); | ||||
|       elems_fp32 = (elems_fp32 * inv_scale); | ||||
|  | ||||
|       if constexpr (AZP) { | ||||
|         elems_fp32 = elems_fp32 + azp_vec; | ||||
|       } | ||||
|       elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); | ||||
|       vec_op::INT8Vec16 elems_int8(elems_fp32); | ||||
|       elems_int8.save(output_ptr + j, hidden_size - j); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <bool AZP, bool Bias, typename scalar_t> | ||||
| void dynamic_quant_epilogue(const float* input, scalar_t* output, | ||||
|                             const float* a_scale, const int32_t* azp, | ||||
|                             const float* azp_adj, const scalar_t* bias, | ||||
|                             const int64_t num_tokens, | ||||
|                             const int64_t hidden_size) { | ||||
|   CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) | ||||
|   using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type; | ||||
|   using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type; | ||||
|   constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; | ||||
|  | ||||
|   const int64_t thread_num = omp_get_max_threads(); | ||||
|   if (num_tokens > thread_num) { | ||||
| #pragma omp parallel for | ||||
|     for (int64_t i = 0; i < num_tokens; ++i) { | ||||
|       const float* input_ptr = input + i * hidden_size; | ||||
|       scalar_t* output_ptr = output + i * hidden_size; | ||||
|       int64_t j = 0; | ||||
|       cvt_vec_t token_scale_vec(a_scale[i]); | ||||
|       cvt_vec_t token_zp_scale_vec; | ||||
|       if constexpr (AZP) { | ||||
|         float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]); | ||||
|         token_zp_scale_vec = cvt_vec_t(zp_scale_val); | ||||
|       } | ||||
|       for (; j < hidden_size - vec_elem_num; ++j) { | ||||
|         cvt_vec_t elems_fp32(input_ptr + j); | ||||
|         elems_fp32 = elems_fp32 * token_scale_vec; | ||||
|         if constexpr (AZP) { | ||||
|           cvt_vec_t azp_adj_fp32(azp_adj + j); | ||||
|           elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; | ||||
|         } | ||||
|         if constexpr (Bias) { | ||||
|           load_vec_t bias_vec(bias + j); | ||||
|           cvt_vec_t bias_vec_fp32(bias_vec); | ||||
|           elems_fp32 = elems_fp32 + bias_vec_fp32; | ||||
|         } | ||||
|         load_vec_t elems_out(elems_fp32); | ||||
|         elems_out.save(output_ptr + j); | ||||
|       } | ||||
|       cvt_vec_t elems_fp32(input_ptr + j); | ||||
|       elems_fp32 = elems_fp32 * token_scale_vec; | ||||
|       if constexpr (AZP) { | ||||
|         cvt_vec_t azp_adj_fp32(azp_adj + j); | ||||
|         elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; | ||||
|       } | ||||
|       if constexpr (Bias) { | ||||
|         load_vec_t bias_vec(bias + j); | ||||
|         cvt_vec_t bias_vec_fp32(bias_vec); | ||||
|         elems_fp32 = elems_fp32 + bias_vec_fp32; | ||||
|       } | ||||
|       load_vec_t elems_out(elems_fp32); | ||||
|       elems_out.save(output_ptr + j, hidden_size - j); | ||||
|     } | ||||
|   } else { | ||||
|     const int64_t vec_iteration = | ||||
|         (hidden_size + vec_elem_num - 1) / vec_elem_num; | ||||
|     const int64_t vec_iteration_per_thread = | ||||
|         (vec_iteration + thread_num - 1) / thread_num; | ||||
|     const int64_t elem_num_per_thread = vec_iteration_per_thread * vec_elem_num; | ||||
| #pragma omp parallel for schedule(static, 1) | ||||
|     for (int64_t i = 0; i < thread_num; ++i) { | ||||
|       const int64_t start = elem_num_per_thread * i; | ||||
|       const int64_t end = std::min(hidden_size, elem_num_per_thread + start); | ||||
|       for (int64_t j = 0; j < num_tokens; ++j) { | ||||
|         cvt_vec_t token_scale_vec(a_scale[j]); | ||||
|         cvt_vec_t token_zp_scale_vec; | ||||
|         if constexpr (AZP) { | ||||
|           float zp_scale_val = a_scale[j] * static_cast<float>(azp[j]); | ||||
|           token_zp_scale_vec = cvt_vec_t(zp_scale_val); | ||||
|         } | ||||
|         int64_t k = start; | ||||
|         const float* input_ptr = input + j * hidden_size; | ||||
|         scalar_t* output_ptr = output + j * hidden_size; | ||||
|         for (; k < end - vec_elem_num; k += vec_elem_num) { | ||||
|           cvt_vec_t elems_fp32(input_ptr + k); | ||||
|           elems_fp32 = elems_fp32 * token_scale_vec; | ||||
|           if constexpr (AZP) { | ||||
|             cvt_vec_t azp_adj_fp32(azp_adj + k); | ||||
|             elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; | ||||
|           } | ||||
|           if constexpr (Bias) { | ||||
|             load_vec_t bias_vec(bias + k); | ||||
|             cvt_vec_t bias_vec_fp32(bias_vec); | ||||
|             elems_fp32 = elems_fp32 + bias_vec_fp32; | ||||
|           } | ||||
|           load_vec_t elems_out(elems_fp32); | ||||
|           elems_out.save(output_ptr + k); | ||||
|         } | ||||
|         if (k < end) { | ||||
|           cvt_vec_t elems_fp32(input_ptr + k); | ||||
|           elems_fp32 = elems_fp32 * token_scale_vec; | ||||
|           if constexpr (AZP) { | ||||
|             cvt_vec_t azp_adj_fp32(azp_adj + k); | ||||
|             elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec; | ||||
|           } | ||||
|           if constexpr (Bias) { | ||||
|             load_vec_t bias_vec(bias + k); | ||||
|             cvt_vec_t bias_vec_fp32(bias_vec); | ||||
|             elems_fp32 = elems_fp32 + bias_vec_fp32; | ||||
|           } | ||||
|           load_vec_t elems_out(elems_fp32); | ||||
|           elems_out.save(output_ptr + k, end - k); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
| }  // namespace | ||||
|  | ||||
| int64_t create_onednn_scaled_mm_handler( | ||||
|     const torch::Tensor& b,         // [IC, OC], column-major | ||||
|     const torch::Tensor& b_scales,  // [1] or [OC] | ||||
|     at::ScalarType output_type, bool dynamic_act_quant, bool use_azp, | ||||
|     int64_t primitive_cache_size) { | ||||
|   TORCH_CHECK(b.dim() == 2); | ||||
|   TORCH_CHECK(b.stride(0) == 1);  // Column-major | ||||
|   TORCH_CHECK(b_scales.is_contiguous()); | ||||
|  | ||||
|   W8A8MatMulPrimitiveHandler::Args args; | ||||
|   args.primitive_cache_size = primitive_cache_size; | ||||
|  | ||||
|   if (b_scales.numel() == 1) { | ||||
|     args.b_quantization_strategy = | ||||
|         W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR; | ||||
|   } else { | ||||
|     TORCH_CHECK_EQ(b_scales.numel(), b.size(1)); | ||||
|     args.b_quantization_strategy = | ||||
|         W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_OUTPUT_CHANNEL; | ||||
|   } | ||||
|   args.b_scales_ptr = b_scales.data_ptr<float>(); | ||||
|   args.b_k_size = b.size(0); | ||||
|   args.b_k_stride = b.stride(0); | ||||
|   args.b_n_size = b.size(1); | ||||
|   args.b_n_stride = b.stride(1); | ||||
|   args.b_ptr = b.data_ptr<int8_t>(); | ||||
|  | ||||
|   if (dynamic_act_quant) { | ||||
|     // dynamic per-token, bias, A scales and A zps will be applied in outside. | ||||
|     args.a_quantization_strategy = | ||||
|         W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN; | ||||
|     args.use_a_zero_point = false; | ||||
|   } else { | ||||
|     // static per-tensor | ||||
|     args.a_quantization_strategy = | ||||
|         W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR; | ||||
|     args.use_a_zero_point = use_azp; | ||||
|   } | ||||
|  | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(output_type, "create_onednn_scaled_mm_handler", | ||||
|                                [&] { | ||||
|                                  if (dynamic_act_quant) { | ||||
|                                    args.c_type = get_dnnl_type<float>(); | ||||
|                                  } else { | ||||
|                                    args.c_type = get_dnnl_type<scalar_t>(); | ||||
|                                  } | ||||
|                                }); | ||||
|  | ||||
|   return reinterpret_cast<int64_t>(new W8A8MatMulPrimitiveHandler(args)); | ||||
| } | ||||
|  | ||||
| void onednn_scaled_mm( | ||||
|     torch::Tensor& c,                             // [M, OC], row-major | ||||
|     const torch::Tensor& a,                       // [M, IC], row-major | ||||
|     const torch::Tensor& a_scales,                // [M] or [1] | ||||
|     const std::optional<torch::Tensor>& azp,      // [M] or [1] | ||||
|     const std::optional<torch::Tensor>& azp_adj,  // [M] or [1] | ||||
|     const std::optional<torch::Tensor>& bias,     // [N] | ||||
|     int64_t handler) { | ||||
|   CPU_KERNEL_GUARD_IN(onednn_scaled_mm) | ||||
|   TORCH_CHECK(a.dim() == 2); | ||||
|   TORCH_CHECK(a.is_contiguous()); | ||||
|   TORCH_CHECK(c.is_contiguous()); | ||||
|   W8A8MatMulPrimitiveHandler* ptr = | ||||
|       reinterpret_cast<W8A8MatMulPrimitiveHandler*>(handler); | ||||
|   const int32_t* azp_ptr = nullptr; | ||||
|   if (azp.has_value()) { | ||||
|     azp_ptr = azp->data_ptr<int32_t>(); | ||||
|   } | ||||
|   if (ptr->get_input_scale_strategy() == | ||||
|       W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) { | ||||
|     TORCH_CHECK_EQ(a_scales.numel(), 1); | ||||
|   } | ||||
|  | ||||
|   W8A8MatMulPrimitiveHandler::ExecArgs exec_args; | ||||
|   exec_args.a_ptr = a.data_ptr<int8_t>(); | ||||
|   exec_args.a_m_size = a.size(0); | ||||
|   exec_args.bias_ptr = nullptr; | ||||
|   exec_args.use_bias = false; | ||||
|   exec_args.a_scales_ptr = nullptr; | ||||
|   exec_args.a_zero_points_ptr = nullptr; | ||||
|  | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "onednn_scaled_mm", [&] { | ||||
|     if (ptr->get_input_scale_strategy() == | ||||
|         W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) { | ||||
|       if (bias.has_value()) { | ||||
|         exec_args.bias_ptr = bias->data_ptr<scalar_t>(); | ||||
|         exec_args.bias_type = get_dnnl_type<scalar_t>(); | ||||
|         exec_args.use_bias = true; | ||||
|       } | ||||
|       exec_args.a_scales_ptr = a_scales.data_ptr<float>(); | ||||
|       exec_args.a_zero_points_ptr = azp_ptr; | ||||
|       exec_args.c_ptr = c.data_ptr<scalar_t>(); | ||||
|       ptr->execute(exec_args); | ||||
|     } else if (ptr->get_input_scale_strategy() == | ||||
|                W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN) { | ||||
|       torch::Tensor tmp_fp32_out = | ||||
|           torch::empty_like(c, ::at::ScalarType::Float); | ||||
|       exec_args.c_ptr = tmp_fp32_out.data_ptr<float>(); | ||||
|       ptr->execute(exec_args); | ||||
|       if (bias.has_value()) { | ||||
|         if (azp.has_value()) { | ||||
|           dynamic_quant_epilogue<true, true>( | ||||
|               tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(), | ||||
|               a_scales.data_ptr<float>(), azp_ptr, azp_adj->data_ptr<float>(), | ||||
|               bias->data_ptr<scalar_t>(), c.size(0), c.size(1)); | ||||
|         } else { | ||||
|           dynamic_quant_epilogue<false, true>( | ||||
|               tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(), | ||||
|               a_scales.data_ptr<float>(), azp_ptr, nullptr, | ||||
|               bias->data_ptr<scalar_t>(), c.size(0), c.size(1)); | ||||
|         } | ||||
|       } else { | ||||
|         if (azp.has_value()) { | ||||
|           dynamic_quant_epilogue<true, false>( | ||||
|               tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(), | ||||
|               a_scales.data_ptr<float>(), azp_ptr, azp_adj->data_ptr<float>(), | ||||
|               (scalar_t*)nullptr, c.size(0), c.size(1)); | ||||
|         } else { | ||||
|           dynamic_quant_epilogue<false, false>( | ||||
|               tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(), | ||||
|               a_scales.data_ptr<float>(), azp_ptr, nullptr, (scalar_t*)nullptr, | ||||
|               c.size(0), c.size(1)); | ||||
|         } | ||||
|       } | ||||
|     } else { | ||||
|       TORCH_CHECK(false, "invalid act quant type."); | ||||
|     } | ||||
|   }); | ||||
| } | ||||
|  | ||||
| // static-per-tensor quantization. | ||||
| void static_scaled_int8_quant( | ||||
|     torch::Tensor& out,          // [batch, hidden_size] | ||||
|     const torch::Tensor& input,  // [batch, hidden_size] | ||||
|     const torch::Tensor& scale, std::optional<torch::Tensor> const& azp) { | ||||
|   CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) | ||||
|   TORCH_CHECK(out.is_contiguous()); | ||||
|   TORCH_CHECK_EQ(input.dim(), 2); | ||||
|   TORCH_CHECK_EQ(input.stride(1), 1); | ||||
|   TORCH_CHECK(scale.numel() == 1); | ||||
|   TORCH_CHECK(!azp.has_value() || azp->numel() == 1); | ||||
|  | ||||
|   const int64_t stride = input.stride(0); | ||||
|   const int64_t hidden_size = input.size(1); | ||||
|   const int64_t num_tokens = input.size(0); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|       input.scalar_type(), "static_scaled_int8_quant_impl", [&] { | ||||
|         if (azp.has_value()) { | ||||
|           static_scaled_int8_quant_impl<true>( | ||||
|               input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), | ||||
|               scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens, | ||||
|               stride, hidden_size); | ||||
|         } else { | ||||
|           static_scaled_int8_quant_impl<false>(input.data_ptr<scalar_t>(), | ||||
|                                                out.data_ptr<int8_t>(), | ||||
|                                                scale.data_ptr<float>(), nullptr, | ||||
|                                                num_tokens, stride, hidden_size); | ||||
|         } | ||||
|       }); | ||||
| } | ||||
|  | ||||
| // dynamic-per-token quantization. | ||||
| void dynamic_scaled_int8_quant( | ||||
|     torch::Tensor& out,          // [batch, hidden_size] | ||||
|     const torch::Tensor& input,  // [batch, hidden_size] | ||||
|     torch::Tensor& scale,        // [batch, 1] | ||||
|     std::optional<torch::Tensor> const& azp) { | ||||
|   CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) | ||||
|   TORCH_CHECK(out.is_contiguous()); | ||||
|   TORCH_CHECK_EQ(input.dim(), 2); | ||||
|   TORCH_CHECK_EQ(input.stride(1), 1); | ||||
|  | ||||
|   const int64_t hidden_size = input.size(1); | ||||
|   const int64_t num_tokens = input.size(0); | ||||
|   const int64_t stride = input.stride(0); | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|       input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { | ||||
|         if (azp.has_value()) { | ||||
|           dynamic_scaled_int8_quant_impl<true>( | ||||
|               input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), | ||||
|               scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens, | ||||
|               stride, hidden_size); | ||||
|         } else { | ||||
|           dynamic_scaled_int8_quant_impl<false>( | ||||
|               input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), | ||||
|               scale.data_ptr<float>(), nullptr, num_tokens, stride, | ||||
|               hidden_size); | ||||
|         } | ||||
|       }); | ||||
| } | ||||
| @ -1,951 +0,0 @@ | ||||
| #include "cpu_types.hpp" | ||||
| #include "dnnl_helper.hpp" | ||||
|  | ||||
| namespace { | ||||
| template <typename scalar_t> | ||||
| struct KernelVecType { | ||||
|   using load_vec_type = void; | ||||
|   using azp_adj_load_vec_type = void; | ||||
|   using cvt_vec_type = void; | ||||
| }; | ||||
|  | ||||
| template <> | ||||
| struct KernelVecType<float> { | ||||
|   using load_vec_type = vec_op::FP32Vec16; | ||||
|   using azp_adj_load_vec_type = vec_op::INT32Vec16; | ||||
|   using cvt_vec_type = vec_op::FP32Vec16; | ||||
| }; | ||||
|  | ||||
| #if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT) | ||||
| template <> | ||||
| struct KernelVecType<c10::BFloat16> { | ||||
|   using load_vec_type = vec_op::BF16Vec16; | ||||
|   using azp_adj_load_vec_type = vec_op::INT32Vec16; | ||||
|   using cvt_vec_type = vec_op::FP32Vec16; | ||||
| }; | ||||
| #endif | ||||
|  | ||||
| template <> | ||||
| struct KernelVecType<c10::Half> { | ||||
| #if defined(__powerpc64__) || defined(__s390x__) | ||||
|   // Power architecture-specific vector type | ||||
|   using load_vec_type = vec_op::FP32Vec16; | ||||
| #else | ||||
|   // Fallback for other architectures | ||||
|   using load_vec_type = vec_op::FP16Vec16; | ||||
| #endif | ||||
|   using azp_adj_load_vec_type = vec_op::INT32Vec16; | ||||
|   using cvt_vec_type = vec_op::FP32Vec16; | ||||
| }; | ||||
|  | ||||
| #if defined(__AVX512F__) || defined(__aarch64__) | ||||
| template <bool AZP, typename scalar_t> | ||||
| void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, | ||||
|                                    const float* scale, const int32_t* azp, | ||||
|                                    const int num_tokens, | ||||
|                                    const int hidden_size) { | ||||
|   using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type; | ||||
|   using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type; | ||||
|   constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; | ||||
|  | ||||
|   constexpr float i8_min = | ||||
|       static_cast<float>(std::numeric_limits<int8_t>::min()); | ||||
|   constexpr float i8_max = | ||||
|       static_cast<float>(std::numeric_limits<int8_t>::max()); | ||||
|   const cvt_vec_t inv_scale(1.0 / *scale); | ||||
|   const cvt_vec_t i8_min_vec(i8_min); | ||||
|   const cvt_vec_t i8_max_vec(i8_max); | ||||
|  | ||||
|   cvt_vec_t zp_vec; | ||||
|   if constexpr (AZP) { | ||||
|     zp_vec = cvt_vec_t(static_cast<float>(*azp)); | ||||
|   } | ||||
|  | ||||
|   #pragma omp parallel for | ||||
|   for (int i = 0; i < num_tokens; ++i) { | ||||
|     int j = 0; | ||||
|     for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { | ||||
|       load_vec_t elems(input + i * hidden_size + j); | ||||
|       cvt_vec_t elems_fp32(elems); | ||||
|       elems_fp32 = elems_fp32 * inv_scale; | ||||
|  | ||||
|       if constexpr (AZP) { | ||||
|         elems_fp32 = elems_fp32 + zp_vec; | ||||
|       } | ||||
|  | ||||
|       elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); | ||||
|       vec_op::INT8Vec16 elems_int8(elems_fp32); | ||||
|       elems_int8.save(output + i * hidden_size + j); | ||||
|     } | ||||
|  | ||||
|     load_vec_t elems(input + i * hidden_size + j); | ||||
|     cvt_vec_t elems_fp32(elems); | ||||
|     elems_fp32 = elems_fp32 * inv_scale; | ||||
|  | ||||
|     if constexpr (AZP) { | ||||
|       elems_fp32 = elems_fp32 + zp_vec; | ||||
|     } | ||||
|  | ||||
|     elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); | ||||
|     vec_op::INT8Vec16 elems_int8(elems_fp32); | ||||
|     elems_int8.save(output + i * hidden_size + j, hidden_size - j); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <bool AZP, typename scalar_t> | ||||
| void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, | ||||
|                                     float* scale, int32_t* azp, | ||||
|                                     const int num_tokens, | ||||
|                                     const int hidden_size) { | ||||
|   using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type; | ||||
|   using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type; | ||||
|   constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; | ||||
|  | ||||
|   constexpr float i8_min = | ||||
|       static_cast<float>(std::numeric_limits<int8_t>::min()); | ||||
|   constexpr float i8_max = | ||||
|       static_cast<float>(std::numeric_limits<int8_t>::max()); | ||||
|   const cvt_vec_t i8_min_vec(i8_min); | ||||
|   const cvt_vec_t i8_max_vec(i8_max); | ||||
|  | ||||
|   #pragma omp parallel for | ||||
|   for (int i = 0; i < num_tokens; ++i) { | ||||
|     cvt_vec_t max_value(std::numeric_limits<float>::lowest()); | ||||
|     cvt_vec_t min_value(std::numeric_limits<float>::max()); | ||||
|     { | ||||
|       int j = 0; | ||||
|       for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { | ||||
|         load_vec_t elems(input + i * hidden_size + j); | ||||
|         cvt_vec_t elems_fp32(elems); | ||||
|         if constexpr (AZP) { | ||||
|           max_value = max_value.max(elems_fp32); | ||||
|           min_value = min_value.min(elems_fp32); | ||||
|         } else { | ||||
|           max_value = max_value.max(elems_fp32.abs()); | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       load_vec_t elems(input + i * hidden_size + j); | ||||
|       cvt_vec_t elems_fp32(elems); | ||||
|  | ||||
|       if (j + vec_elem_num == hidden_size) { | ||||
|         if constexpr (AZP) { | ||||
|           max_value = max_value.max(elems_fp32); | ||||
|           min_value = min_value.min(elems_fp32); | ||||
|         } else { | ||||
|           max_value = max_value.max(elems_fp32.abs()); | ||||
|         } | ||||
|       } else { | ||||
|         if constexpr (AZP) { | ||||
|           max_value = max_value.max(elems_fp32, hidden_size - j); | ||||
|           min_value = min_value.min(elems_fp32, hidden_size - j); | ||||
|         } else { | ||||
|           max_value = max_value.max(elems_fp32.abs(), hidden_size - j); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     float scale_val, azp_val; | ||||
|     if constexpr (AZP) { | ||||
|       float max_scalar = max_value.reduce_max(); | ||||
|       float min_scalar = min_value.reduce_min(); | ||||
|       scale_val = (max_scalar - min_scalar) / 255.0f; | ||||
|       azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); | ||||
|       azp[i] = static_cast<int32_t>(azp_val); | ||||
|       scale[i] = scale_val; | ||||
|     } else { | ||||
|       scale_val = max_value.reduce_max() / 127.0f; | ||||
|       scale[i] = scale_val; | ||||
|     } | ||||
|  | ||||
|     const cvt_vec_t inv_scale(1.0 / scale_val); | ||||
|     const cvt_vec_t azp_vec(azp_val); | ||||
|  | ||||
|     { | ||||
|       int j = 0; | ||||
|       for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { | ||||
|         load_vec_t elems(input + i * hidden_size + j); | ||||
|         cvt_vec_t elems_fp32(elems); | ||||
|         elems_fp32 = (elems_fp32 * inv_scale); | ||||
|  | ||||
|         if constexpr (AZP) { | ||||
|           elems_fp32 = elems_fp32 + azp_vec; | ||||
|         } | ||||
|         elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); | ||||
|         vec_op::INT8Vec16 elems_int8(elems_fp32); | ||||
|         elems_int8.save(output + i * hidden_size + j); | ||||
|       } | ||||
|  | ||||
|       load_vec_t elems(input + i * hidden_size + j); | ||||
|       cvt_vec_t elems_fp32(elems); | ||||
|       elems_fp32 = (elems_fp32 * inv_scale); | ||||
|  | ||||
|       if constexpr (AZP) { | ||||
|         elems_fp32 = elems_fp32 + azp_vec; | ||||
|       } | ||||
|       elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); | ||||
|       vec_op::INT8Vec16 elems_int8(elems_fp32); | ||||
|       elems_int8.save(output + i * hidden_size + j, hidden_size - j); | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <bool PerChannel, typename scalar_t> | ||||
| void static_quant_epilogue(const float* input, scalar_t* output, | ||||
|                            const float a_scale, const float* b_scale, | ||||
|                            const int32_t* azp_with_adj, const int num_tokens, | ||||
|                            const int hidden_size) { | ||||
|   CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) | ||||
|   using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type; | ||||
|   using azp_adj_load_vec_t = | ||||
|       typename KernelVecType<scalar_t>::azp_adj_load_vec_type; | ||||
|   using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type; | ||||
|   constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; | ||||
|  | ||||
|   #pragma omp parallel for | ||||
|   for (int i = 0; i < num_tokens; ++i) { | ||||
|     cvt_vec_t a_scale_vec(a_scale); | ||||
|     cvt_vec_t b_scale_vec(*b_scale); | ||||
|     cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; | ||||
|  | ||||
|     int j = 0; | ||||
|     for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { | ||||
|       cvt_vec_t elems_fp32(input + i * hidden_size + j); | ||||
|       azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); | ||||
|       cvt_vec_t azp_adj_fp32(azp_adj_vec); | ||||
|  | ||||
|       if constexpr (PerChannel) { | ||||
|         b_scale_vec = cvt_vec_t(b_scale + j); | ||||
|         scale_vec = b_scale_vec * a_scale_vec; | ||||
|       } | ||||
|  | ||||
|       elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; | ||||
|  | ||||
|       load_vec_t elems_out(elems_fp32); | ||||
|       elems_out.save(output + i * hidden_size + j); | ||||
|     } | ||||
|  | ||||
|     cvt_vec_t elems_fp32(input + i * hidden_size + j); | ||||
|     azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); | ||||
|     cvt_vec_t azp_adj_fp32(azp_adj_vec); | ||||
|  | ||||
|     if constexpr (PerChannel) { | ||||
|       b_scale_vec = cvt_vec_t(b_scale + j); | ||||
|       scale_vec = b_scale_vec * a_scale_vec; | ||||
|     } | ||||
|  | ||||
|     elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; | ||||
|  | ||||
|     load_vec_t elems_out(elems_fp32); | ||||
|     elems_out.save(output + i * hidden_size + j, hidden_size - j); | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <bool AZP, bool PerChannel, bool Bias, typename scalar_t> | ||||
| void dynamic_quant_epilogue(const float* input, scalar_t* output, | ||||
|                             const float* a_scale, const float* b_scale, | ||||
|                             const int32_t* azp, const int32_t* azp_adj, | ||||
|                             const scalar_t* bias, const int num_tokens, | ||||
|                             const int hidden_size) { | ||||
|   CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) | ||||
|   using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type; | ||||
|   using azp_adj_load_vec_t = | ||||
|       typename KernelVecType<scalar_t>::azp_adj_load_vec_type; | ||||
|   using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type; | ||||
|   constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; | ||||
|  | ||||
|   #pragma omp parallel for | ||||
|   for (int i = 0; i < num_tokens; ++i) { | ||||
|     int j = 0; | ||||
|     cvt_vec_t token_scale_vec(a_scale[i]); | ||||
|     cvt_vec_t token_zp_scale_vec; | ||||
|     if constexpr (AZP) { | ||||
|       float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]); | ||||
|       if constexpr (!PerChannel) { | ||||
|         zp_scale_val *= *b_scale; | ||||
|       } | ||||
|       token_zp_scale_vec = cvt_vec_t(zp_scale_val); | ||||
|     } | ||||
|  | ||||
|     for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { | ||||
|       cvt_vec_t elems_fp32(input + i * hidden_size + j); | ||||
|       elems_fp32 = elems_fp32 * token_scale_vec; | ||||
|  | ||||
|       if constexpr (AZP) { | ||||
|         azp_adj_load_vec_t azp_adj_vec(azp_adj + j); | ||||
|         cvt_vec_t azp_adj_fp32(azp_adj_vec); | ||||
|         azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; | ||||
|  | ||||
|         if constexpr (PerChannel) { | ||||
|           cvt_vec_t b_scale_vec(b_scale + j); | ||||
|           azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; | ||||
|         } | ||||
|  | ||||
|         elems_fp32 = elems_fp32 - azp_adj_fp32; | ||||
|       } | ||||
|  | ||||
|       if constexpr (Bias) { | ||||
|         load_vec_t bias_vec(bias + j); | ||||
|         cvt_vec_t bias_vec_fp32(bias_vec); | ||||
|         elems_fp32 = elems_fp32 + bias_vec_fp32; | ||||
|       } | ||||
|  | ||||
|       load_vec_t elems_out(elems_fp32); | ||||
|       elems_out.save(output + i * hidden_size + j); | ||||
|     } | ||||
|  | ||||
|     cvt_vec_t elems_fp32(input + i * hidden_size + j); | ||||
|     elems_fp32 = elems_fp32 * token_scale_vec; | ||||
|  | ||||
|     if constexpr (AZP) { | ||||
|       azp_adj_load_vec_t azp_adj_vec(azp_adj + j); | ||||
|       cvt_vec_t azp_adj_fp32(azp_adj_vec); | ||||
|       azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; | ||||
|  | ||||
|       if constexpr (PerChannel) { | ||||
|         cvt_vec_t b_scale_vec(b_scale + j); | ||||
|         azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; | ||||
|       } | ||||
|  | ||||
|       elems_fp32 = elems_fp32 - azp_adj_fp32; | ||||
|     } | ||||
|  | ||||
|     if constexpr (Bias) { | ||||
|       load_vec_t bias_vec(bias + j); | ||||
|       cvt_vec_t bias_vec_fp32(bias_vec); | ||||
|       elems_fp32 = elems_fp32 + bias_vec_fp32; | ||||
|     } | ||||
|  | ||||
|     load_vec_t elems_out(elems_fp32); | ||||
|     elems_out.save(output + i * hidden_size + j, hidden_size - j); | ||||
|   } | ||||
| } | ||||
| #elif defined(__powerpc64__) | ||||
| template <bool AZP, typename scalar_t> | ||||
| void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, | ||||
|                                    const float* scale, const int32_t* azp, | ||||
|                                    const int num_tokens, | ||||
|                                    const int hidden_size) { | ||||
|   using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type; | ||||
|   using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type; | ||||
|   constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; | ||||
|  | ||||
|   constexpr float i8_min = | ||||
|       static_cast<float>(std::numeric_limits<int8_t>::min()); | ||||
|   constexpr float i8_max = | ||||
|       static_cast<float>(std::numeric_limits<int8_t>::max()); | ||||
|  | ||||
|   const cvt_vec_t inv_scale(1.0 / *scale); | ||||
|   const cvt_vec_t i8_min_vec(i8_min); | ||||
|   const cvt_vec_t i8_max_vec(i8_max); | ||||
|  | ||||
|   cvt_vec_t zp_vec; | ||||
|   if constexpr (AZP) { | ||||
|     zp_vec = cvt_vec_t(static_cast<float>(*azp)); | ||||
|   } | ||||
|   #pragma omp parallel for | ||||
|   for (int i = 0; i < num_tokens; ++i) { | ||||
|     int j = 0; | ||||
|     for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { | ||||
|       load_vec_t elems(input + i * hidden_size + j); | ||||
|       cvt_vec_t elems_fp32(elems); | ||||
|       elems_fp32 = elems_fp32 * inv_scale; | ||||
|       if constexpr (AZP) { | ||||
|         elems_fp32 = elems_fp32 + zp_vec; | ||||
|       } | ||||
|       elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); | ||||
|       vec_op::INT8Vec16 elems_int8(elems_fp32); | ||||
|       elems_int8.save(output + i * hidden_size + j); | ||||
|     } | ||||
|     load_vec_t elems(input + i * hidden_size + j); | ||||
|     cvt_vec_t elems_fp32(elems); | ||||
|     elems_fp32 = elems_fp32 * inv_scale; | ||||
|  | ||||
|     if constexpr (AZP) { | ||||
|       elems_fp32 = elems_fp32 + zp_vec; | ||||
|     } | ||||
|  | ||||
|     elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); | ||||
|     vec_op::INT8Vec16 elems_int8(elems_fp32); | ||||
|     elems_int8.save(output + i * hidden_size + j, hidden_size - j); | ||||
|   } | ||||
| } | ||||
| template <bool AZP, typename scalar_t> | ||||
| void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, | ||||
|                                     float* scale, int32_t* azp, | ||||
|                                     const int num_tokens, | ||||
|                                     const int hidden_size) { | ||||
|   using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type; | ||||
|   using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type; | ||||
|   constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; | ||||
|  | ||||
|   constexpr float i8_min = | ||||
|       static_cast<float>(std::numeric_limits<int8_t>::min()); | ||||
|   constexpr float i8_max = | ||||
|       static_cast<float>(std::numeric_limits<int8_t>::max()); | ||||
|   const cvt_vec_t i8_min_vec(i8_min); | ||||
|   const cvt_vec_t i8_max_vec(i8_max); | ||||
|  | ||||
|   #pragma omp parallel for | ||||
|   for (int i = 0; i < num_tokens; ++i) { | ||||
|     cvt_vec_t max_value(std::numeric_limits<float>::lowest()); | ||||
|     cvt_vec_t min_value(std::numeric_limits<float>::max()); | ||||
|     { | ||||
|       int j = 0; | ||||
|       for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { | ||||
|         load_vec_t elems(input + i * hidden_size + j); | ||||
|         cvt_vec_t elems_fp32(elems); | ||||
|         if constexpr (AZP) { | ||||
|           max_value = max_value.max(elems_fp32); | ||||
|           min_value = min_value.min(elems_fp32); | ||||
|         } else { | ||||
|           max_value = max_value.max(elems_fp32.abs()); | ||||
|         } | ||||
|       } | ||||
|  | ||||
|       load_vec_t elems(input + i * hidden_size + j); | ||||
|       cvt_vec_t elems_fp32(elems); | ||||
|  | ||||
|       if (j + vec_elem_num == hidden_size) { | ||||
|         if constexpr (AZP) { | ||||
|           max_value = max_value.max(elems_fp32); | ||||
|           min_value = min_value.min(elems_fp32); | ||||
|         } else { | ||||
|           max_value = max_value.max(elems_fp32.abs()); | ||||
|         } | ||||
|       } else { | ||||
|         if constexpr (AZP) { | ||||
|           max_value = max_value.max(elems_fp32, hidden_size - j); | ||||
|           min_value = min_value.min(elems_fp32, hidden_size - j); | ||||
|         } else { | ||||
|           max_value = max_value.max(elems_fp32.abs(), hidden_size - j); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     float scale_val, azp_val; | ||||
|     if constexpr (AZP) { | ||||
|       float max_scalar = max_value.reduce_max(); | ||||
|       float min_scalar = min_value.reduce_min(); | ||||
|       scale_val = (max_scalar - min_scalar) / 255.0f; | ||||
|       azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); | ||||
|       azp[i] = static_cast<int32_t>(azp_val); | ||||
|       scale[i] = scale_val; | ||||
|     } else { | ||||
|       scale_val = max_value.reduce_max() / 127.0f; | ||||
|       scale[i] = scale_val; | ||||
|     } | ||||
|  | ||||
|     const cvt_vec_t inv_scale(1.0 / scale_val); | ||||
|     const cvt_vec_t azp_vec(azp_val); | ||||
|  | ||||
|     { | ||||
|       int j = 0; | ||||
|       for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { | ||||
|         load_vec_t elems(input + i * hidden_size + j); | ||||
|         cvt_vec_t elems_fp32(elems); | ||||
|         elems_fp32 = (elems_fp32 * inv_scale); | ||||
|  | ||||
|         if constexpr (AZP) { | ||||
|           elems_fp32 = elems_fp32 + azp_vec; | ||||
|         } | ||||
|         elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); | ||||
|         vec_op::INT8Vec16 elems_int8(elems_fp32); | ||||
|         elems_int8.save(output + i * hidden_size + j); | ||||
|       } | ||||
|  | ||||
|       load_vec_t elems(input + i * hidden_size + j); | ||||
|       cvt_vec_t elems_fp32(elems); | ||||
|       elems_fp32 = (elems_fp32 * inv_scale); | ||||
|  | ||||
|       if constexpr (AZP) { | ||||
|         elems_fp32 = elems_fp32 + azp_vec; | ||||
|       } | ||||
|       elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); | ||||
|       vec_op::INT8Vec16 elems_int8(elems_fp32); | ||||
|       elems_int8.save(output + i * hidden_size + j, hidden_size - j); | ||||
|     } | ||||
|   } | ||||
| } | ||||
| template <bool PerChannel, typename scalar_t> | ||||
| void static_quant_epilogue(const float* input, scalar_t* output, | ||||
|                            const float a_scale, const float* b_scale, | ||||
|                            const int32_t* azp_with_adj, const int num_tokens, | ||||
|                            const int hidden_size) { | ||||
|   CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) | ||||
|   using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type; | ||||
|   using azp_adj_load_vec_t = | ||||
|       typename KernelVecType<scalar_t>::azp_adj_load_vec_type; | ||||
|   using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type; | ||||
|   constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; | ||||
|  | ||||
|   #pragma omp parallel for | ||||
|   for (int i = 0; i < num_tokens; ++i) { | ||||
|     cvt_vec_t a_scale_vec(a_scale); | ||||
|     cvt_vec_t b_scale_vec(*b_scale); | ||||
|     cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; | ||||
|  | ||||
|     int j = 0; | ||||
|     for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { | ||||
|       cvt_vec_t elems_fp32(input + i * hidden_size + j); | ||||
|       azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); | ||||
|       cvt_vec_t azp_adj_fp32(azp_adj_vec); | ||||
|  | ||||
|       if constexpr (PerChannel) { | ||||
|         b_scale_vec = cvt_vec_t(b_scale + j); | ||||
|         scale_vec = b_scale_vec * a_scale_vec; | ||||
|       } | ||||
|       elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; | ||||
|       load_vec_t elems_out(elems_fp32); | ||||
|       elems_out.save(output + i * hidden_size + j); | ||||
|     } | ||||
|  | ||||
|     cvt_vec_t elems_fp32(input + i * hidden_size + j); | ||||
|     azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); | ||||
|     cvt_vec_t azp_adj_fp32(azp_adj_vec); | ||||
|  | ||||
|     if constexpr (PerChannel) { | ||||
|       b_scale_vec = cvt_vec_t(b_scale + j); | ||||
|       scale_vec = b_scale_vec * a_scale_vec; | ||||
|     } | ||||
|  | ||||
|     elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; | ||||
|  | ||||
|     load_vec_t elems_out(elems_fp32); | ||||
|     elems_out.save(output + i * hidden_size + j, hidden_size - j); | ||||
|   } | ||||
| } | ||||
| template <bool AZP, bool PerChannel, bool Bias, typename scalar_t> | ||||
| void dynamic_quant_epilogue(const float* input, scalar_t* output, | ||||
|                             const float* a_scale, const float* b_scale, | ||||
|                             const int32_t* azp, const int32_t* azp_adj, | ||||
|                             const scalar_t* bias, const int num_tokens, | ||||
|                             const int hidden_size) { | ||||
|   CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) | ||||
|   using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type; | ||||
|   using azp_adj_load_vec_t = | ||||
|       typename KernelVecType<scalar_t>::azp_adj_load_vec_type; | ||||
|   using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type; | ||||
|   constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; | ||||
|  | ||||
|   #pragma omp parallel for | ||||
|   for (int i = 0; i < num_tokens; ++i) { | ||||
|     int j = 0; | ||||
|     cvt_vec_t token_scale_vec(a_scale[i]); | ||||
|     cvt_vec_t token_zp_scale_vec; | ||||
|     if constexpr (AZP) { | ||||
|       float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]); | ||||
|       if constexpr (!PerChannel) { | ||||
|         zp_scale_val *= *b_scale; | ||||
|       } | ||||
|       token_zp_scale_vec = cvt_vec_t(zp_scale_val); | ||||
|     } | ||||
|  | ||||
|     for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { | ||||
|       cvt_vec_t elems_fp32(input + i * hidden_size + j); | ||||
|       elems_fp32 = elems_fp32 * token_scale_vec; | ||||
|  | ||||
|       if constexpr (AZP) { | ||||
|         azp_adj_load_vec_t azp_adj_vec(azp_adj + j); | ||||
|         cvt_vec_t azp_adj_fp32(azp_adj_vec); | ||||
|         azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; | ||||
|  | ||||
|         if constexpr (PerChannel) { | ||||
|           cvt_vec_t b_scale_vec(b_scale + j); | ||||
|           azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; | ||||
|         } | ||||
|  | ||||
|         elems_fp32 = elems_fp32 - azp_adj_fp32; | ||||
|       } | ||||
|  | ||||
|       if constexpr (Bias) { | ||||
|         load_vec_t bias_vec(bias + j); | ||||
|         cvt_vec_t bias_vec_fp32(bias_vec); | ||||
|         elems_fp32 = elems_fp32 + bias_vec_fp32; | ||||
|       } | ||||
|  | ||||
|       load_vec_t elems_out(elems_fp32); | ||||
|       elems_out.save(output + i * hidden_size + j); | ||||
|     } | ||||
|  | ||||
|     cvt_vec_t elems_fp32(input + i * hidden_size + j); | ||||
|     elems_fp32 = elems_fp32 * token_scale_vec; | ||||
|  | ||||
|     if constexpr (AZP) { | ||||
|       azp_adj_load_vec_t azp_adj_vec(azp_adj + j); | ||||
|       cvt_vec_t azp_adj_fp32(azp_adj_vec); | ||||
|       azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; | ||||
|  | ||||
|       if constexpr (PerChannel) { | ||||
|         cvt_vec_t b_scale_vec(b_scale + j); | ||||
|         azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; | ||||
|       } | ||||
|  | ||||
|       elems_fp32 = elems_fp32 - azp_adj_fp32; | ||||
|     } | ||||
|  | ||||
|     if constexpr (Bias) { | ||||
|       load_vec_t bias_vec(bias + j); | ||||
|       cvt_vec_t bias_vec_fp32(bias_vec); | ||||
|       elems_fp32 = elems_fp32 + bias_vec_fp32; | ||||
|     } | ||||
|  | ||||
|     load_vec_t elems_out(elems_fp32); | ||||
|     elems_out.save(output + i * hidden_size + j, hidden_size - j); | ||||
|   } | ||||
| } | ||||
| #else | ||||
| template <typename scalar_t> | ||||
| void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, | ||||
|                                    const float* scale, const int32_t* azp, | ||||
|                                    const int num_tokens, | ||||
|                                    const int hidden_size) { | ||||
|   TORCH_CHECK(false, | ||||
|               "static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 " | ||||
|               "support.") | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, | ||||
|                                     float* scale, int32_t* azp, | ||||
|                                     const int num_tokens, | ||||
|                                     const int hidden_size) { | ||||
|   TORCH_CHECK(false, | ||||
|               "dynamic_scaled_int8_quant_impl requires " | ||||
|               "AVX512/powerpc64/AArch64 support.") | ||||
| } | ||||
|  | ||||
| template <bool PerChannel, typename scalar_t> | ||||
| void static_quant_epilogue(const float* input, scalar_t* output, | ||||
|                            const float a_scale, const float* b_scale, | ||||
|                            const int32_t* azp_with_adj, const int num_tokens, | ||||
|                            const int hidden_size) { | ||||
|   TORCH_CHECK( | ||||
|       false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.") | ||||
| } | ||||
|  | ||||
| template <typename scalar_t> | ||||
| void dynamic_quant_epilogue(const float* input, scalar_t* output, | ||||
|                             const float* a_scale, const float* b_scale, | ||||
|                             const int32_t* azp, const int32_t* azp_with_adj, | ||||
|                             const scalar_t* bias, const int num_tokens, | ||||
|                             const int hidden_size) { | ||||
|   TORCH_CHECK( | ||||
|       false, | ||||
|       "dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.") | ||||
| } | ||||
| #endif | ||||
| }  // namespace | ||||
|  | ||||
| void int8_scaled_mm(torch::Tensor& c,               // [M, OC], row-major | ||||
|                     const torch::Tensor& a,         // [M, IC], row-major | ||||
|                     const torch::Tensor& b,         // [IC, OC], column-major | ||||
|                     const torch::Tensor& a_scales,  // [1] or [M] | ||||
|                     const torch::Tensor& b_scales,  // [1] or [OC] | ||||
|                     const std::optional<torch::Tensor>& bias  // [OC] | ||||
| ) { | ||||
|   CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) | ||||
|   // Checks for conformality | ||||
|   TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, | ||||
|               "int8_scaled_mm only supports INT8 inputs.") | ||||
|   TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); | ||||
|   TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && | ||||
|               b.size(1) == c.size(1)); | ||||
|   TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); | ||||
|   TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); | ||||
|  | ||||
|   // Check for strides and alignment | ||||
|   TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1);  // Row-major | ||||
|   TORCH_CHECK(b.stride(0) == 1);                      // Column-major | ||||
|   TORCH_CHECK(c.stride(0) % 16 == 0 && | ||||
|               b.stride(1) % 16 == 0);  // 16 Byte Alignment | ||||
|   TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); | ||||
|  | ||||
|   if (bias) { | ||||
|     TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && | ||||
|                 bias->dim() == 1); | ||||
|   } | ||||
|  | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] { | ||||
|     if (a_scales.numel() != 1) { | ||||
|       // per-token | ||||
|       // Note: oneDNN doesn't support per-token activation quantization | ||||
|       // Ideally we want to fuse the GEMM and the scale procedure with oneDNN | ||||
|       // JIT, the intermediate data is cached in registers or L1. But for now | ||||
|       // the oneDNN GEMM code generation only supports two quantization | ||||
|       // patterns: per-tensor or per-output-channel of weight. | ||||
|       // So we have to apply the per-token scale with a 'epilogue'. In C=s_a * | ||||
|       // s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN | ||||
|       // GEMM, then the per-token scale (and bias) is applied with the epilogue | ||||
|       // C=s_a * C_inter + bias. | ||||
|       torch::Tensor tmp_fp32_out = | ||||
|           torch::empty_like(c, ::at::ScalarType::Float); | ||||
|       // Compute C_inter=s_b * (A@B) | ||||
|       DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>( | ||||
|           a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), | ||||
|           tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1), | ||||
|           a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel()); | ||||
|       if (bias.has_value()) { | ||||
|         // Compute C=s_a * C_inter + bias | ||||
|         dynamic_quant_epilogue<false, true, true>( | ||||
|             tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(), | ||||
|             a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, | ||||
|             bias->data_ptr<scalar_t>(), c.size(0), c.size(1)); | ||||
|       } else { | ||||
|         // Compute C=s_a * C_inter | ||||
|         dynamic_quant_epilogue<false, true, false, scalar_t>( | ||||
|             tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(), | ||||
|             a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, nullptr, | ||||
|             c.size(0), c.size(1)); | ||||
|       } | ||||
|     } else { | ||||
|       // per-tensor | ||||
|       if (bias.has_value()) { | ||||
|         // Compute C=s_a * s_b * (A@B) + bias | ||||
|         DNNLPrimitiveHelper<false>::gemm_s8s8_jit( | ||||
|             a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(), | ||||
|             bias->data_ptr<scalar_t>(), a.size(0), b.size(1), a.size(1), | ||||
|             a_scales.data_ptr<float>(), b_scales.data_ptr<float>(), | ||||
|             a_scales.numel(), b_scales.numel()); | ||||
|       } else { | ||||
|         // Compute C=s_a * s_b * (A@B) | ||||
|         DNNLPrimitiveHelper<false>::gemm_s8s8_jit<scalar_t, void>( | ||||
|             a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(), | ||||
|             nullptr, a.size(0), b.size(1), a.size(1), | ||||
|             a_scales.data_ptr<float>(), b_scales.data_ptr<float>(), | ||||
|             a_scales.numel(), b_scales.numel()); | ||||
|       } | ||||
|     } | ||||
|   }); | ||||
| } | ||||
|  | ||||
| void int8_scaled_mm_azp(torch::Tensor& c,        // [M, OC], row-major | ||||
|                         const torch::Tensor& a,  // [M, IC], row-major | ||||
|                         const torch::Tensor& b,  // [IC, OC], column-major | ||||
|                         const torch::Tensor& a_scales,            // [1] or [M] | ||||
|                         const torch::Tensor& b_scales,            // [1] or [OC] | ||||
|                         const torch::Tensor& azp_adj,             // [OC] | ||||
|                         const std::optional<torch::Tensor>& azp,  // [1] or [M] | ||||
|                         const std::optional<torch::Tensor>& bias  // [OC] | ||||
| ) { | ||||
|   CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp) | ||||
|   // Checks for conformality | ||||
|   TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, | ||||
|               "int8_scaled_mm_azp only supports INT8 inputs.") | ||||
|   TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); | ||||
|   TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && | ||||
|               b.size(1) == c.size(1)); | ||||
|   TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); | ||||
|   TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); | ||||
|  | ||||
|   // Check for strides and alignment | ||||
|   TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1);  // Row-major | ||||
|   TORCH_CHECK(b.stride(0) == 1);                      // Column-major | ||||
|   TORCH_CHECK(c.stride(0) % 16 == 0 && | ||||
|               b.stride(1) % 16 == 0);  // 16 Byte Alignment | ||||
|   TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); | ||||
|  | ||||
|   if (bias) { | ||||
|     TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous()); | ||||
|   } | ||||
|   if (azp) { | ||||
|     TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous()); | ||||
|   } | ||||
|   TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous()); | ||||
|  | ||||
|   // azp & bias types | ||||
|   TORCH_CHECK(azp_adj.dtype() == torch::kInt32); | ||||
|   TORCH_CHECK(!azp || azp->dtype() == torch::kInt32); | ||||
|   TORCH_CHECK(!bias || bias->dtype() == c.dtype(), | ||||
|               "currently bias dtype must match output dtype ", c.dtype()); | ||||
|  | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] { | ||||
|     torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); | ||||
|     if (a_scales.numel() != 1) { | ||||
|       // per-token | ||||
|       // Note: oneDNN doesn't support per-token activation quantization | ||||
|       // Compute C_inter=s_b * (A@B) | ||||
|       DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>( | ||||
|           a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), | ||||
|           tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1), | ||||
|           a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel()); | ||||
|       if (bias.has_value()) { | ||||
|         // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias | ||||
|         if (b_scales.numel() != 1) { | ||||
|           // Per-Channel | ||||
|           dynamic_quant_epilogue<true, true, true>( | ||||
|               tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(), | ||||
|               a_scales.data_ptr<float>(), b_scales.data_ptr<float>(), | ||||
|               azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(), | ||||
|               bias->data_ptr<scalar_t>(), c.size(0), c.size(1)); | ||||
|         } else { | ||||
|           // Per-Tensor | ||||
|           dynamic_quant_epilogue<true, false, true>( | ||||
|               tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(), | ||||
|               a_scales.data_ptr<float>(), b_scales.data_ptr<float>(), | ||||
|               azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(), | ||||
|               bias->data_ptr<scalar_t>(), c.size(0), c.size(1)); | ||||
|         } | ||||
|       } else { | ||||
|         // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj | ||||
|         if (b_scales.numel() != 1) { | ||||
|           // Per-Channel | ||||
|           dynamic_quant_epilogue<true, true, false, scalar_t>( | ||||
|               tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(), | ||||
|               a_scales.data_ptr<float>(), b_scales.data_ptr<float>(), | ||||
|               azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(), nullptr, | ||||
|               c.size(0), c.size(1)); | ||||
|         } else { | ||||
|           // Per-Tensor | ||||
|           dynamic_quant_epilogue<true, false, false, scalar_t>( | ||||
|               tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(), | ||||
|               a_scales.data_ptr<float>(), b_scales.data_ptr<float>(), | ||||
|               azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(), nullptr, | ||||
|               c.size(0), c.size(1)); | ||||
|         } | ||||
|       } | ||||
|     } else { | ||||
|       // per-tensor | ||||
|       if (bias.has_value()) { | ||||
|         // Compute C_inter=s_a * s_b * (A@B) + bias | ||||
|         DNNLPrimitiveHelper<false>::gemm_s8s8_jit( | ||||
|             a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), | ||||
|             tmp_fp32_out.data_ptr<float>(), bias->data_ptr<scalar_t>(), | ||||
|             a.size(0), b.size(1), a.size(1), a_scales.data_ptr<float>(), | ||||
|             b_scales.data_ptr<float>(), a_scales.numel(), b_scales.numel()); | ||||
|       } else { | ||||
|         // Compute C_inter=s_a * s_b * (A@B) | ||||
|         DNNLPrimitiveHelper<false>::gemm_s8s8_jit<float, void>( | ||||
|             a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), | ||||
|             tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1), | ||||
|             a.size(1), a_scales.data_ptr<float>(), b_scales.data_ptr<float>(), | ||||
|             a_scales.numel(), b_scales.numel()); | ||||
|       } | ||||
|  | ||||
|       // Compute C=C_inter - s_a * s_b * azp_adj | ||||
|       if (b_scales.numel() != 1) { | ||||
|         // Per-Channel | ||||
|         static_quant_epilogue<true>( | ||||
|             tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(), | ||||
|             *a_scales.data_ptr<float>(), b_scales.data_ptr<float>(), | ||||
|             azp_adj.data_ptr<int32_t>(), a.size(0), b.size(1)); | ||||
|       } else { | ||||
|         // Per-Tensor | ||||
|         static_quant_epilogue<false>( | ||||
|             tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(), | ||||
|             *a_scales.data_ptr<float>(), b_scales.data_ptr<float>(), | ||||
|             azp_adj.data_ptr<int32_t>(), a.size(0), b.size(1)); | ||||
|       } | ||||
|     } | ||||
|   }); | ||||
| } | ||||
|  | ||||
| // static-per-tensor quantization. | ||||
| void static_scaled_int8_quant(torch::Tensor& out,          // [..., hidden_size] | ||||
|                               const torch::Tensor& input,  // [..., hidden_size] | ||||
|                               const torch::Tensor& scale, | ||||
|                               std::optional<torch::Tensor> const& azp) { | ||||
|   CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) | ||||
|   TORCH_CHECK(input.is_contiguous()); | ||||
|   TORCH_CHECK(out.is_contiguous()); | ||||
|   TORCH_CHECK(scale.numel() == 1); | ||||
|   TORCH_CHECK(!azp.has_value() || azp->numel() == 1); | ||||
|  | ||||
|   const int hidden_size = input.size(-1); | ||||
|   const int num_tokens = input.numel() / hidden_size; | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|       input.scalar_type(), "static_scaled_int8_quant_impl", [&] { | ||||
|         if (azp.has_value()) { | ||||
|           static_scaled_int8_quant_impl<true>( | ||||
|               input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), | ||||
|               scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens, | ||||
|               hidden_size); | ||||
|         } else { | ||||
|           static_scaled_int8_quant_impl<false>( | ||||
|               input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), | ||||
|               scale.data_ptr<float>(), nullptr, num_tokens, hidden_size); | ||||
|         } | ||||
|       }); | ||||
| } | ||||
|  | ||||
| // dynamic-per-token quantization. | ||||
| void dynamic_scaled_int8_quant( | ||||
|     torch::Tensor& out,          // [..., hidden_size] | ||||
|     const torch::Tensor& input,  // [..., hidden_size] | ||||
|     torch::Tensor& scale,        // [..., 1] | ||||
|     std::optional<torch::Tensor> const& azp) { | ||||
|   CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) | ||||
|   TORCH_CHECK(input.is_contiguous()); | ||||
|   TORCH_CHECK(out.is_contiguous()); | ||||
|  | ||||
|   int const hidden_size = input.size(-1); | ||||
|   int const num_tokens = input.numel() / hidden_size; | ||||
|   VLLM_DISPATCH_FLOATING_TYPES( | ||||
|       input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { | ||||
|         if (azp.has_value()) { | ||||
|           dynamic_scaled_int8_quant_impl<true>( | ||||
|               input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), | ||||
|               scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens, | ||||
|               hidden_size); | ||||
|         } else { | ||||
|           dynamic_scaled_int8_quant_impl<false>( | ||||
|               input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(), | ||||
|               scale.data_ptr<float>(), nullptr, num_tokens, hidden_size); | ||||
|         } | ||||
|       }); | ||||
| } | ||||
|  | ||||
| #if defined(__powerpc64__) | ||||
| void int8_scaled_mm_ppc64le(torch::Tensor& c,        // [M, OC], row-major | ||||
|                             const torch::Tensor& a,  // [M, IC], row-major | ||||
|                             const torch::Tensor& b,  // [IC, OC], column-major | ||||
|                             const torch::Tensor& a_scales, | ||||
|                             const torch::Tensor& b_scales, | ||||
|                             const std::optional<torch::Tensor>& bias  // [OC] | ||||
| ) { | ||||
|   CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) | ||||
|   // Checks for conformality | ||||
|   TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, | ||||
|               "int8_scaled_mm_ppc64le only supports INT8 inputs."); | ||||
|   TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); | ||||
|   TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && | ||||
|               b.size(1) == c.size(1)); | ||||
|   // We dont need this | ||||
|   TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); | ||||
|   TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); | ||||
|  | ||||
|   // Check for strides and alignment | ||||
|   TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1);  // Row-major | ||||
|   TORCH_CHECK(b.stride(0) == 1);                      // Column-major | ||||
|   TORCH_CHECK(c.stride(0) % 16 == 0 && | ||||
|               b.stride(1) % 16 == 0);  // 16 Byte Alignment | ||||
|   TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); | ||||
|  | ||||
|   if (bias) { | ||||
|     TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && | ||||
|                 bias->dim() == 1); | ||||
|   } | ||||
|   VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] { | ||||
|     torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); | ||||
|     // Compute C_inter=s_b * (A@B) | ||||
|     DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>( | ||||
|         a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), | ||||
|         tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1), | ||||
|         a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel()); | ||||
|     if (bias.has_value()) { | ||||
|       // Compute C=s_a * C_inter + bias | ||||
|       dynamic_quant_epilogue<false, true, true>( | ||||
|           tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(), | ||||
|           a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, | ||||
|           bias->data_ptr<scalar_t>(), c.size(0), c.size(1)); | ||||
|     } else { | ||||
|       // Compute C=s_a * C_inter | ||||
|       dynamic_quant_epilogue<false, true, false, scalar_t>( | ||||
|           tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(), | ||||
|           a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, nullptr, | ||||
|           c.size(0), c.size(1)); | ||||
|     } | ||||
|   }); | ||||
| } | ||||
|  | ||||
| #endif | ||||
| @ -6,25 +6,20 @@ | ||||
|  | ||||
| std::string init_cpu_threads_env(const std::string& cpu_ids); | ||||
|  | ||||
| void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a, | ||||
|                     const torch::Tensor& b, const torch::Tensor& a_scales, | ||||
|                     const torch::Tensor& b_scales, | ||||
|                     const std::optional<torch::Tensor>& bias); | ||||
| void release_dnnl_matmul_handler(int64_t handler); | ||||
|  | ||||
| void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, | ||||
|                         const torch::Tensor& b, const torch::Tensor& a_scales, | ||||
|                         const torch::Tensor& b_scales, | ||||
|                         const torch::Tensor& azp_adj, | ||||
|                         const std::optional<torch::Tensor>& azp, | ||||
|                         const std::optional<torch::Tensor>& bias); | ||||
| int64_t create_onednn_scaled_mm_handler(const torch::Tensor& b, | ||||
|                                         const torch::Tensor& b_scales, | ||||
|                                         at::ScalarType output_type, | ||||
|                                         bool dynamic_act_quant, bool use_azp, | ||||
|                                         int64_t primitive_cache_size); | ||||
|  | ||||
| #if defined(__powerpc64__) | ||||
| void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a, | ||||
|                             const torch::Tensor& b, | ||||
|                             const torch::Tensor& a_scales, | ||||
|                             const torch::Tensor& b_scales, | ||||
|                             const std::optional<torch::Tensor>& bias); | ||||
| #endif | ||||
| void onednn_scaled_mm(torch::Tensor& c, const torch::Tensor& a, | ||||
|                       const torch::Tensor& a_scales, | ||||
|                       const std::optional<torch::Tensor>& azp, | ||||
|                       const std::optional<torch::Tensor>& azp_adj, | ||||
|                       const std::optional<torch::Tensor>& bias, | ||||
|                       int64_t handler); | ||||
|  | ||||
| void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, | ||||
|                         torch::Tensor& kv_cache, double scale, | ||||
| @ -151,8 +146,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | ||||
|   ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); | ||||
|  | ||||
|   // Quantization | ||||
| #if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) | ||||
| #if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) || \ | ||||
|     defined(__powerpc64__) | ||||
|   at::Tag stride_tag = at::Tag::needs_fixed_stride_order; | ||||
|   // Helper function to release oneDNN handlers | ||||
|   ops.def("release_dnnl_matmul_handler(int handler) -> ()", | ||||
|           &release_dnnl_matmul_handler); | ||||
|  | ||||
|   // Create oneDNN W8A8 handler | ||||
|   ops.def( | ||||
|       "create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType " | ||||
|       "output_type, bool dynamic_act_quant, bool use_azp, int " | ||||
|       "primitive_cache_size) -> int", | ||||
|       &create_onednn_scaled_mm_handler); | ||||
|  | ||||
|   // oneDNN scaled_mm for W8A8 with static per-tensor activation quantization | ||||
|   ops.def( | ||||
|       "onednn_scaled_mm(Tensor! c, Tensor a, Tensor a_scales, Tensor? azp, " | ||||
|       "Tensor? azp_adj, Tensor? bias, int handler) -> ()"); | ||||
|   ops.impl("onednn_scaled_mm", torch::kCPU, &onednn_scaled_mm); | ||||
|  | ||||
|   // Compute int8 quantized tensor for given scaling factor. | ||||
|   ops.def( | ||||
| @ -168,50 +180,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | ||||
|       {stride_tag}); | ||||
|   ops.impl("dynamic_scaled_int8_quant", torch::kCPU, | ||||
|            &dynamic_scaled_int8_quant); | ||||
|   // W8A8 GEMM, supporting symmetric per-tensor or per-row/column | ||||
|   // quantization. | ||||
|   ops.def( | ||||
|       "cutlass_scaled_mm(Tensor! out, Tensor a," | ||||
|       "                  Tensor b, Tensor a_scales," | ||||
|       "                  Tensor b_scales, Tensor? bias) -> ()", | ||||
|       {stride_tag}); | ||||
|   ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); | ||||
|   // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column | ||||
|   // quantization. | ||||
|   ops.def( | ||||
|       "cutlass_scaled_mm_azp(Tensor! out, Tensor a," | ||||
|       "                  Tensor b, Tensor a_scales," | ||||
|       "                  Tensor b_scales, Tensor azp_adj," | ||||
|       "                  Tensor? azp, Tensor? bias) -> ()", | ||||
|       {stride_tag}); | ||||
|   ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); | ||||
| #elif defined(__powerpc64__) | ||||
|   // Compute int8 quantized tensor for given scaling factor. | ||||
|   ops.def( | ||||
|       "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," | ||||
|       "Tensor? azp) -> ()"); | ||||
|   ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); | ||||
|  | ||||
|   // Compute int8 quantized tensor and scaling factor | ||||
|   ops.def( | ||||
|       "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " | ||||
|       "Tensor!? azp) -> ()"); | ||||
|   ops.impl("dynamic_scaled_int8_quant", torch::kCPU, | ||||
|            &dynamic_scaled_int8_quant); | ||||
|   // W8A8 GEMM, supporting symmetric quantization. | ||||
|   ops.def( | ||||
|       "cutlass_scaled_mm(Tensor! out, Tensor a," | ||||
|       "                  Tensor b, Tensor a_scales," | ||||
|       "                  Tensor b_scales, Tensor? bias) -> ()"); | ||||
|   ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le); | ||||
|   // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column | ||||
|   // quantization. | ||||
|   ops.def( | ||||
|       "cutlass_scaled_mm_azp(Tensor! out, Tensor a," | ||||
|       "                  Tensor b, Tensor a_scales," | ||||
|       "                  Tensor b_scales, Tensor azp_adj," | ||||
|       "                  Tensor? azp, Tensor? bias) -> ()"); | ||||
|   ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); | ||||
| #endif | ||||
|  | ||||
| // SHM CCL | ||||
|  | ||||
							
								
								
									
										757
									
								
								csrc/moe/grouped_topk_kernels.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										757
									
								
								csrc/moe/grouped_topk_kernels.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,757 @@ | ||||
| /* | ||||
|  * Adapted from | ||||
|  * https://github.com/NVIDIA/TensorRT-LLM/blob/v0.21.0/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu | ||||
|  * Copyright (c) 2025, The vLLM team. | ||||
|  * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & | ||||
|  * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  * http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| #include <c10/cuda/CUDAStream.h> | ||||
| #include <torch/all.h> | ||||
| #include <cuda_fp16.h> | ||||
| #include <cuda_bf16.h> | ||||
| #include <cooperative_groups.h> | ||||
| #include <cooperative_groups/reduce.h> | ||||
| namespace cg = cooperative_groups; | ||||
|  | ||||
| namespace vllm { | ||||
| namespace moe { | ||||
|  | ||||
| constexpr unsigned FULL_WARP_MASK = 0xffffffff; | ||||
| constexpr int32_t WARP_SIZE = 32; | ||||
| constexpr int32_t BLOCK_SIZE = 512; | ||||
| constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; | ||||
|  | ||||
| namespace warp_topk { | ||||
|  | ||||
| template <int size, typename T> | ||||
| __host__ __device__ constexpr T round_up_to_multiple_of(T len) { | ||||
|   if (len == 0) { | ||||
|     return 0; | ||||
|   } | ||||
|   return ((len - 1) / size + 1) * size; | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| constexpr __host__ __device__ bool isPowerOf2(T v) { | ||||
|   return (v && !(v & (v - 1))); | ||||
| } | ||||
|  | ||||
| template <bool greater, typename T> | ||||
| __forceinline__ __device__ bool is_better_than(T val, T baseline) { | ||||
|   return (val > baseline && greater) || (val < baseline && !greater); | ||||
| } | ||||
|  | ||||
| template <bool greater, typename T, typename idxT> | ||||
| __forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index, | ||||
|                                                idxT baseline_index) { | ||||
|   bool res = (val > baseline && greater) || (val < baseline && !greater); | ||||
|   if (val == baseline) { | ||||
|     res = (index < baseline_index && greater) || | ||||
|           (index < baseline_index && !greater); | ||||
|   } | ||||
|   return res; | ||||
| } | ||||
|  | ||||
| template <typename T, typename idxT> | ||||
| int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) { | ||||
|   int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k; | ||||
|   int64_t n = std::max<int>(num_of_warp / 2 * k, num_of_warp * WARP_SIZE); | ||||
|   return max(cache_topk, | ||||
|              round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT)); | ||||
| } | ||||
|  | ||||
| template <int size, bool ascending, bool reverse, typename T, typename idxT, | ||||
|           bool is_stable> | ||||
| struct BitonicMerge { | ||||
|   // input should be a bitonic sequence, and sort it to be a monotonic sequence | ||||
|   __device__ static void merge(T* __restrict__ val_arr, | ||||
|                                idxT* __restrict__ idx_arr) { | ||||
|     static_assert(isPowerOf2(size)); | ||||
|     static_assert(size >= 2 * WARP_SIZE); | ||||
|     constexpr int arr_len = size / WARP_SIZE; | ||||
|  | ||||
|     constexpr int stride = arr_len / 2; | ||||
|     for (int i = 0; i < stride; ++i) { | ||||
|       int const other_i = i + stride; | ||||
|       T& val = val_arr[i]; | ||||
|       T& other_val = val_arr[other_i]; | ||||
|       bool is_better; | ||||
|       if constexpr (is_stable) { | ||||
|         is_better = is_better_than<ascending>(val, other_val, idx_arr[i], | ||||
|                                               idx_arr[other_i]); | ||||
|       } else { | ||||
|         is_better = is_better_than<ascending>(val, other_val); | ||||
|       } | ||||
|  | ||||
|       if (is_better) { | ||||
|         T tmp = val; | ||||
|         val = other_val; | ||||
|         other_val = tmp; | ||||
|  | ||||
|         idxT tmp2 = idx_arr[i]; | ||||
|         idx_arr[i] = idx_arr[other_i]; | ||||
|         idx_arr[other_i] = tmp2; | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge( | ||||
|         val_arr, idx_arr); | ||||
|     BitonicMerge<size / 2, ascending, reverse, T, idxT, is_stable>::merge( | ||||
|         val_arr + arr_len / 2, idx_arr + arr_len / 2); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <int size, bool ascending, typename T, typename idxT, bool is_stable> | ||||
| struct BitonicSort { | ||||
|   __device__ static void sort(T* __restrict__ val_arr, | ||||
|                               idxT* __restrict__ idx_arr) { | ||||
|     static_assert(isPowerOf2(size)); | ||||
|     static_assert(size >= 2 * WARP_SIZE); | ||||
|     constexpr int arr_len = size / WARP_SIZE; | ||||
|  | ||||
|     BitonicSort<size / 2, true, T, idxT, is_stable>::sort(val_arr, idx_arr); | ||||
|     BitonicSort<size / 2, false, T, idxT, is_stable>::sort( | ||||
|         val_arr + arr_len / 2, idx_arr + arr_len / 2); | ||||
|     BitonicMerge<size, ascending, ascending, T, idxT, is_stable>::merge( | ||||
|         val_arr, idx_arr); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <bool ascending, typename T, typename idxT, bool is_stable> | ||||
| struct BitonicSort<32, ascending, T, idxT, is_stable> { | ||||
|   __device__ static void sort(T* __restrict__ val_arr, | ||||
|                               idxT* __restrict__ idx_arr) { | ||||
|     int const lane = threadIdx.x % WARP_SIZE; | ||||
|  | ||||
|     // ascending doesn't matter before merging since all we need is a bitonic | ||||
|     // sequence | ||||
|     for (int stage = 0; stage < 4; ++stage) { | ||||
|       for (int stride = (1 << stage); stride > 0; stride /= 2) { | ||||
|         bool reverse = (lane >> stage) & 2; | ||||
|         bool is_second = lane & stride; | ||||
|  | ||||
|         T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride); | ||||
|         idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride); | ||||
|  | ||||
|         bool is_better; | ||||
|         if constexpr (is_stable) { | ||||
|           if constexpr (ascending) { | ||||
|             is_better = ((*val_arr > other) || | ||||
|                          ((*val_arr == other) && (*idx_arr < other_idx))) != | ||||
|                         (reverse != is_second); | ||||
|           } else { | ||||
|             is_better = ((*val_arr > other) || | ||||
|                          ((*val_arr == other) && (*idx_arr > other_idx))) != | ||||
|                         (reverse != is_second); | ||||
|           } | ||||
|         } else { | ||||
|           is_better = (*val_arr != other && | ||||
|                        (*val_arr > other) != (reverse != is_second)); | ||||
|         } | ||||
|         if (is_better) { | ||||
|           *val_arr = other; | ||||
|           *idx_arr = other_idx; | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr, | ||||
|                                                                       idx_arr); | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <bool ascending, bool reverse, typename T, typename idxT, | ||||
|           bool is_stable> | ||||
| struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> { | ||||
|   __device__ static void merge(T* __restrict__ val_arr, | ||||
|                                idxT* __restrict__ idx_arr) { | ||||
|     int const lane = threadIdx.x % WARP_SIZE; | ||||
|     for (int stride = WARP_SIZE / 2; stride > 0; stride /= 2) { | ||||
|       bool is_second = lane & stride; | ||||
|       T& val = *val_arr; | ||||
|       T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride); | ||||
|       idxT& idx = *idx_arr; | ||||
|       idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride); | ||||
|  | ||||
|       bool is_better; | ||||
|       if constexpr (is_stable) { | ||||
|         if constexpr (ascending) { | ||||
|           is_better = ((*val_arr > other) || | ||||
|                        ((*val_arr == other) && (*idx_arr < other_idx))) == | ||||
|                       (reverse != is_second);  // for min | ||||
|         } else { | ||||
|           is_better = ((*val_arr > other) || | ||||
|                        ((*val_arr == other) && (*idx_arr > other_idx))) == | ||||
|                       (reverse != is_second);  // for max | ||||
|         } | ||||
|       } else { | ||||
|         is_better = | ||||
|             (val != other && ((val > other) == (ascending != is_second))); | ||||
|       } | ||||
|  | ||||
|       if (is_better) { | ||||
|         val = other; | ||||
|         idx = other_idx; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| }; | ||||
|  | ||||
| template <int capacity, bool greater, typename T, typename idxT, bool is_stable> | ||||
| class WarpSort { | ||||
|  public: | ||||
|   __device__ WarpSort(idxT k, T dummy) | ||||
|       : lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) { | ||||
|     static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity)); | ||||
|  | ||||
|     for (int i = 0; i < max_arr_len_; ++i) { | ||||
|       val_arr_[i] = dummy_; | ||||
|       idx_arr_[i] = 0; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   // load and merge k sorted values | ||||
|   __device__ void load_sorted(T const* __restrict__ in, | ||||
|                               idxT const* __restrict__ in_idx, idxT start) { | ||||
|     idxT idx = start + WARP_SIZE - 1 - lane_; | ||||
|     for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) { | ||||
|       if (idx < start + k_) { | ||||
|         T t = in[idx]; | ||||
|         bool is_better; | ||||
|         if constexpr (is_stable) { | ||||
|           is_better = | ||||
|               is_better_than<greater>(t, val_arr_[i], in_idx[idx], idx_arr_[i]); | ||||
|         } else { | ||||
|           is_better = is_better_than<greater>(t, val_arr_[i]); | ||||
|         } | ||||
|         if (is_better) { | ||||
|           val_arr_[i] = t; | ||||
|           idx_arr_[i] = in_idx[idx]; | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge( | ||||
|         val_arr_, idx_arr_); | ||||
|   } | ||||
|  | ||||
|   __device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const { | ||||
|     for (int i = 0; i < max_arr_len_; ++i) { | ||||
|       idxT out_i = i * WARP_SIZE + lane_; | ||||
|       if (out_i < k_) { | ||||
|         out[out_i] = val_arr_[i]; | ||||
|         out_idx[out_i] = idx_arr_[i]; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   __device__ void dumpIdx(idxT* __restrict__ out_idx) const { | ||||
|     for (int i = 0; i < max_arr_len_; ++i) { | ||||
|       idxT out_i = i * WARP_SIZE + lane_; | ||||
|       if (out_i < k_) { | ||||
|         out_idx[out_i] = idx_arr_[i]; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|  | ||||
|  protected: | ||||
|   static constexpr int max_arr_len_ = capacity / WARP_SIZE; | ||||
|  | ||||
|   T val_arr_[max_arr_len_]; | ||||
|   idxT idx_arr_[max_arr_len_]; | ||||
|  | ||||
|   int const lane_; | ||||
|   idxT const k_; | ||||
|   T const dummy_; | ||||
|  | ||||
| };  // end class WarpSort | ||||
|  | ||||
| template <int capacity, bool greater, typename T, typename idxT, bool is_stable> | ||||
| class WarpSelect : public WarpSort<capacity, greater, T, idxT, is_stable> { | ||||
|  public: | ||||
|   __device__ WarpSelect(idxT k, T dummy) | ||||
|       : WarpSort<capacity, greater, T, idxT, is_stable>(k, dummy), | ||||
|         k_th_(dummy), | ||||
|         k_th_lane_((k - 1) % WARP_SIZE) { | ||||
|     extern __shared__ char smem_buf[];  // extern __shared__ T smem_buf[]; | ||||
|  | ||||
|     int const num_of_warp = blockDim.x / WARP_SIZE; | ||||
|     int const warp_id = threadIdx.x / WARP_SIZE; | ||||
|     val_smem_ = reinterpret_cast<T*>(smem_buf); | ||||
|     val_smem_ += warp_id * WARP_SIZE; | ||||
|     idx_smem_ = reinterpret_cast<idxT*>( | ||||
|         smem_buf + | ||||
|         round_up_to_multiple_of<256>(num_of_warp * sizeof(T) * WARP_SIZE)); | ||||
|     idx_smem_ += warp_id * WARP_SIZE; | ||||
|   } | ||||
|  | ||||
|   __device__ void add(T const* in, idxT start, idxT end) { | ||||
|     idxT const end_for_fullwarp = | ||||
|         round_up_to_multiple_of<WARP_SIZE>(end - start) + start; | ||||
|     for (idxT i = start + lane_; i < end_for_fullwarp; i += WARP_SIZE) { | ||||
|       T val = (i < end) ? in[i] : dummy_; | ||||
|       add(val, i); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   __device__ void add(T val, idxT idx) { | ||||
|     bool do_add; | ||||
|     if constexpr (is_stable) { | ||||
|       do_add = is_better_than<greater>(val, k_th_, idx, k_th_idx_); | ||||
|     } else { | ||||
|       do_add = is_better_than<greater>(val, k_th_); | ||||
|     } | ||||
|  | ||||
|     uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add); | ||||
|     if (mask == 0) { | ||||
|       return; | ||||
|     } | ||||
|  | ||||
|     int pos = smem_buf_len_ + __popc(mask & ((0x1u << lane_) - 1)); | ||||
|     if (do_add && pos < WARP_SIZE) { | ||||
|       val_smem_[pos] = val; | ||||
|       idx_smem_[pos] = idx; | ||||
|       do_add = false; | ||||
|     } | ||||
|     smem_buf_len_ += __popc(mask); | ||||
|     if (smem_buf_len_ >= WARP_SIZE) { | ||||
|       __syncwarp(); | ||||
|       merge_buf_(val_smem_[lane_], idx_smem_[lane_]); | ||||
|       smem_buf_len_ -= WARP_SIZE; | ||||
|     } | ||||
|     if (do_add) { | ||||
|       pos -= WARP_SIZE; | ||||
|       val_smem_[pos] = val; | ||||
|       idx_smem_[pos] = idx; | ||||
|     } | ||||
|     __syncwarp(); | ||||
|   } | ||||
|  | ||||
|   __device__ void done() { | ||||
|     if (smem_buf_len_) { | ||||
|       T val = (lane_ < smem_buf_len_) ? val_smem_[lane_] : dummy_; | ||||
|       idxT idx = (lane_ < smem_buf_len_) ? idx_smem_[lane_] : 0; | ||||
|       merge_buf_(val, idx); | ||||
|     } | ||||
|  | ||||
|     // after done(), smem is used for merging results among warps | ||||
|     __syncthreads(); | ||||
|   } | ||||
|  | ||||
|  private: | ||||
|   __device__ void set_k_th_() { | ||||
|     k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_); | ||||
|     if constexpr (is_stable) { | ||||
|       k_th_idx_ = | ||||
|           __shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   __device__ void merge_buf_(T val, idxT idx) { | ||||
|     BitonicSort<WARP_SIZE, greater, T, idxT, is_stable>::sort(&val, &idx); | ||||
|  | ||||
|     T& old = val_arr_[max_arr_len_ - 1]; | ||||
|  | ||||
|     bool is_better; | ||||
|     if constexpr (is_stable) { | ||||
|       is_better = | ||||
|           is_better_than<greater>(val, old, idx, idx_arr_[max_arr_len_ - 1]); | ||||
|     } else { | ||||
|       is_better = is_better_than<greater>(val, old); | ||||
|     } | ||||
|  | ||||
|     if (is_better) { | ||||
|       old = val; | ||||
|       idx_arr_[max_arr_len_ - 1] = idx; | ||||
|     } | ||||
|  | ||||
|     BitonicMerge<capacity, greater, !greater, T, idxT, is_stable>::merge( | ||||
|         val_arr_, idx_arr_); | ||||
|  | ||||
|     set_k_th_(); | ||||
|   } | ||||
|  | ||||
|   using WarpSort<capacity, greater, T, idxT, is_stable>::max_arr_len_; | ||||
|   using WarpSort<capacity, greater, T, idxT, is_stable>::val_arr_; | ||||
|   using WarpSort<capacity, greater, T, idxT, is_stable>::idx_arr_; | ||||
|   using WarpSort<capacity, greater, T, idxT, is_stable>::lane_; | ||||
|   using WarpSort<capacity, greater, T, idxT, is_stable>::k_; | ||||
|   using WarpSort<capacity, greater, T, idxT, is_stable>::dummy_; | ||||
|  | ||||
|   T* val_smem_; | ||||
|   idxT* idx_smem_; | ||||
|   int smem_buf_len_ = 0; | ||||
|  | ||||
|   T k_th_; | ||||
|   idxT k_th_idx_; | ||||
|   int const k_th_lane_; | ||||
| };  // end class WarpSelect | ||||
| }  // namespace warp_topk | ||||
|  | ||||
| template <typename T_OUT, typename T_IN> | ||||
| __device__ inline T_OUT cuda_cast(T_IN val) { | ||||
|   return val; | ||||
| } | ||||
|  | ||||
| template <> | ||||
| __device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) { | ||||
|   return __bfloat162float(val); | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| __device__ void topk_with_k2(T* output, T const* input, | ||||
|                              cg::thread_block_tile<32> const& tile, | ||||
|                              int32_t const lane_id, | ||||
|                              int const num_experts_per_group) { | ||||
|   // Get the top2 per thread | ||||
|   T largest = -INFINITY; | ||||
|   T second_largest = -INFINITY; | ||||
|  | ||||
|   if (num_experts_per_group > WARP_SIZE) { | ||||
|     for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { | ||||
|       T value = input[i]; | ||||
|       if (value > largest) { | ||||
|         second_largest = largest; | ||||
|         largest = value; | ||||
|       } else if (value > second_largest) { | ||||
|         second_largest = value; | ||||
|       } | ||||
|     } | ||||
|   } else { | ||||
|     for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { | ||||
|       largest = input[i]; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   __syncwarp();  // Ensure all threads have valid data before reduction | ||||
|   // Get the top2 warpwise | ||||
|   T max1 = cg::reduce(tile, largest, cg::greater<T>()); | ||||
|  | ||||
|   T max2 = max1; | ||||
|   bool equal_to_max1 = (max1 == largest); | ||||
|  | ||||
|   int count_max1 = __popc(__ballot_sync(FULL_WARP_MASK, equal_to_max1)); | ||||
|  | ||||
|   if (count_max1 == 1) { | ||||
|     largest = (largest == max1) ? second_largest : largest; | ||||
|     max2 = cg::reduce(tile, largest, cg::greater<T>()); | ||||
|   } | ||||
|  | ||||
|   if (lane_id == 0) { | ||||
|     *output = max1 + max2; | ||||
|   } | ||||
| } | ||||
|  | ||||
| template <typename T> | ||||
| __global__ void topk_with_k2_kernel(T* output, T* input, | ||||
|                                     int64_t const num_tokens, | ||||
|                                     int64_t const num_cases, | ||||
|                                     int64_t const n_group, | ||||
|                                     int64_t const num_experts_per_group) { | ||||
|   int32_t warp_id = threadIdx.x / WARP_SIZE; | ||||
|   int32_t lane_id = threadIdx.x % WARP_SIZE; | ||||
|  | ||||
|   int32_t case_id = blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id; | ||||
|   if (case_id < num_cases) { | ||||
|     input += case_id * num_experts_per_group; | ||||
|     output += case_id; | ||||
|  | ||||
|     cg::thread_block block = cg::this_thread_block(); | ||||
|     cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); | ||||
|  | ||||
| #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) | ||||
|     asm volatile("griddepcontrol.wait;"); | ||||
| #endif | ||||
|     topk_with_k2(output, input, tile, lane_id, num_experts_per_group); | ||||
|   } | ||||
| #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) | ||||
|   asm volatile("griddepcontrol.launch_dependents;"); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <typename T, typename IdxT> | ||||
| __global__ void group_idx_and_topk_idx_kernel( | ||||
|     T* scores, T const* group_scores, T* topk_values, IdxT* topk_indices, | ||||
|     T* scores_with_bias, int64_t const num_tokens, int64_t const n_group, | ||||
|     int64_t const topk_group, int64_t const topk, int64_t const num_experts, | ||||
|     int64_t const num_experts_per_group, bool renormalize, | ||||
|     double routed_scaling_factor) { | ||||
|   int32_t warp_id = threadIdx.x / WARP_SIZE; | ||||
|   int32_t lane_id = threadIdx.x % WARP_SIZE; | ||||
|   int32_t case_id = | ||||
|       blockIdx.x * NUM_WARPS_PER_BLOCK + warp_id;  // one per token | ||||
|   scores_with_bias += case_id * num_experts; | ||||
|   scores += case_id * num_experts; | ||||
|   group_scores += case_id * n_group; | ||||
|   topk_values += case_id * topk; | ||||
|   topk_indices += case_id * topk; | ||||
|  | ||||
|   int32_t align_num_experts_per_group = | ||||
|       warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group); | ||||
|  | ||||
|   cg::thread_block block = cg::this_thread_block(); | ||||
|   cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); | ||||
|  | ||||
|   extern __shared__ char smem_buf[];  // NOTE: reuse the shared memory here to | ||||
|                                       // store the target topk idx | ||||
|   int32_t* s_topk_idx = reinterpret_cast<int32_t*>(smem_buf); | ||||
|   T* s_topk_value = | ||||
|       reinterpret_cast<T*>(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) + | ||||
|       warp_id * topk; | ||||
|   s_topk_idx += warp_id * topk; | ||||
|  | ||||
|   T value = cuda::std::numeric_limits<T>::min(); | ||||
|   T topk_group_value = cuda::std::numeric_limits<T>::min(); | ||||
|   int32_t num_equalto_topkth_group; | ||||
|  | ||||
| #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) | ||||
|   asm volatile("griddepcontrol.wait;");  // I think all prolog can be put before | ||||
|                                          // acqbulk because it's ptr arithmetic | ||||
| #endif | ||||
|  | ||||
|   if (case_id < num_tokens) { | ||||
|     // calculate group_idx | ||||
|     int32_t target_num_min = WARP_SIZE - n_group + topk_group; | ||||
|     if (lane_id < n_group && | ||||
|         (isfinite(cuda_cast<float, T>( | ||||
|             group_scores[lane_id]))))  // The check is necessary to avoid | ||||
|                                        // abnormal input | ||||
|     { | ||||
|       value = group_scores[lane_id]; | ||||
|     } | ||||
|  | ||||
|     int count_equal_to_top_value = WARP_SIZE - n_group; | ||||
|     int pre_count_equal_to_top_value = 0; | ||||
|     // Use loop to find the largset top_group | ||||
|     while (count_equal_to_top_value < target_num_min) { | ||||
|       __syncwarp();  // Ensure all threads have valid data before reduction | ||||
|       topk_group_value = cg::reduce(tile, value, cg::greater<T>()); | ||||
|       if (value == topk_group_value) { | ||||
|         value = cuda::std::numeric_limits<T>::min(); | ||||
|       } | ||||
|       pre_count_equal_to_top_value = count_equal_to_top_value; | ||||
|       count_equal_to_top_value = __popc(__ballot_sync( | ||||
|           FULL_WARP_MASK, (value == cuda::std::numeric_limits<T>::min()))); | ||||
|     } | ||||
|     num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; | ||||
|   } | ||||
|   __syncthreads(); | ||||
|  | ||||
|   warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t, | ||||
|                         /* is_stable */ true> | ||||
|       queue((int32_t)topk, -INFINITY); | ||||
|  | ||||
|   int count_equalto_topkth_group = 0; | ||||
|   bool if_proceed_next_topk = | ||||
|       (topk_group_value != cuda::std::numeric_limits<T>::min()); | ||||
|   if (case_id < num_tokens && if_proceed_next_topk) { | ||||
|     for (int i_group = 0; i_group < n_group; i_group++) { | ||||
|       if ((group_scores[i_group] > topk_group_value) || | ||||
|           ((group_scores[i_group] == topk_group_value) && | ||||
|            (count_equalto_topkth_group < num_equalto_topkth_group))) { | ||||
|         int32_t offset = i_group * num_experts_per_group; | ||||
|         for (int32_t i = lane_id; i < align_num_experts_per_group; | ||||
|              i += WARP_SIZE) { | ||||
|           T candidates = | ||||
|               (i < num_experts_per_group) && isfinite(cuda_cast<float, T>( | ||||
|                                                  scores_with_bias[offset + i])) | ||||
|                   ? scores_with_bias[offset + i] | ||||
|                   : cuda::std::numeric_limits<T>::min(); | ||||
|           queue.add(candidates, offset + i); | ||||
|         } | ||||
|         if (group_scores[i_group] == topk_group_value) { | ||||
|           count_equalto_topkth_group++; | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     queue.done(); | ||||
|     __syncwarp(); | ||||
|     // Get the topk_idx | ||||
|     queue.dumpIdx(s_topk_idx); | ||||
|     __syncwarp(); | ||||
|   } | ||||
|  | ||||
|   // Load the valid score value | ||||
|   // Calculate the summation | ||||
|   float topk_sum = 1e-20; | ||||
|   if (case_id < num_tokens && if_proceed_next_topk) { | ||||
|     for (int i = lane_id; | ||||
|          i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk); | ||||
|          i += WARP_SIZE) { | ||||
|       T value = | ||||
|           i < topk | ||||
|               ? scores[s_topk_idx[i]] | ||||
|               : cuda_cast<T, float>(0.0f);  // Load the valid value of expert | ||||
|       if (i < topk) { | ||||
|         s_topk_value[i] = value; | ||||
|       } | ||||
|       topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>()); | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   __syncthreads(); | ||||
|  | ||||
|   if (case_id < num_tokens) { | ||||
|     if (if_proceed_next_topk) { | ||||
|       for (int i = lane_id; i < topk; i += WARP_SIZE) { | ||||
|         float value; | ||||
|         if (renormalize) { | ||||
|           value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum * | ||||
|                   routed_scaling_factor; | ||||
|         } else { | ||||
|           value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor; | ||||
|         } | ||||
|         topk_indices[i] = s_topk_idx[i]; | ||||
|         topk_values[i] = cuda_cast<T, float>(value); | ||||
|       } | ||||
|     } else { | ||||
|       for (int i = lane_id; i < topk; i += WARP_SIZE) { | ||||
|         topk_indices[i] = i; | ||||
|         topk_values[i] = cuda_cast<T, float>(1.0f / topk); | ||||
|       } | ||||
|     } | ||||
|     // Note: when if_proceed_next_topk==false, choose the first 8 experts as the | ||||
|     // default result. | ||||
|   } | ||||
| #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) | ||||
|   asm volatile("griddepcontrol.launch_dependents;"); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| template <typename T, typename IdxT> | ||||
| void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values, | ||||
|                    IdxT* topk_indices, T* scores_with_bias, | ||||
|                    int64_t const num_tokens, int64_t const num_experts, | ||||
|                    int64_t const n_group, int64_t const topk_group, | ||||
|                    int64_t const topk, bool const renormalize, | ||||
|                    double const routed_scaling_factor, bool enable_pdl = false, | ||||
|                    cudaStream_t const stream = 0) { | ||||
|   int64_t num_cases = num_tokens * n_group; | ||||
|   int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; | ||||
|   auto* kernel_instance1 = &topk_with_k2_kernel<T>; | ||||
|   cudaLaunchConfig_t config; | ||||
|   config.gridDim = topk_with_k2_num_blocks; | ||||
|   config.blockDim = BLOCK_SIZE; | ||||
|   config.dynamicSmemBytes = 0; | ||||
|   config.stream = stream; | ||||
|   cudaLaunchAttribute attrs[1]; | ||||
|   attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; | ||||
|   attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; | ||||
|   config.numAttrs = 1; | ||||
|   config.attrs = attrs; | ||||
|   cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias, | ||||
|                      num_tokens, num_cases, n_group, num_experts / n_group); | ||||
|  | ||||
|   int64_t topk_with_k_group_num_blocks = | ||||
|       (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; | ||||
|   size_t dynamic_smem_in_bytes = | ||||
|       warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK, | ||||
|                                                            topk); | ||||
|   auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>; | ||||
|   config.gridDim = topk_with_k_group_num_blocks; | ||||
|   config.blockDim = BLOCK_SIZE; | ||||
|   config.dynamicSmemBytes = dynamic_smem_in_bytes; | ||||
|   config.stream = stream; | ||||
|   attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; | ||||
|   attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; | ||||
|   config.numAttrs = 1; | ||||
|   config.attrs = attrs; | ||||
|   cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, | ||||
|                      topk_values, topk_indices, scores_with_bias, num_tokens, | ||||
|                      n_group, topk_group, topk, num_experts, | ||||
|                      num_experts / n_group, renormalize, routed_scaling_factor); | ||||
| } | ||||
|  | ||||
| #define INSTANTIATE_NOAUX_TC(T, IdxT)                                       \ | ||||
|   template void invokeNoAuxTc<T, IdxT>(                                     \ | ||||
|       T * scores, T * group_scores, T * topk_values, IdxT * topk_indices,   \ | ||||
|       T * scores_with_bias, int64_t const num_tokens,                       \ | ||||
|       int64_t const num_experts, int64_t const n_group,                     \ | ||||
|       int64_t const topk_group, int64_t const topk, bool const renormalize, \ | ||||
|       double const routed_scaling_factor, bool enable_pdl,                  \ | ||||
|       cudaStream_t const stream); | ||||
|  | ||||
| INSTANTIATE_NOAUX_TC(float, int32_t); | ||||
| INSTANTIATE_NOAUX_TC(half, int32_t); | ||||
| INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t); | ||||
| }  // end namespace moe | ||||
| }  // namespace vllm | ||||
|  | ||||
| std::tuple<torch::Tensor, torch::Tensor> grouped_topk( | ||||
|     torch::Tensor const& scores, torch::Tensor const& scores_with_bias, | ||||
|     int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, | ||||
|     double routed_scaling_factor) { | ||||
|   auto data_type = scores_with_bias.scalar_type(); | ||||
|   auto input_size = scores_with_bias.sizes(); | ||||
|   int64_t num_tokens = input_size[0]; | ||||
|   int64_t num_experts = input_size[1]; | ||||
|   TORCH_CHECK(input_size.size() == 2, "scores_with_bias must be a 2D Tensor"); | ||||
|   TORCH_CHECK(num_experts % n_group == 0, | ||||
|               "num_experts should be divisible by n_group"); | ||||
|   TORCH_CHECK(n_group <= 32, | ||||
|               "n_group should be smaller than or equal to 32 for now"); | ||||
|   TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now"); | ||||
|  | ||||
|   torch::Tensor group_scores = torch::empty( | ||||
|       {num_tokens, n_group}, torch::dtype(data_type).device(torch::kCUDA)); | ||||
|   torch::Tensor topk_values = torch::empty( | ||||
|       {num_tokens, topk}, torch::dtype(data_type).device(torch::kCUDA)); | ||||
|   torch::Tensor topk_indices = torch::empty( | ||||
|       {num_tokens, topk}, torch::dtype(torch::kInt32).device(torch::kCUDA)); | ||||
|  | ||||
|   auto stream = c10::cuda::getCurrentCUDAStream(scores_with_bias.get_device()); | ||||
|  | ||||
|   switch (data_type) { | ||||
|     case torch::kFloat16: | ||||
|       // Handle Float16 | ||||
|       vllm::moe::invokeNoAuxTc<half, int32_t>( | ||||
|           reinterpret_cast<half*>(scores.mutable_data_ptr()), | ||||
|           reinterpret_cast<half*>(group_scores.mutable_data_ptr()), | ||||
|           reinterpret_cast<half*>(topk_values.mutable_data_ptr()), | ||||
|           reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()), | ||||
|           reinterpret_cast<half*>(scores_with_bias.data_ptr()), num_tokens, | ||||
|           num_experts, n_group, topk_group, topk, renormalize, | ||||
|           routed_scaling_factor, false, stream); | ||||
|       break; | ||||
|     case torch::kFloat32: | ||||
|       // Handle Float32 | ||||
|       vllm::moe::invokeNoAuxTc<float, int32_t>( | ||||
|           reinterpret_cast<float*>(scores.mutable_data_ptr()), | ||||
|           reinterpret_cast<float*>(group_scores.mutable_data_ptr()), | ||||
|           reinterpret_cast<float*>(topk_values.mutable_data_ptr()), | ||||
|           reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()), | ||||
|           reinterpret_cast<float*>(scores_with_bias.data_ptr()), num_tokens, | ||||
|           num_experts, n_group, topk_group, topk, renormalize, | ||||
|           routed_scaling_factor, false, stream); | ||||
|       break; | ||||
|     case torch::kBFloat16: | ||||
|       // Handle BFloat16 | ||||
|       vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t>( | ||||
|           reinterpret_cast<__nv_bfloat16*>(scores.mutable_data_ptr()), | ||||
|           reinterpret_cast<__nv_bfloat16*>(group_scores.mutable_data_ptr()), | ||||
|           reinterpret_cast<__nv_bfloat16*>(topk_values.mutable_data_ptr()), | ||||
|           reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()), | ||||
|           reinterpret_cast<__nv_bfloat16*>(scores_with_bias.data_ptr()), | ||||
|           num_tokens, num_experts, n_group, topk_group, topk, renormalize, | ||||
|           routed_scaling_factor, false, stream); | ||||
|       break; | ||||
|     default: | ||||
|       // Handle other data types | ||||
|       throw std::invalid_argument( | ||||
|           "Invalid dtype, only supports float16, float32, and bfloat16"); | ||||
|       break; | ||||
|   } | ||||
|   return {topk_values, topk_indices}; | ||||
| } | ||||
| @ -22,6 +22,11 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, | ||||
|                              torch::Tensor num_tokens_post_pad, int64_t top_k, | ||||
|                              int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, | ||||
|                              int64_t BLOCK_SIZE_K, int64_t bit); | ||||
|  | ||||
| std::tuple<torch::Tensor, torch::Tensor> grouped_topk( | ||||
|     torch::Tensor const& scores, torch::Tensor const& scores_with_bias, | ||||
|     int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize, | ||||
|     double routed_scaling_factor); | ||||
| #endif | ||||
|  | ||||
| bool moe_permute_unpermute_supported(); | ||||
|  | ||||
| @ -45,8 +45,6 @@ void moe_permute( | ||||
|   auto copy_topk_ids = topk_ids.clone();  // copy topk_ids for preprocess | ||||
|   auto permuted_experts_id = torch::empty_like(topk_ids); | ||||
|   auto sorted_row_idx = torch::empty_like(inv_permuted_idx); | ||||
|   auto align_expert_first_token_offset = | ||||
|       torch::zeros_like(expert_first_token_offset); | ||||
|  | ||||
|   CubKeyValueSorter sorter{}; | ||||
|   int64_t* valid_num_ptr = nullptr; | ||||
| @ -85,12 +83,14 @@ void moe_permute( | ||||
|   }); | ||||
|  | ||||
|   // get m_indices and update expert_first_token_offset with align block | ||||
|   getMIndices(get_ptr<int64_t>(expert_first_token_offset), | ||||
|               get_ptr<int64_t>(align_expert_first_token_offset), | ||||
|               get_ptr<int>(m_indices), n_local_expert, align_block_size_value, | ||||
|               stream); | ||||
|   // this is only required for DeepGemm and not required for CUTLASS group gemm | ||||
|   if (align_block_size.has_value()) { | ||||
|     // update align_expert_first_token_offset | ||||
|     auto align_expert_first_token_offset = | ||||
|         torch::zeros_like(expert_first_token_offset); | ||||
|     getMIndices(get_ptr<int64_t>(expert_first_token_offset), | ||||
|                 get_ptr<int64_t>(align_expert_first_token_offset), | ||||
|                 get_ptr<int>(m_indices), n_local_expert, align_block_size_value, | ||||
|                 stream); | ||||
|     expert_first_token_offset.copy_(align_expert_first_token_offset); | ||||
|   } | ||||
| } | ||||
| @ -195,19 +195,14 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, | ||||
|                  torch::Tensor& expert_first_token_offset, | ||||
|                  torch::Tensor& src_row_id2dst_row_id_map, | ||||
|                  torch::Tensor& m_indices) { | ||||
|   TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); | ||||
|   TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0"); | ||||
| } | ||||
|  | ||||
| void moe_unpermute(const torch::Tensor& input, | ||||
|                    const torch::Tensor& topk_weights, torch::Tensor& topk_ids, | ||||
|                    const torch::Tensor& token_expert_indices, | ||||
|                    const std::optional<torch::Tensor>& expert_map, | ||||
|                    int64_t n_expert, int64_t n_local_expert, int64_t topk, | ||||
|                    const std::optional<int64_t>& align_block_size, | ||||
|                    torch::Tensor& permuted_input, | ||||
|                    torch::Tensor& expert_first_token_offset, | ||||
|                    torch::Tensor& src_row_id2dst_row_id_map, | ||||
|                    torch::Tensor& m_indices) { | ||||
| void moe_unpermute( | ||||
|     const torch::Tensor& permuted_hidden_states, | ||||
|     const torch::Tensor& topk_weights, const torch::Tensor& inv_permuted_idx, | ||||
|     const std::optional<torch::Tensor>& expert_first_token_offset, int64_t topk, | ||||
|     torch::Tensor& hidden_states) { | ||||
|   TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0"); | ||||
| } | ||||
|  | ||||
| @ -224,4 +219,4 @@ bool moe_permute_unpermute_supported() { | ||||
| TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { | ||||
|   m.impl("moe_permute", &moe_permute); | ||||
|   m.impl("moe_unpermute", &moe_unpermute); | ||||
| } | ||||
| } | ||||
| @ -78,6 +78,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { | ||||
|       "output_tensor) -> ()"); | ||||
|   m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows); | ||||
|  | ||||
|   // Apply grouped topk routing to select experts. | ||||
|   m.def( | ||||
|       "grouped_topk(Tensor scores, Tensor scores_with_bias, int n_group, int " | ||||
|       "topk_group, int topk, bool renormalize, float " | ||||
|       "routed_scaling_factor) -> (Tensor, Tensor)"); | ||||
|   m.impl("grouped_topk", torch::kCUDA, &grouped_topk); | ||||
| #endif | ||||
| } | ||||
|  | ||||
|  | ||||
| @ -229,6 +229,11 @@ void get_cutlass_moe_mm_data( | ||||
|     const int64_t num_experts, const int64_t n, const int64_t k, | ||||
|     const std::optional<torch::Tensor>& blockscale_offsets); | ||||
|  | ||||
| void get_cutlass_moe_mm_problem_sizes( | ||||
|     const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, | ||||
|     torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, | ||||
|     const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets); | ||||
|  | ||||
| void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, | ||||
|                                   torch::Tensor& problem_sizes1, | ||||
|                                   torch::Tensor& problem_sizes2, | ||||
|  | ||||
							
								
								
									
										418
									
								
								csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										418
									
								
								csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,418 @@ | ||||
| // | ||||
| // Based off of: | ||||
| //   https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu | ||||
| // | ||||
|  | ||||
| #include <ATen/cuda/CUDAContext.h> | ||||
| #include <c10/cuda/CUDAGuard.h> | ||||
| #include <torch/all.h> | ||||
| #include "cutlass_extensions/torch_utils.hpp" | ||||
|  | ||||
| #include "core/registration.h" | ||||
|  | ||||
| #include "cutlass/cutlass.h" | ||||
|  | ||||
| #include "cute/tensor.hpp" | ||||
| #include "cutlass/gemm/collective/collective_builder.hpp" | ||||
| #include "cutlass/epilogue/collective/collective_builder.hpp" | ||||
| #include "cutlass/gemm/device/gemm_universal_adapter.h" | ||||
| #include "cutlass/gemm/kernel/gemm_universal.hpp" | ||||
|  | ||||
| #include "cutlass/util/packed_stride.hpp" | ||||
| #include "cutlass/util/mixed_dtype_utils.hpp" | ||||
|  | ||||
| #include "cutlass_extensions/common.hpp" | ||||
| #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" | ||||
|  | ||||
| namespace vllm::cutlass_w4a8 { | ||||
|  | ||||
| using namespace cute; | ||||
|  | ||||
| // ------------------------------------------------------------------------------------- | ||||
| // Static configuration shared across all instantiations | ||||
| // ------------------------------------------------------------------------------------- | ||||
| using MmaType = cutlass::float_e4m3_t;  // A/scale element type | ||||
| using QuantType = cutlass::int4b_t;     // B element type (packed int4) | ||||
|  | ||||
| static int constexpr TileShapeK = 128 * 8 / sizeof_bits<MmaType>::value; | ||||
| static int constexpr ScalePackSize = 8;  // pack 8 scale elements together | ||||
| static int constexpr PackFactor = 8;     // 8 4-bit packed into int32 | ||||
|  | ||||
| // A matrix configuration | ||||
| using ElementA = MmaType;                   // Element type for A matrix operand | ||||
| using LayoutA = cutlass::layout::RowMajor;  // Layout type for A matrix operand | ||||
| using LayoutA_Transpose = | ||||
|     typename cutlass::layout::LayoutTranspose<LayoutA>::type; | ||||
| constexpr int AlignmentA = | ||||
|     128 / cutlass::sizeof_bits< | ||||
|               ElementA>::value;  // Memory access granularity/alignment of A | ||||
|                                  // matrix in units of elements (up to 16 bytes) | ||||
| using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>; | ||||
|  | ||||
| // B matrix configuration | ||||
| using ElementB = QuantType;  // Element type for B matrix operand | ||||
| using LayoutB = | ||||
|     cutlass::layout::ColumnMajor;  // Layout type for B matrix operand | ||||
| using LayoutB_Transpose = | ||||
|     typename cutlass::layout::LayoutTranspose<LayoutB>::type; | ||||
| constexpr int AlignmentB = | ||||
|     128 / cutlass::sizeof_bits< | ||||
|               ElementB>::value;  // Memory access granularity/alignment of B | ||||
|                                  // matrix in units of elements (up to 16 bytes) | ||||
| using StrideB = cutlass::detail::TagToStrideB_t<LayoutB>; | ||||
|  | ||||
| // Define the CuTe layout for reordered quantized tensor B | ||||
| // LayoutAtomQuant places values that will be read by the same thread in | ||||
| // contiguous locations in global memory. It specifies the reordering within a | ||||
| // single warp's fragment | ||||
| using LayoutAtomQuant = | ||||
|     decltype(cutlass::compute_memory_reordering_atom<MmaType>()); | ||||
| using LayoutB_Reordered = decltype(cute::tile_to_shape( | ||||
|     LayoutAtomQuant{}, Layout<Shape<int, int, int>, StrideB>{})); | ||||
|  | ||||
| // Group-wise scales | ||||
| using ElementScale = MmaType; | ||||
| using LayoutScale = cutlass::layout::RowMajor; | ||||
|  | ||||
| // Per-tok, per-chan scales | ||||
| using ElementSChannel = float; | ||||
|  | ||||
| // C/D matrix configuration | ||||
| using ElementC = | ||||
|     cutlass::bfloat16_t;  // Element type for C and D matrix operands | ||||
| using LayoutC = | ||||
|     cutlass::layout::RowMajor;  // Layout type for C and D matrix operands | ||||
| constexpr int AlignmentC = | ||||
|     128 / cutlass::sizeof_bits< | ||||
|               ElementC>::value;  // Memory access granularity/alignment of C | ||||
|                                  // matrix in units of elements (up to 16 bytes) | ||||
|  | ||||
| using ElementD = ElementC; | ||||
| using LayoutD = LayoutC; | ||||
| constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; | ||||
|  | ||||
| // Core kernel configurations | ||||
| using ElementAccumulator = float;     // Element type for internal accumulation | ||||
| using ElementCompute = float;         // Element type for epilogue computation | ||||
| using ArchTag = cutlass::arch::Sm90;  // Tag indicating the minimum SM that | ||||
|                                       // supports the intended feature | ||||
| using OperatorClass = cutlass::arch::OpClassTensorOp;  // Operator class tag | ||||
| using KernelSchedule = | ||||
|     cutlass::gemm::KernelTmaWarpSpecializedCooperative;  // Kernel to launch | ||||
|                                                          // based on the default | ||||
|                                                          // setting in the | ||||
|                                                          // Collective Builder | ||||
| using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; | ||||
| using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
| // Kernel template — Tile/Cluster shapes | ||||
| // ---------------------------------------------------------------------------- | ||||
| template <class TileShape_MN, class ClusterShape_MNK> | ||||
| struct W4A8GemmKernel { | ||||
|   using TileShape = | ||||
|       decltype(cute::append(TileShape_MN{}, cute::Int<TileShapeK>{})); | ||||
|   using ClusterShape = ClusterShape_MNK; | ||||
|  | ||||
|   // Epilogue per-tok, per-chan scales | ||||
|   using ChTokScalesEpilogue = | ||||
|       typename vllm::c3x::ScaledEpilogue<ElementAccumulator, ElementD, | ||||
|                                          TileShape>; | ||||
|   using EVTCompute = typename ChTokScalesEpilogue::EVTCompute; | ||||
|   using CollectiveEpilogue = | ||||
|       typename cutlass::epilogue::collective::CollectiveBuilder< | ||||
|           ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, | ||||
|           ElementAccumulator, ElementSChannel, | ||||
|           // Transpose layout of D here since we use explicit swap + transpose | ||||
|           // the void type for C tells the builder to allocate 0 smem for the C | ||||
|           // matrix. We can enable this if beta == 0 by changing ElementC to | ||||
|           // void below. | ||||
|           ElementC, typename cutlass::layout::LayoutTranspose<LayoutC>::type, | ||||
|           AlignmentC, ElementD, | ||||
|           typename cutlass::layout::LayoutTranspose<LayoutD>::type, AlignmentD, | ||||
|           EpilogueSchedule,  // This is the only epi supporting the required | ||||
|                              // swap + transpose. | ||||
|           EVTCompute>::CollectiveOp; | ||||
|  | ||||
|   // The Scale information must get paired with the operand that will be scaled. | ||||
|   // In this example, B is scaled so we make a tuple of B's information and the | ||||
|   // scale information. | ||||
|   using CollectiveMainloopShuffled = | ||||
|       typename cutlass::gemm::collective::CollectiveBuilder< | ||||
|           ArchTag, OperatorClass, | ||||
|           cute::tuple<ElementB, cutlass::Array<ElementScale, ScalePackSize>>, | ||||
|           LayoutB_Reordered, AlignmentB, ElementA, LayoutA_Transpose, | ||||
|           AlignmentA, ElementAccumulator, TileShape, ClusterShape, | ||||
|           cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( | ||||
|               sizeof(typename CollectiveEpilogue::SharedStorage))>, | ||||
|           KernelSchedule>::CollectiveOp; | ||||
|  | ||||
|   using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal< | ||||
|       Shape<int, int, int, int>,  // Indicates ProblemShape | ||||
|       CollectiveMainloopShuffled, CollectiveEpilogue>; | ||||
|   using GemmShuffled = | ||||
|       cutlass::gemm::device::GemmUniversalAdapter<GemmKernelShuffled>; | ||||
|  | ||||
|   using StrideC = typename GemmKernelShuffled::StrideC; | ||||
|   using StrideD = typename GemmKernelShuffled::StrideD; | ||||
|   using StrideS = typename CollectiveMainloopShuffled::StrideScale; | ||||
|  | ||||
|   static torch::Tensor mm(torch::Tensor const& A, | ||||
|                           torch::Tensor const& B,             // already packed | ||||
|                           torch::Tensor const& group_scales,  // already packed | ||||
|                           int64_t group_size, | ||||
|                           torch::Tensor const& channel_scales, | ||||
|                           torch::Tensor const& token_scales, | ||||
|                           std::optional<at::ScalarType> const& maybe_out_type) { | ||||
|     // TODO: param validation | ||||
|     int m = A.size(0); | ||||
|     int k = A.size(1); | ||||
|     int n = B.size(1); | ||||
|  | ||||
|     // Allocate output | ||||
|     const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); | ||||
|     auto device = A.device(); | ||||
|     auto stream = at::cuda::getCurrentCUDAStream(device.index()); | ||||
|     torch::Tensor D = | ||||
|         torch::empty({m, n}, torch::TensorOptions() | ||||
|                                  .dtype(equivalent_scalar_type_v<ElementD>) | ||||
|                                  .device(device)); | ||||
|     // prepare arg pointers | ||||
|     auto A_ptr = static_cast<MmaType const*>(A.const_data_ptr()); | ||||
|     auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr()); | ||||
|     auto D_ptr = static_cast<ElementD*>(D.data_ptr()); | ||||
|     // can we avoid harcode the 8 here | ||||
|     auto S_ptr = | ||||
|         static_cast<cutlass::Array<ElementScale, ScalePackSize> const*>( | ||||
|             group_scales.const_data_ptr()); | ||||
|  | ||||
|     // runtime layout for B | ||||
|     auto shape_B = cute::make_shape(n, k, 1); | ||||
|     LayoutB_Reordered layout_B_reordered = | ||||
|         cute::tile_to_shape(LayoutAtomQuant{}, shape_B); | ||||
|  | ||||
|     // strides | ||||
|     int const scale_k = cutlass::ceil_div(k, group_size); | ||||
|     StrideA stride_A = | ||||
|         cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); | ||||
|     // Reverse stride here due to swap and transpose | ||||
|     StrideD stride_D = | ||||
|         cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1)); | ||||
|     StrideS stride_S = cutlass::make_cute_packed_stride( | ||||
|         StrideS{}, cute::make_shape(n, scale_k, 1)); | ||||
|  | ||||
|     // Create a structure of gemm kernel arguments suitable for invoking an | ||||
|     // instance of Gemm auto arguments = | ||||
|     // args_from_options<GemmShuffled>(options); | ||||
|     /// Populates a Gemm::Arguments structure from the given arguments | ||||
|     /// Swap the A and B tensors, as well as problem shapes here. | ||||
|     using Args = typename GemmShuffled::Arguments; | ||||
|     using MainloopArguments = typename GemmKernelShuffled::MainloopArguments; | ||||
|     using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments; | ||||
|  | ||||
|     MainloopArguments mainloop_arguments{ | ||||
|         B_ptr, layout_B_reordered, A_ptr,     stride_A, | ||||
|         S_ptr, stride_S,           group_size}; | ||||
|  | ||||
|     EpilogueArguments epilogue_arguments{ | ||||
|         ChTokScalesEpilogue::prepare_args(channel_scales, token_scales), | ||||
|         nullptr, | ||||
|         {},  // no C | ||||
|         D_ptr, | ||||
|         stride_D}; | ||||
|  | ||||
|     Args arguments{cutlass::gemm::GemmUniversalMode::kGemm, | ||||
|                    {n, m, k, 1},  // shape | ||||
|                    mainloop_arguments, | ||||
|                    epilogue_arguments}; | ||||
|  | ||||
|     // Workspace | ||||
|     size_t workspace_size = GemmShuffled::get_workspace_size(arguments); | ||||
|     torch::Tensor workspace = | ||||
|         torch::empty(workspace_size, | ||||
|                      torch::TensorOptions().dtype(torch::kU8).device(device)); | ||||
|  | ||||
|     // Run GEMM | ||||
|     GemmShuffled gemm; | ||||
|     CUTLASS_CHECK(gemm.can_implement(arguments)); | ||||
|     CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); | ||||
|     CUTLASS_CHECK(gemm.run(stream)); | ||||
|  | ||||
|     return D; | ||||
|   } | ||||
| }; | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
| // Kernel instantiations and dispatch logic | ||||
| // ---------------------------------------------------------------------------- | ||||
| using Kernel_256x128_1x1x1 = | ||||
|     W4A8GemmKernel<Shape<_256, _128>, Shape<_1, _1, _1>>; | ||||
| using Kernel_256x64_1x1x1 = W4A8GemmKernel<Shape<_256, _64>, Shape<_1, _1, _1>>; | ||||
| using Kernel_256x32_1x1x1 = W4A8GemmKernel<Shape<_256, _32>, Shape<_1, _1, _1>>; | ||||
| using Kernel_256x16_1x1x1 = W4A8GemmKernel<Shape<_256, _16>, Shape<_1, _1, _1>>; | ||||
| using Kernel_128x256_2x1x1 = | ||||
|     W4A8GemmKernel<Shape<_128, _256>, Shape<_2, _1, _1>>; | ||||
| using Kernel_128x256_1x1x1 = | ||||
|     W4A8GemmKernel<Shape<_128, _256>, Shape<_1, _1, _1>>; | ||||
| using Kernel_128x128_1x1x1 = | ||||
|     W4A8GemmKernel<Shape<_128, _128>, Shape<_1, _1, _1>>; | ||||
| using Kernel_128x64_1x1x1 = W4A8GemmKernel<Shape<_128, _64>, Shape<_1, _1, _1>>; | ||||
| using Kernel_128x32_1x1x1 = W4A8GemmKernel<Shape<_128, _32>, Shape<_1, _1, _1>>; | ||||
| using Kernel_128x16_1x1x1 = W4A8GemmKernel<Shape<_128, _16>, Shape<_1, _1, _1>>; | ||||
|  | ||||
| torch::Tensor mm_dispatch(torch::Tensor const& A, | ||||
|                           torch::Tensor const& B,             // already packed | ||||
|                           torch::Tensor const& group_scales,  // already packed | ||||
|                           int64_t group_size, | ||||
|                           torch::Tensor const& channel_scales, | ||||
|                           torch::Tensor const& token_scales, | ||||
|                           std::optional<at::ScalarType> const& maybe_out_type, | ||||
|                           const std::string& schedule) { | ||||
|   if (schedule == "256x128_1x1x1") { | ||||
|     return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size, | ||||
|                                     channel_scales, token_scales, | ||||
|                                     maybe_out_type); | ||||
|   } else if (schedule == "256x64_1x1x1") { | ||||
|     return Kernel_256x64_1x1x1::mm(A, B, group_scales, group_size, | ||||
|                                    channel_scales, token_scales, | ||||
|                                    maybe_out_type); | ||||
|   } else if (schedule == "256x32_1x1x1") { | ||||
|     return Kernel_256x32_1x1x1::mm(A, B, group_scales, group_size, | ||||
|                                    channel_scales, token_scales, | ||||
|                                    maybe_out_type); | ||||
|   } else if (schedule == "256x16_1x1x1") { | ||||
|     return Kernel_256x16_1x1x1::mm(A, B, group_scales, group_size, | ||||
|                                    channel_scales, token_scales, | ||||
|                                    maybe_out_type); | ||||
|   } else if (schedule == "128x256_2x1x1") { | ||||
|     return Kernel_128x256_2x1x1::mm(A, B, group_scales, group_size, | ||||
|                                     channel_scales, token_scales, | ||||
|                                     maybe_out_type); | ||||
|   } else if (schedule == "128x256_1x1x1") { | ||||
|     return Kernel_128x256_1x1x1::mm(A, B, group_scales, group_size, | ||||
|                                     channel_scales, token_scales, | ||||
|                                     maybe_out_type); | ||||
|   } else if (schedule == "128x128_1x1x1") { | ||||
|     return Kernel_128x128_1x1x1::mm(A, B, group_scales, group_size, | ||||
|                                     channel_scales, token_scales, | ||||
|                                     maybe_out_type); | ||||
|   } else if (schedule == "128x64_1x1x1") { | ||||
|     return Kernel_128x64_1x1x1::mm(A, B, group_scales, group_size, | ||||
|                                    channel_scales, token_scales, | ||||
|                                    maybe_out_type); | ||||
|   } else if (schedule == "128x32_1x1x1") { | ||||
|     return Kernel_128x32_1x1x1::mm(A, B, group_scales, group_size, | ||||
|                                    channel_scales, token_scales, | ||||
|                                    maybe_out_type); | ||||
|   } else if (schedule == "128x16_1x1x1") { | ||||
|     return Kernel_128x16_1x1x1::mm(A, B, group_scales, group_size, | ||||
|                                    channel_scales, token_scales, | ||||
|                                    maybe_out_type); | ||||
|   } | ||||
|   TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule); | ||||
|   return {}; | ||||
| } | ||||
|  | ||||
| torch::Tensor mm(torch::Tensor const& A, | ||||
|                  torch::Tensor const& B,             // already packed | ||||
|                  torch::Tensor const& group_scales,  // already packed | ||||
|                  int64_t group_size, torch::Tensor const& channel_scales, | ||||
|                  torch::Tensor const& token_scales, | ||||
|                  std::optional<at::ScalarType> const& maybe_out_type, | ||||
|                  std::optional<std::string> maybe_schedule) { | ||||
|   // requested a specific schedule | ||||
|   if (maybe_schedule) { | ||||
|     return mm_dispatch(A, B, group_scales, group_size, channel_scales, | ||||
|                        token_scales, maybe_out_type, *maybe_schedule); | ||||
|   } | ||||
|   std::string schedule; | ||||
|   int M = A.size(0); | ||||
|   int K = A.size(1); | ||||
|   int N = B.size(1); | ||||
|   // heuristic | ||||
|   if (M <= 16) { | ||||
|     schedule = (K == 16384 && N == 18432) ? "256x16_1x1x1" : "128x16_1x1x1"; | ||||
|   } else if (M <= 32) { | ||||
|     schedule = (K == 16384 && N == 18432) ? "256x32_1x1x1" : "128x32_1x1x1"; | ||||
|   } else if (M <= 64) { | ||||
|     if (K == 16384 && N == 18432) | ||||
|       schedule = "256x64_1x1x1"; | ||||
|     else if (N <= 8192 && K <= 8192) | ||||
|       schedule = "128x32_1x1x1"; | ||||
|     else | ||||
|       schedule = "128x64_1x1x1"; | ||||
|   } else if (M <= 128) { | ||||
|     if (K == 16384 && N == 18432) | ||||
|       schedule = "256x128_1x1x1"; | ||||
|     else if (N <= 8192) | ||||
|       schedule = "128x64_1x1x1"; | ||||
|     else | ||||
|       schedule = "128x128_1x1x1"; | ||||
|   } else if (M <= 256) { | ||||
|     if (N <= 4096) | ||||
|       schedule = "128x64_1x1x1"; | ||||
|     else if (N <= 8192) | ||||
|       schedule = "128x128_1x1x1"; | ||||
|     else | ||||
|       schedule = "128x256_1x1x1"; | ||||
|   } else if (M <= 512 && N <= 4096) { | ||||
|     schedule = "128x128_1x1x1"; | ||||
|   } else if (M <= 1024) { | ||||
|     schedule = "128x256_1x1x1"; | ||||
|   } else { | ||||
|     schedule = "128x256_2x1x1"; | ||||
|   } | ||||
|   return mm_dispatch(A, B, group_scales, group_size, channel_scales, | ||||
|                      token_scales, maybe_out_type, schedule); | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
| // Pre-processing utils | ||||
| // ---------------------------------------------------------------------------- | ||||
| torch::Tensor pack_scale_fp8(torch::Tensor const& scales) { | ||||
|   TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn); | ||||
|   TORCH_CHECK(scales.is_contiguous()); | ||||
|   TORCH_CHECK(scales.is_cuda()); | ||||
|  | ||||
|   auto packed_scales = torch::empty( | ||||
|       {scales.numel() * ScalePackSize}, | ||||
|       torch::TensorOptions().dtype(scales.dtype()).device(scales.device())); | ||||
|   auto scales_ptr = static_cast<MmaType const*>(scales.const_data_ptr()); | ||||
|   auto packed_scales_ptr = | ||||
|       static_cast<cutlass::Array<ElementScale, ScalePackSize>*>( | ||||
|           packed_scales.data_ptr()); | ||||
|  | ||||
|   cutlass::pack_scale_fp8(scales_ptr, packed_scales_ptr, scales.numel()); | ||||
|  | ||||
|   return packed_scales; | ||||
| } | ||||
|  | ||||
| torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { | ||||
|   TORCH_CHECK(B.dtype() == torch::kInt32); | ||||
|   TORCH_CHECK(B.dim() == 2); | ||||
|  | ||||
|   torch::Tensor B_packed = torch::empty_like(B); | ||||
|  | ||||
|   int k = B.size(0) * PackFactor;  // logical k | ||||
|   int n = B.size(1); | ||||
|  | ||||
|   auto B_ptr = static_cast<QuantType const*>(B.const_data_ptr()); | ||||
|   auto B_packed_ptr = static_cast<QuantType*>(B_packed.data_ptr()); | ||||
|   auto shape_B = cute::make_shape(n, k, 1); | ||||
|   auto layout_B = make_layout(shape_B, LayoutRight{});  // row major | ||||
|   LayoutB_Reordered layout_B_reordered = | ||||
|       cute::tile_to_shape(LayoutAtomQuant{}, shape_B); | ||||
|  | ||||
|   cutlass::unified_encode_int4b(B_ptr, B_packed_ptr, n * k); | ||||
|   cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered); | ||||
|  | ||||
|   return B_packed; | ||||
| } | ||||
|  | ||||
| TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { | ||||
|   m.impl("cutlass_w4a8_mm", &mm); | ||||
|   m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8); | ||||
|   m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b); | ||||
| } | ||||
|  | ||||
| }  // namespace vllm::cutlass_w4a8 | ||||
| @ -10,7 +10,7 @@ | ||||
|  | ||||
| template <typename ElementAB, typename ElementC, typename ElementAccumulator> | ||||
| __global__ void get_group_gemm_starts( | ||||
|     int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, | ||||
|     int64_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets, | ||||
|     ElementC** out_offsets, ElementAccumulator** a_scales_offsets, | ||||
|     ElementAccumulator** b_scales_offsets, ElementAB* a_base_as_int, | ||||
|     ElementAB* b_base_as_int, ElementC* out_base_as_int, | ||||
| @ -34,7 +34,7 @@ __global__ void get_group_gemm_starts( | ||||
|   else if (out_tensors.dtype() == TENSOR_C_TYPE) {                         \ | ||||
|     get_group_gemm_starts<cutlass::float_e4m3_t, C_TYPE, float>            \ | ||||
|         <<<1, num_experts, 0, stream>>>(                                   \ | ||||
|             static_cast<int32_t*>(expert_offsets.data_ptr()),              \ | ||||
|             static_cast<int64_t*>(expert_offsets.data_ptr()),              \ | ||||
|             static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()),       \ | ||||
|             static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()),       \ | ||||
|             static_cast<C_TYPE**>(out_ptrs.data_ptr()),                    \ | ||||
| @ -61,6 +61,8 @@ void run_get_group_gemm_starts( | ||||
|   TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); | ||||
|   TORCH_CHECK(a_scales.dtype() == torch::kFloat32); | ||||
|   TORCH_CHECK(b_scales.dtype() == torch::kFloat32); | ||||
|   // expect int64_t to avoid overflow during offset calculations | ||||
|   TORCH_CHECK(expert_offsets.dtype() == torch::kInt64); | ||||
|  | ||||
|   int num_experts = static_cast<int>(expert_offsets.size(0)); | ||||
|   bool per_act_token = a_scales.numel() != 1; | ||||
|  | ||||
| @ -104,6 +104,53 @@ __global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids, | ||||
|   } | ||||
| } | ||||
|  | ||||
| namespace { | ||||
| inline void launch_compute_problem_sizes(const torch::Tensor& topk_ids, | ||||
|                                          torch::Tensor& problem_sizes1, | ||||
|                                          torch::Tensor& problem_sizes2, | ||||
|                                          torch::Tensor& atomic_buffer, | ||||
|                                          int64_t num_experts, int64_t n, | ||||
|                                          int64_t k, cudaStream_t stream, | ||||
|                                          const bool swap_ab) { | ||||
|   int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); | ||||
|  | ||||
|   const int32_t* topk_ptr = static_cast<const int32_t*>(topk_ids.data_ptr()); | ||||
|   int32_t* ps1_ptr = static_cast<int32_t*>(problem_sizes1.data_ptr()); | ||||
|   int32_t* ps2_ptr = static_cast<int32_t*>(problem_sizes2.data_ptr()); | ||||
|   int32_t* atomic_ptr = static_cast<int32_t*>(atomic_buffer.data_ptr()); | ||||
|  | ||||
|   if (swap_ab) { | ||||
|     compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>( | ||||
|         topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr, | ||||
|         static_cast<int>(topk_ids.numel()), static_cast<int>(n), | ||||
|         static_cast<int>(k)); | ||||
|   } else { | ||||
|     compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>( | ||||
|         topk_ptr, ps1_ptr, ps2_ptr, atomic_ptr, | ||||
|         static_cast<int>(topk_ids.numel()), static_cast<int>(n), | ||||
|         static_cast<int>(k)); | ||||
|   } | ||||
| } | ||||
| }  // namespace | ||||
|  | ||||
| void get_cutlass_moe_mm_problem_sizes_caller( | ||||
|     const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, | ||||
|     torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, | ||||
|     const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) { | ||||
|   auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); | ||||
|   auto options_int32 = | ||||
|       torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); | ||||
|   torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); | ||||
|  | ||||
|   // Swap-AB should be disabled for FP4 path | ||||
|   bool may_swap_ab = (!blockscale_offsets.has_value()) && | ||||
|                      (topk_ids.numel() <= SWAP_AB_THRESHOLD); | ||||
|  | ||||
|   launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, | ||||
|                                atomic_buffer, num_experts, n, k, stream, | ||||
|                                may_swap_ab); | ||||
| } | ||||
|  | ||||
| void get_cutlass_moe_mm_data_caller( | ||||
|     const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, | ||||
|     torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, | ||||
| @ -121,21 +168,9 @@ void get_cutlass_moe_mm_data_caller( | ||||
|   bool may_swap_ab = (!blockscale_offsets.has_value()) && | ||||
|                      (topk_ids.numel() <= SWAP_AB_THRESHOLD); | ||||
|  | ||||
|   if (may_swap_ab) { | ||||
|     compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>( | ||||
|         static_cast<const int32_t*>(topk_ids.data_ptr()), | ||||
|         static_cast<int32_t*>(problem_sizes1.data_ptr()), | ||||
|         static_cast<int32_t*>(problem_sizes2.data_ptr()), | ||||
|         static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, | ||||
|         k); | ||||
|   } else { | ||||
|     compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>( | ||||
|         static_cast<const int32_t*>(topk_ids.data_ptr()), | ||||
|         static_cast<int32_t*>(problem_sizes1.data_ptr()), | ||||
|         static_cast<int32_t*>(problem_sizes2.data_ptr()), | ||||
|         static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, | ||||
|         k); | ||||
|   } | ||||
|   launch_compute_problem_sizes(topk_ids, problem_sizes1, problem_sizes2, | ||||
|                                atomic_buffer, num_experts, n, k, stream, | ||||
|                                may_swap_ab); | ||||
|  | ||||
|   if (blockscale_offsets.has_value()) { | ||||
|     // fp4 path | ||||
|  | ||||
| @ -76,6 +76,11 @@ void get_cutlass_moe_mm_data_caller( | ||||
|     const int64_t num_experts, const int64_t n, const int64_t k, | ||||
|     const std::optional<torch::Tensor>& blockscale_offsets); | ||||
|  | ||||
| void get_cutlass_moe_mm_problem_sizes_caller( | ||||
|     const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, | ||||
|     torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, | ||||
|     const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets); | ||||
|  | ||||
| void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, | ||||
|                                          torch::Tensor& problem_sizes1, | ||||
|                                          torch::Tensor& problem_sizes2, | ||||
| @ -293,6 +298,25 @@ void get_cutlass_moe_mm_data( | ||||
|       version_num, ". Required capability: 90 or 100"); | ||||
| } | ||||
|  | ||||
| void get_cutlass_moe_mm_problem_sizes( | ||||
|     const torch::Tensor& topk_ids, torch::Tensor& problem_sizes1, | ||||
|     torch::Tensor& problem_sizes2, const int64_t num_experts, const int64_t n, | ||||
|     const int64_t k, const std::optional<torch::Tensor>& blockscale_offsets) { | ||||
|   int32_t version_num = get_sm_version_num(); | ||||
| #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ | ||||
|     (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) | ||||
|   get_cutlass_moe_mm_problem_sizes_caller(topk_ids, problem_sizes1, | ||||
|                                           problem_sizes2, num_experts, n, k, | ||||
|                                           blockscale_offsets); | ||||
|   return; | ||||
| #endif | ||||
|   TORCH_CHECK_NOT_IMPLEMENTED( | ||||
|       false, | ||||
|       "No compiled get_cutlass_moe_mm_problem_sizes: no cutlass_scaled_mm " | ||||
|       "kernel for CUDA device capability: ", | ||||
|       version_num, ". Required capability: 90 or 100"); | ||||
| } | ||||
|  | ||||
| void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, | ||||
|                                   torch::Tensor& problem_sizes1, | ||||
|                                   torch::Tensor& problem_sizes2, | ||||
|  | ||||
| @ -349,9 +349,12 @@ def to_cute_constant(value: list[int]): | ||||
|  | ||||
|  | ||||
| def unique_schedules(impl_configs: list[ImplConfig]): | ||||
|     return list( | ||||
|         set(sch for impl_config in impl_configs | ||||
|             for sch in impl_config.schedules)) | ||||
|     # Use dict over set for deterministic ordering | ||||
|     return list({ | ||||
|         sch: None | ||||
|         for impl_config in impl_configs | ||||
|         for sch in impl_config.schedules | ||||
|     }.keys()) | ||||
|  | ||||
|  | ||||
| def unsigned_type_with_bitwidth(num_bits): | ||||
| @ -568,78 +571,79 @@ def generate(): | ||||
|                      itertools.repeat(default_heuristic)) | ||||
|     ] | ||||
|  | ||||
|     # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) | ||||
|     # TODO (LucasWilkinson): Further tuning required | ||||
|     qqq_tile_heuristic_config = { | ||||
|         #### M = 257+ | ||||
|         # ((128, 256), (2, 1, 1)) Broken for QQQ types | ||||
|         # TODO (LucasWilkinson): Investigate further | ||||
|         # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), | ||||
|         # "M > 256": ((128, 256), (2, 1, 1)), | ||||
|         "M > 256": ((128, 128), (2, 1, 1)), | ||||
|         #### M = 129-256 | ||||
|         "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), | ||||
|         "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), | ||||
|         # ((128, 256), (2, 1, 1)) Broken for QQQ types | ||||
|         # TODO (LucasWilkinson): Investigate further | ||||
|         # "M > 128": ((128, 256), (2, 1, 1)), | ||||
|         "M > 128": ((128, 128), (2, 1, 1)), | ||||
|         #### M = 65-128 | ||||
|         "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), | ||||
|         "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), | ||||
|         "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), | ||||
|         "M > 64": ((128, 128), (2, 1, 1)), | ||||
|         #### M = 33-64 | ||||
|         "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), | ||||
|         # Broken for QQQ types | ||||
|         # TODO (LucasWilkinson): Investigate further | ||||
|         #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), | ||||
|         "M > 32": ((128, 64), (2, 1, 1)), | ||||
|         #### M = 17-32 | ||||
|         "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), | ||||
|         "M > 16": ((256, 32), (2, 1, 1)), | ||||
|         #### M = 1-16 | ||||
|         "N >= 26624": ((256, 16), (1, 1, 1)), | ||||
|         None: ((128, 16), (1, 1, 1)), | ||||
|     } | ||||
|     # TODO: Support W4A8 when ready | ||||
|     # # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) | ||||
|     # # TODO (LucasWilkinson): Further tuning required | ||||
|     # qqq_tile_heuristic_config = { | ||||
|     #     #### M = 257+ | ||||
|     #     # ((128, 256), (2, 1, 1)) Broken for QQQ types | ||||
|     #     # TODO (LucasWilkinson): Investigate further | ||||
|     #     # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), | ||||
|     #     # "M > 256": ((128, 256), (2, 1, 1)), | ||||
|     #     "M > 256": ((128, 128), (2, 1, 1)), | ||||
|     #     #### M = 129-256 | ||||
|     #     "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), | ||||
|     #     "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), | ||||
|     #     # ((128, 256), (2, 1, 1)) Broken for QQQ types | ||||
|     #     # TODO (LucasWilkinson): Investigate further | ||||
|     #     # "M > 128": ((128, 256), (2, 1, 1)), | ||||
|     #     "M > 128": ((128, 128), (2, 1, 1)), | ||||
|     #     #### M = 65-128 | ||||
|     #     "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), | ||||
|     #     "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), | ||||
|     #     "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), | ||||
|     #     "M > 64": ((128, 128), (2, 1, 1)), | ||||
|     #     #### M = 33-64 | ||||
|     #     "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), | ||||
|     #     # Broken for QQQ types | ||||
|     #     # TODO (LucasWilkinson): Investigate further | ||||
|     #     #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), | ||||
|     #     "M > 32": ((128, 64), (2, 1, 1)), | ||||
|     #     #### M = 17-32 | ||||
|     #     "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), | ||||
|     #     "M > 16": ((256, 32), (2, 1, 1)), | ||||
|     #     #### M = 1-16 | ||||
|     #     "N >= 26624": ((256, 16), (1, 1, 1)), | ||||
|     #     None: ((128, 16), (1, 1, 1)), | ||||
|     # } | ||||
|  | ||||
|     # For now we use the same heuristic for all types | ||||
|     # Heuristic is currently tuned for H100s | ||||
|     qqq_heuristic = [ | ||||
|         (cond, ScheduleConfig(*tile_config, | ||||
|                               **sch_common_params))  # type: ignore | ||||
|         for cond, tile_config in qqq_tile_heuristic_config.items() | ||||
|     ] | ||||
|     # # For now we use the same heuristic for all types | ||||
|     # # Heuristic is currently tuned for H100s | ||||
|     # qqq_heuristic = [ | ||||
|     #     (cond, ScheduleConfig(*tile_config, | ||||
|     #                           **sch_common_params))  # type: ignore | ||||
|     #     for cond, tile_config in qqq_tile_heuristic_config.items() | ||||
|     # ] | ||||
|  | ||||
|     QQQ_kernel_types = [ | ||||
|         *(TypeConfig( | ||||
|             a=DataType.s8, | ||||
|             b=VLLMDataType.u4b8, | ||||
|             b_group_scale=b_group_scale, | ||||
|             b_group_zeropoint=DataType.void, | ||||
|             b_channel_scale=DataType.f32, | ||||
|             a_token_scale=DataType.f32, | ||||
|             out=DataType.f16, | ||||
|             accumulator=DataType.s32, | ||||
|         ) for b_group_scale in (DataType.f16, DataType.void)), | ||||
|         *(TypeConfig( | ||||
|             a=DataType.e4m3, | ||||
|             b=VLLMDataType.u4b8, | ||||
|             b_group_scale=b_group_scale, | ||||
|             b_group_zeropoint=DataType.void, | ||||
|             b_channel_scale=DataType.f32, | ||||
|             a_token_scale=DataType.f32, | ||||
|             out=DataType.f16, | ||||
|             accumulator=DataType.f32, | ||||
|         ) for b_group_scale in (DataType.f16, DataType.void)), | ||||
|     ] | ||||
|     # QQQ_kernel_types = [ | ||||
|     #     *(TypeConfig( | ||||
|     #         a=DataType.s8, | ||||
|     #         b=VLLMDataType.u4b8, | ||||
|     #         b_group_scale=b_group_scale, | ||||
|     #         b_group_zeropoint=DataType.void, | ||||
|     #         b_channel_scale=DataType.f32, | ||||
|     #         a_token_scale=DataType.f32, | ||||
|     #         out=DataType.f16, | ||||
|     #         accumulator=DataType.s32, | ||||
|     #     ) for b_group_scale in (DataType.f16, DataType.void)), | ||||
|     #     *(TypeConfig( | ||||
|     #         a=DataType.e4m3, | ||||
|     #         b=VLLMDataType.u4b8, | ||||
|     #         b_group_scale=b_group_scale, | ||||
|     #         b_group_zeropoint=DataType.void, | ||||
|     #         b_channel_scale=DataType.f32, | ||||
|     #         a_token_scale=DataType.f32, | ||||
|     #         out=DataType.f16, | ||||
|     #         accumulator=DataType.f32, | ||||
|     #     ) for b_group_scale in (DataType.f16, DataType.void)), | ||||
|     # ] | ||||
|  | ||||
|     impl_configs += [ | ||||
|         ImplConfig(x[0], x[1], x[2]) | ||||
|         for x in zip(QQQ_kernel_types, | ||||
|                      itertools.repeat(get_unique_schedules(qqq_heuristic)), | ||||
|                      itertools.repeat(qqq_heuristic)) | ||||
|     ] | ||||
|     # impl_configs += [ | ||||
|     #     ImplConfig(x[0], x[1], x[2]) | ||||
|     #     for x in zip(QQQ_kernel_types, | ||||
|     #                  itertools.repeat(get_unique_schedules(qqq_heuristic)), | ||||
|     #                  itertools.repeat(qqq_heuristic)) | ||||
|     # ] | ||||
|  | ||||
|     output_dir = os.path.join(SCRIPT_DIR, "generated") | ||||
|  | ||||
|  | ||||
| @ -1,209 +0,0 @@ | ||||
| Contains code from https://github.com/IST-DASLab/marlin | ||||
|  | ||||
|                                  Apache License | ||||
|                            Version 2.0, January 2004 | ||||
|                         http://www.apache.org/licenses/ | ||||
|  | ||||
|    TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | ||||
|  | ||||
|    1. Definitions. | ||||
|  | ||||
|       "License" shall mean the terms and conditions for use, reproduction, | ||||
|       and distribution as defined by Sections 1 through 9 of this document. | ||||
|  | ||||
|       "Licensor" shall mean the copyright owner or entity authorized by | ||||
|       the copyright owner that is granting the License. | ||||
|  | ||||
|       "Legal Entity" shall mean the union of the acting entity and all | ||||
|       other entities that control, are controlled by, or are under common | ||||
|       control with that entity. For the purposes of this definition, | ||||
|       "control" means (i) the power, direct or indirect, to cause the | ||||
|       direction or management of such entity, whether by contract or | ||||
|       otherwise, or (ii) ownership of fifty percent (50%) or more of the | ||||
|       outstanding shares, or (iii) beneficial ownership of such entity. | ||||
|  | ||||
|       "You" (or "Your") shall mean an individual or Legal Entity | ||||
|       exercising permissions granted by this License. | ||||
|  | ||||
|       "Source" form shall mean the preferred form for making modifications, | ||||
|       including but not limited to software source code, documentation | ||||
|       source, and configuration files. | ||||
|  | ||||
|       "Object" form shall mean any form resulting from mechanical | ||||
|       transformation or translation of a Source form, including but | ||||
|       not limited to compiled object code, generated documentation, | ||||
|       and conversions to other media types. | ||||
|  | ||||
|       "Work" shall mean the work of authorship, whether in Source or | ||||
|       Object form, made available under the License, as indicated by a | ||||
|       copyright notice that is included in or attached to the work | ||||
|       (an example is provided in the Appendix below). | ||||
|  | ||||
|       "Derivative Works" shall mean any work, whether in Source or Object | ||||
|       form, that is based on (or derived from) the Work and for which the | ||||
|       editorial revisions, annotations, elaborations, or other modifications | ||||
|       represent, as a whole, an original work of authorship. For the purposes | ||||
|       of this License, Derivative Works shall not include works that remain | ||||
|       separable from, or merely link (or bind by name) to the interfaces of, | ||||
|       the Work and Derivative Works thereof. | ||||
|  | ||||
|       "Contribution" shall mean any work of authorship, including | ||||
|       the original version of the Work and any modifications or additions | ||||
|       to that Work or Derivative Works thereof, that is intentionally | ||||
|       submitted to Licensor for inclusion in the Work by the copyright owner | ||||
|       or by an individual or Legal Entity authorized to submit on behalf of | ||||
|       the copyright owner. For the purposes of this definition, "submitted" | ||||
|       means any form of electronic, verbal, or written communication sent | ||||
|       to the Licensor or its representatives, including but not limited to | ||||
|       communication on electronic mailing lists, source code control systems, | ||||
|       and issue tracking systems that are managed by, or on behalf of, the | ||||
|       Licensor for the purpose of discussing and improving the Work, but | ||||
|       excluding communication that is conspicuously marked or otherwise | ||||
|       designated in writing by the copyright owner as "Not a Contribution." | ||||
|  | ||||
|       "Contributor" shall mean Licensor and any individual or Legal Entity | ||||
|       on behalf of whom a Contribution has been received by Licensor and | ||||
|       subsequently incorporated within the Work. | ||||
|  | ||||
|    2. Grant of Copyright License. Subject to the terms and conditions of | ||||
|       this License, each Contributor hereby grants to You a perpetual, | ||||
|       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||
|       copyright license to reproduce, prepare Derivative Works of, | ||||
|       publicly display, publicly perform, sublicense, and distribute the | ||||
|       Work and such Derivative Works in Source or Object form. | ||||
|  | ||||
|    3. Grant of Patent License. Subject to the terms and conditions of | ||||
|       this License, each Contributor hereby grants to You a perpetual, | ||||
|       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||
|       (except as stated in this section) patent license to make, have made, | ||||
|       use, offer to sell, sell, import, and otherwise transfer the Work, | ||||
|       where such license applies only to those patent claims licensable | ||||
|       by such Contributor that are necessarily infringed by their | ||||
|       Contribution(s) alone or by combination of their Contribution(s) | ||||
|       with the Work to which such Contribution(s) was submitted. If You | ||||
|       institute patent litigation against any entity (including a | ||||
|       cross-claim or counterclaim in a lawsuit) alleging that the Work | ||||
|       or a Contribution incorporated within the Work constitutes direct | ||||
|       or contributory patent infringement, then any patent licenses | ||||
|       granted to You under this License for that Work shall terminate | ||||
|       as of the date such litigation is filed. | ||||
|  | ||||
|    4. Redistribution. You may reproduce and distribute copies of the | ||||
|       Work or Derivative Works thereof in any medium, with or without | ||||
|       modifications, and in Source or Object form, provided that You | ||||
|       meet the following conditions: | ||||
|  | ||||
|       (a) You must give any other recipients of the Work or | ||||
|           Derivative Works a copy of this License; and | ||||
|  | ||||
|       (b) You must cause any modified files to carry prominent notices | ||||
|           stating that You changed the files; and | ||||
|  | ||||
|       (c) You must retain, in the Source form of any Derivative Works | ||||
|           that You distribute, all copyright, patent, trademark, and | ||||
|           attribution notices from the Source form of the Work, | ||||
|           excluding those notices that do not pertain to any part of | ||||
|           the Derivative Works; and | ||||
|  | ||||
|       (d) If the Work includes a "NOTICE" text file as part of its | ||||
|           distribution, then any Derivative Works that You distribute must | ||||
|           include a readable copy of the attribution notices contained | ||||
|           within such NOTICE file, excluding those notices that do not | ||||
|           pertain to any part of the Derivative Works, in at least one | ||||
|           of the following places: within a NOTICE text file distributed | ||||
|           as part of the Derivative Works; within the Source form or | ||||
|           documentation, if provided along with the Derivative Works; or, | ||||
|           within a display generated by the Derivative Works, if and | ||||
|           wherever such third-party notices normally appear. The contents | ||||
|           of the NOTICE file are for informational purposes only and | ||||
|           do not modify the License. You may add Your own attribution | ||||
|           notices within Derivative Works that You distribute, alongside | ||||
|           or as an addendum to the NOTICE text from the Work, provided | ||||
|           that such additional attribution notices cannot be construed | ||||
|           as modifying the License. | ||||
|  | ||||
|       You may add Your own copyright statement to Your modifications and | ||||
|       may provide additional or different license terms and conditions | ||||
|       for use, reproduction, or distribution of Your modifications, or | ||||
|       for any such Derivative Works as a whole, provided Your use, | ||||
|       reproduction, and distribution of the Work otherwise complies with | ||||
|       the conditions stated in this License. | ||||
|  | ||||
|    5. Submission of Contributions. Unless You explicitly state otherwise, | ||||
|       any Contribution intentionally submitted for inclusion in the Work | ||||
|       by You to the Licensor shall be under the terms and conditions of | ||||
|       this License, without any additional terms or conditions. | ||||
|       Notwithstanding the above, nothing herein shall supersede or modify | ||||
|       the terms of any separate license agreement you may have executed | ||||
|       with Licensor regarding such Contributions. | ||||
|  | ||||
|    6. Trademarks. This License does not grant permission to use the trade | ||||
|       names, trademarks, service marks, or product names of the Licensor, | ||||
|       except as required for reasonable and customary use in describing the | ||||
|       origin of the Work and reproducing the content of the NOTICE file. | ||||
|  | ||||
|    7. Disclaimer of Warranty. Unless required by applicable law or | ||||
|       agreed to in writing, Licensor provides the Work (and each | ||||
|       Contributor provides its Contributions) on an "AS IS" BASIS, | ||||
|       WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
|       implied, including, without limitation, any warranties or conditions | ||||
|       of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A | ||||
|       PARTICULAR PURPOSE. You are solely responsible for determining the | ||||
|       appropriateness of using or redistributing the Work and assume any | ||||
|       risks associated with Your exercise of permissions under this License. | ||||
|  | ||||
|    8. Limitation of Liability. In no event and under no legal theory, | ||||
|       whether in tort (including negligence), contract, or otherwise, | ||||
|       unless required by applicable law (such as deliberate and grossly | ||||
|       negligent acts) or agreed to in writing, shall any Contributor be | ||||
|       liable to You for damages, including any direct, indirect, special, | ||||
|       incidental, or consequential damages of any character arising as a | ||||
|       result of this License or out of the use or inability to use the | ||||
|       Work (including but not limited to damages for loss of goodwill, | ||||
|       work stoppage, computer failure or malfunction, or any and all | ||||
|       other commercial damages or losses), even if such Contributor | ||||
|       has been advised of the possibility of such damages. | ||||
|  | ||||
|    9. Accepting Warranty or Additional Liability. While redistributing | ||||
|       the Work or Derivative Works thereof, You may choose to offer, | ||||
|       and charge a fee for, acceptance of support, warranty, indemnity, | ||||
|       or other liability obligations and/or rights consistent with this | ||||
|       License. However, in accepting such obligations, You may act only | ||||
|       on Your own behalf and on Your sole responsibility, not on behalf | ||||
|       of any other Contributor, and only if You agree to indemnify, | ||||
|       defend, and hold each Contributor harmless for any liability | ||||
|       incurred by, or claims asserted against, such Contributor by reason | ||||
|       of your accepting any such warranty or additional liability. | ||||
|  | ||||
|    END OF TERMS AND CONDITIONS | ||||
|  | ||||
|    APPENDIX: How to apply the Apache License to your work. | ||||
|  | ||||
|       To apply the Apache License to your work, attach the following | ||||
|       boilerplate notice, with the fields enclosed by brackets "{}" | ||||
|       replaced with your own identifying information. (Don't include | ||||
|       the brackets!)  The text should be enclosed in the appropriate | ||||
|       comment syntax for the file format. We also recommend that a | ||||
|       file or class name and description of purpose be included on the | ||||
|       same "printed page" as the copyright notice for easier | ||||
|       identification within third-party archives. | ||||
|  | ||||
|    Copyright {yyyy} {name of copyright owner} | ||||
|  | ||||
|    Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|    you may not use this file except in compliance with the License. | ||||
|    You may obtain a copy of the License at | ||||
|  | ||||
|        http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | ||||
|    Unless required by applicable law or agreed to in writing, software | ||||
|    distributed under the License is distributed on an "AS IS" BASIS, | ||||
|    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|    See the License for the specific language governing permissions and | ||||
|    limitations under the License. | ||||
|  | ||||
| ------------------------------------------------------------------------------------ | ||||
|  | ||||
| This product bundles various third-party components under other open source licenses. | ||||
| This section summarizes those components and their licenses. See licenses/ | ||||
| for text of these licenses. | ||||
| @ -1,32 +0,0 @@ | ||||
| /* | ||||
|  * Modified by HandH1998 | ||||
|  * Modified by Neural Magic | ||||
|  * Copyright (C) Marlin.2024 Elias Frantar | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *         http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
|  | ||||
| #pragma once | ||||
|  | ||||
| constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } | ||||
|  | ||||
| // Instances of `Vec` are used to organize groups of >>registers<<, as needed | ||||
| // for instance as inputs to tensor core operations. Consequently, all | ||||
| // corresponding index accesses must be compile-time constants, which is why we | ||||
| // extensively use `#pragma unroll` throughout the kernel code to guarantee | ||||
| // this. | ||||
| template <typename T, int n> | ||||
| struct Vec { | ||||
|   T elems[n]; | ||||
|   __device__ T& operator[](int i) { return elems[i]; } | ||||
| }; | ||||
| @ -1,89 +0,0 @@ | ||||
| /* | ||||
|  * Modified by HandH1998 | ||||
|  * Modified by Neural Magic | ||||
|  * Copyright (C) Marlin.2024 Elias Frantar | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *         http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
|  | ||||
| #pragma once | ||||
|  | ||||
| // Predicated asynchronous global->shared copy; used for inputs A where we apply | ||||
| // predication to handle batchsizes that are not multiples of 16. | ||||
| __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, | ||||
|                                       bool pred = true) { | ||||
|   const int BYTES = 16; | ||||
|   uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); | ||||
|   asm volatile( | ||||
|       "{\n" | ||||
|       "   .reg .pred p;\n" | ||||
|       "   setp.ne.b32 p, %0, 0;\n" | ||||
|       "   @p cp.async.cg.shared.global [%1], [%2], %3;\n" | ||||
|       "}\n" ::"r"((int)pred), | ||||
|       "r"(smem), "l"(glob_ptr), "n"(BYTES)); | ||||
| } | ||||
|  | ||||
| // Asynchronous global->shared copy | ||||
| __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { | ||||
|   const int BYTES = 16; | ||||
|   uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); | ||||
|   asm volatile( | ||||
|       "{\n" | ||||
|       "   cp.async.cg.shared.global [%0], [%1], %2;\n" | ||||
|       "}\n" ::"r"(smem), | ||||
|       "l"(glob_ptr), "n"(BYTES)); | ||||
| } | ||||
|  | ||||
| // Async copy fence. | ||||
| __device__ inline void cp_async_fence() { | ||||
|   asm volatile("cp.async.commit_group;\n" ::); | ||||
| } | ||||
|  | ||||
| // Wait until at most `n` async copy stages are still pending. | ||||
| template <int n> | ||||
| __device__ inline void cp_async_wait() { | ||||
|   asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); | ||||
| } | ||||
|  | ||||
| // Wait until barrier reaches `count`, then lock for current threadblock. | ||||
| __device__ inline void barrier_acquire(int* lock, int count) { | ||||
|   if (threadIdx.x == 0) { | ||||
|     int state = -1; | ||||
|     do | ||||
|       // Guarantee that subsequent writes by this threadblock will be visible | ||||
|       // globally. | ||||
|       asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" | ||||
|                    : "=r"(state) | ||||
|                    : "l"(lock)); | ||||
|     while (state != count); | ||||
|   } | ||||
|   __syncthreads(); | ||||
| } | ||||
|  | ||||
| // Release barrier and increment visitation count. | ||||
| __device__ inline void barrier_release(int* lock, bool reset = false) { | ||||
|   __syncthreads(); | ||||
|   if (threadIdx.x == 0) { | ||||
|     if (reset) { | ||||
|       lock[0] = 0; | ||||
|       return; | ||||
|     } | ||||
|     int val = 1; | ||||
|     // Make sure that all writes since acquiring this barrier are visible | ||||
|     // globally, while releasing the barrier. | ||||
|     asm volatile("fence.acq_rel.gpu;\n"); | ||||
|     asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" | ||||
|                  : | ||||
|                  : "l"(lock), "r"(val)); | ||||
|   } | ||||
| } | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -41,8 +41,10 @@ __device__ inline void vectorize_with_alignment( | ||||
|  | ||||
|     for (int i = tid; i < num_vec; i += stride) { | ||||
|       vout_t tmp; | ||||
|       vec_op(tmp, v_in[i]); | ||||
|       v_out[i] = tmp; | ||||
|       // Make a local copy of the entire pack | ||||
|       vin_t src = v_in[i];  // <- encourages a single vector ld | ||||
|       vec_op(tmp, src); | ||||
|       v_out[i] = tmp;  // <- encourages a single vector st | ||||
|     } | ||||
|     return; | ||||
|   } | ||||
| @ -71,8 +73,10 @@ __device__ inline void vectorize_with_alignment( | ||||
|   // 2. vectorize the main part | ||||
|   for (int i = tid; i < num_vec; i += stride) { | ||||
|     vout_t tmp; | ||||
|     vec_op(tmp, v_in[i]); | ||||
|     v_out[i] = tmp; | ||||
|     // Make a local copy of the entire pack | ||||
|     vin_t src = v_in[i];  // <- encourages a single vector ld | ||||
|     vec_op(tmp, src); | ||||
|     v_out[i] = tmp;  // <- encourages a single vector st | ||||
|   } | ||||
|  | ||||
|   // 3. handle the tail | ||||
| @ -125,7 +129,8 @@ __device__ inline void vectorize_read_with_alignment(const InT* in, int len, | ||||
|     auto* v_in = reinterpret_cast<const vin_t*>(in); | ||||
|  | ||||
|     for (int i = tid; i < num_vec; i += stride) { | ||||
|       vec_op(v_in[i]); | ||||
|       vin_t tmp = v_in[i]; | ||||
|       vec_op(tmp); | ||||
|     } | ||||
|     return; | ||||
|   } | ||||
|  | ||||
| @ -241,14 +241,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | ||||
|   // custom types: | ||||
|   // https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA | ||||
|  | ||||
|   // Marlin (Dense) Optimized Quantized GEMM for GPTQ. | ||||
|   ops.def( | ||||
|       "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " | ||||
|       "Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> " | ||||
|       "Tensor", | ||||
|       {stride_tag}); | ||||
|   // conditionally compiled so impl in source file | ||||
|  | ||||
|   // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. | ||||
|   ops.def( | ||||
|       "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, " | ||||
| @ -317,6 +309,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | ||||
|       "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " | ||||
|       "SymInt size_n, int num_bits) -> Tensor"); | ||||
|   // conditionally compiled so impl registrations are in source file | ||||
|  | ||||
|   // CUTLASS w4a8 GEMM | ||||
|   ops.def( | ||||
|       "cutlass_w4a8_mm(" | ||||
|       "   Tensor A," | ||||
|       "   Tensor B," | ||||
|       "   Tensor group_scales," | ||||
|       "   int    group_size," | ||||
|       "   Tensor channel_scales," | ||||
|       "   Tensor token_scales," | ||||
|       "   ScalarType? out_type," | ||||
|       "   str?   maybe_schedule" | ||||
|       ") -> Tensor", | ||||
|       {stride_tag}); | ||||
|   // pack scales | ||||
|   ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor"); | ||||
|   // encode and reorder weight matrix | ||||
|   ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"); | ||||
|   // conditionally compiled so impl registration is in source file | ||||
|  | ||||
| #endif | ||||
|  | ||||
|   // Dequantization for GGML. | ||||
| @ -353,15 +365,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | ||||
|   ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size); | ||||
|  | ||||
| #ifndef USE_ROCM | ||||
|   // marlin_qqq_gemm for QQQ. | ||||
|   ops.def( | ||||
|       "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, " | ||||
|       "Tensor s_tok, Tensor s_ch, Tensor s_group, " | ||||
|       "Tensor! workspace, SymInt size_m, SymInt size_n, " | ||||
|       "SymInt size_k) -> Tensor", | ||||
|       {stride_tag}); | ||||
|   // conditionally compiled so impl registration is in source file | ||||
|  | ||||
|   // CUTLASS nvfp4 block scaled GEMM | ||||
|   ops.def( | ||||
|       "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b," | ||||
| @ -440,6 +443,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | ||||
|       {stride_tag}); | ||||
|   ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data); | ||||
|  | ||||
|   // A function that computes problem sizes for each expert's multiplication | ||||
|   // used by the two mms called from fused MoE operation. It takes topk_ids as | ||||
|   // an input, and computes problem_sizes1 and problem_sizes2 only. | ||||
|   ops.def( | ||||
|       "get_cutlass_moe_mm_problem_sizes(Tensor topk_ids, " | ||||
|       "                                 Tensor! problem_sizes1, " | ||||
|       "                                 Tensor! problem_sizes2, " | ||||
|       "                                 int num_experts, int n, int k, " | ||||
|       "                                 Tensor? blockscale_offsets) -> ()", | ||||
|       {stride_tag}); | ||||
|   ops.impl("get_cutlass_moe_mm_problem_sizes", torch::kCUDA, | ||||
|            &get_cutlass_moe_mm_problem_sizes); | ||||
|  | ||||
|   // A function that computes data required to run fused MoE with w8a8 grouped | ||||
|   // GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs | ||||
|   // as an input, and computes expert_offsets (token start indices of each | ||||
| @ -676,11 +692,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { | ||||
|       "str kv_cache_dtype) -> ()"); | ||||
|   cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); | ||||
|  | ||||
|   // Gather cache blocks from src_cache to dst. | ||||
|   // Gather cache blocks from src_cache to dst, dequantizing from | ||||
|   // src_cache's dtype to dst's dtype if necessary. | ||||
|   cache_ops.def( | ||||
|       "gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " | ||||
|       "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); | ||||
|   cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache); | ||||
|       "gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, " | ||||
|       "                               Tensor block_table, Tensor cu_seq_lens, " | ||||
|       "                               int batch_size, " | ||||
|       "                               str kv_cache_dtype, " | ||||
|       "                               Tensor scale, Tensor? seq_starts) -> ()"); | ||||
|   cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA, | ||||
|                  &gather_and_maybe_dequant_cache); | ||||
| } | ||||
|  | ||||
| TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { | ||||
|  | ||||
| @ -372,31 +372,45 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist | ||||
|  | ||||
| # Install FlashInfer from source | ||||
| ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" | ||||
| # Keep this in sync with https://github.com/vllm-project/vllm/blob/main/requirements/cuda.txt | ||||
| # We use `--force-reinstall --no-deps` to avoid issues with the existing FlashInfer wheel. | ||||
| ARG FLASHINFER_GIT_REF="v0.2.11" | ||||
| # Keep this in sync with "flashinfer" extra in setup.py | ||||
| ARG FLASHINFER_GIT_REF="v0.2.14.post1" | ||||
| # Flag to control whether to compile FlashInfer AOT kernels | ||||
| # Set to "true" to enable AOT compilation: | ||||
| # docker build --build-arg FLASHINFER_AOT_COMPILE=true ... | ||||
| ARG FLASHINFER_AOT_COMPILE=false | ||||
| RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' | ||||
|   . /etc/environment | ||||
|     git clone --depth 1 --recursive --shallow-submodules \ | ||||
|         --branch ${FLASHINFER_GIT_REF} \ | ||||
|         ${FLASHINFER_GIT_REPO} flashinfer | ||||
|     # Exclude CUDA arches for older versions (11.x and 12.0-12.7) | ||||
|     # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. | ||||
|     if [[ "${CUDA_VERSION}" == 11.* ]]; then | ||||
|         FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" | ||||
|     elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then | ||||
|         FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" | ||||
|     else | ||||
|         # CUDA 12.8+ supports 10.0a and 12.0 | ||||
|         FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" | ||||
|     fi | ||||
|     echo "🏗️  Building FlashInfer for arches: ${FI_TORCH_CUDA_ARCH_LIST}" | ||||
|     # Needed to build AOT kernels | ||||
|     pushd flashinfer | ||||
|         TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ | ||||
|             python3 -m flashinfer.aot | ||||
|         TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ | ||||
|             uv pip install --system --no-build-isolation --force-reinstall --no-deps . | ||||
|         if [ "${FLASHINFER_AOT_COMPILE}" = "true" ]; then | ||||
|             # Exclude CUDA arches for older versions (11.x and 12.0-12.7) | ||||
|             # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. | ||||
|             if [[ "${CUDA_VERSION}" == 11.* ]]; then | ||||
|                 FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" | ||||
|             elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then | ||||
|                 FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" | ||||
|             else | ||||
|                 # CUDA 12.8+ supports 10.0a and 12.0 | ||||
|                 FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" | ||||
|             fi | ||||
|             echo "🏗️  Installing FlashInfer with AOT compilation for arches: ${FI_TORCH_CUDA_ARCH_LIST}" | ||||
|             # Build AOT kernels | ||||
|             TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ | ||||
|                 python3 -m flashinfer.aot | ||||
|             # Install with no-build-isolation since we already built AOT kernels | ||||
|             TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ | ||||
|                 uv pip install --system --no-build-isolation . \ | ||||
|                 --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') | ||||
|             # Download pre-compiled cubins | ||||
|             TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ | ||||
|                 python3 -m flashinfer --download-cubin || echo "WARNING: Failed to download flashinfer cubins." | ||||
|         else | ||||
|             echo "🏗️  Installing FlashInfer without AOT compilation in JIT mode" | ||||
|             uv pip install --system . \ | ||||
|                 --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') | ||||
|         fi | ||||
|     popd | ||||
|     rm -rf flashinfer | ||||
| BASH | ||||
| @ -418,31 +432,19 @@ RUN --mount=type=cache,target=/root/.cache/uv \ | ||||
|         --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') | ||||
|  | ||||
| # Install DeepGEMM from source | ||||
| ARG DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git" | ||||
| ARG DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c" | ||||
| RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' | ||||
|   . /etc/environment | ||||
|     CUDA_MAJOR="${CUDA_VERSION%%.*}" | ||||
|     CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}" | ||||
|     CUDA_MINOR="${CUDA_MINOR%%.*}" | ||||
|     if [ "$CUDA_MAJOR" -ge 12 ] && [ "$CUDA_MINOR" -ge 8 ]; then | ||||
|         git clone --recursive --shallow-submodules \ | ||||
|             ${DEEPGEMM_GIT_REPO} deepgemm | ||||
|         echo "🏗️  Building DeepGEMM" | ||||
|         pushd deepgemm | ||||
|             git checkout ${DEEPGEMM_GIT_REF} | ||||
|             # Build DeepGEMM | ||||
|             # (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh) | ||||
|             rm -rf build dist | ||||
|             rm -rf *.egg-info | ||||
|             python3 setup.py bdist_wheel | ||||
|             uv pip install --system dist/*.whl | ||||
|         popd | ||||
|         rm -rf deepgemm | ||||
|     else | ||||
|         echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})" | ||||
|     fi | ||||
| BASH | ||||
| COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh | ||||
| RUN --mount=type=cache,target=/root/.cache/uv \ | ||||
|     VLLM_DOCKER_BUILD_CONTEXT=1 /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" --ref "${DEEPGEMM_GIT_REF}" \ | ||||
|     && rm /tmp/install_deepgemm.sh | ||||
|  | ||||
| # Install EP kernels(pplx-kernels and DeepEP), NixL | ||||
| COPY tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh | ||||
| COPY tools/install_nixl.sh install_nixl.sh | ||||
| ENV CUDA_HOME=/usr/local/cuda | ||||
| RUN export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0a+PTX}" \ | ||||
|     && bash install_python_libraries.sh \ | ||||
|     && bash install_nixl.sh --force | ||||
|  | ||||
| #################### vLLM installation IMAGE #################### | ||||
|  | ||||
|  | ||||
| @ -71,7 +71,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace | ||||
| RUN cd /vllm-workspace \ | ||||
|     && rm -rf vllm \ | ||||
|     && python3 -m pip install -e tests/vllm_test_utils \ | ||||
|     && python3 -m pip install lm-eval[api]==0.4.4 \ | ||||
|     && python3 -m pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] \ | ||||
|     && python3 -m pip install pytest-shard | ||||
|  | ||||
| # ----------------------- | ||||
|  | ||||
| @ -16,7 +16,7 @@ ENV LANG=C.UTF-8 \ | ||||
| RUN microdnf install -y \ | ||||
|     which procps findutils tar vim git gcc gcc-gfortran g++ make patch zlib-devel \ | ||||
|     libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \ | ||||
|     openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy && \ | ||||
|     openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy libsndfile && \ | ||||
|     microdnf clean all | ||||
|  | ||||
| # Python Installation | ||||
| @ -136,6 +136,71 @@ RUN --mount=type=cache,target=/root/.cache/uv \ | ||||
|     mkdir -p /tmp/hf-xet/dist && \ | ||||
|     cp dist/*.whl /tmp/hf-xet/dist/ | ||||
|  | ||||
| # Build numba | ||||
| FROM python-install AS numba-builder | ||||
|  | ||||
| ARG MAX_JOBS | ||||
| ARG NUMBA_VERSION=0.61.2 | ||||
|  | ||||
| WORKDIR /tmp | ||||
|  | ||||
| # Clone all required dependencies | ||||
| RUN --mount=type=cache,target=/root/.cache/uv \ | ||||
|     microdnf install ninja-build gcc gcc-c++ -y && \ | ||||
|     git clone --recursive https://github.com/llvm/llvm-project.git -b llvmorg-15.0.7  && \ | ||||
|     git clone --recursive https://github.com/numba/llvmlite.git -b v0.44.0 && \ | ||||
|     git clone --recursive https://github.com/numba/numba.git -b ${NUMBA_VERSION} && \ | ||||
|     cd llvm-project && mkdir build && cd  build && \ | ||||
|     uv pip install 'cmake<4' setuptools numpy && \ | ||||
|     export PREFIX=/usr/local && CMAKE_ARGS="${CMAKE_ARGS} -DLLVM_ENABLE_PROJECTS=lld;libunwind;compiler-rt" \ | ||||
|     CFLAGS="$(echo $CFLAGS | sed 's/-fno-plt //g')" \ | ||||
|     CXXFLAGS="$(echo $CXXFLAGS | sed 's/-fno-plt //g')" \ | ||||
|     CMAKE_ARGS="${CMAKE_ARGS} -DFFI_INCLUDE_DIR=$PREFIX/include" \ | ||||
|     CMAKE_ARGS="${CMAKE_ARGS} -DFFI_LIBRARY_DIR=$PREFIX/lib" \ | ||||
|     cmake -DCMAKE_INSTALL_PREFIX="${PREFIX}"               \ | ||||
|         -DCMAKE_BUILD_TYPE=Release                       \ | ||||
|         -DCMAKE_LIBRARY_PATH="${PREFIX}"                 \ | ||||
|         -DLLVM_ENABLE_LIBEDIT=OFF                        \ | ||||
|         -DLLVM_ENABLE_LIBXML2=OFF                        \ | ||||
|         -DLLVM_ENABLE_RTTI=ON                            \ | ||||
|         -DLLVM_ENABLE_TERMINFO=OFF                       \ | ||||
|         -DLLVM_INCLUDE_BENCHMARKS=OFF                    \ | ||||
|         -DLLVM_INCLUDE_DOCS=OFF                          \ | ||||
|         -DLLVM_INCLUDE_EXAMPLES=OFF                      \ | ||||
|         -DLLVM_INCLUDE_GO_TESTS=OFF                      \ | ||||
|         -DLLVM_INCLUDE_TESTS=OFF                         \ | ||||
|         -DLLVM_INCLUDE_UTILS=ON                          \ | ||||
|         -DLLVM_INSTALL_UTILS=ON                          \ | ||||
|         -DLLVM_UTILS_INSTALL_DIR=libexec/llvm            \ | ||||
|         -DLLVM_BUILD_LLVM_DYLIB=OFF                      \ | ||||
|         -DLLVM_LINK_LLVM_DYLIB=OFF                       \ | ||||
|         -DLLVM_EXPERIMENTAL_TARGETS_TO_BUILD=WebAssembly \ | ||||
|         -DLLVM_ENABLE_FFI=ON                             \ | ||||
|         -DLLVM_ENABLE_Z3_SOLVER=OFF                      \ | ||||
|         -DLLVM_OPTIMIZED_TABLEGEN=ON                     \ | ||||
|         -DCMAKE_POLICY_DEFAULT_CMP0111=NEW               \ | ||||
|         -DCOMPILER_RT_BUILD_BUILTINS=ON                  \ | ||||
|         -DCOMPILER_RT_BUILTINS_HIDE_SYMBOLS=OFF          \ | ||||
|         -DCOMPILER_RT_BUILD_LIBFUZZER=OFF                \ | ||||
|         -DCOMPILER_RT_BUILD_CRT=OFF                      \ | ||||
|         -DCOMPILER_RT_BUILD_MEMPROF=OFF                  \ | ||||
|         -DCOMPILER_RT_BUILD_PROFILE=OFF                  \ | ||||
|         -DCOMPILER_RT_BUILD_SANITIZERS=OFF               \ | ||||
|         -DCOMPILER_RT_BUILD_XRAY=OFF                     \ | ||||
|         -DCOMPILER_RT_BUILD_GWP_ASAN=OFF                 \ | ||||
|         -DCOMPILER_RT_BUILD_ORC=OFF                      \ | ||||
|         -DCOMPILER_RT_INCLUDE_TESTS=OFF                  \ | ||||
|         ${CMAKE_ARGS} -GNinja ../llvm                    \ | ||||
|  | ||||
|     && ninja install  . && \ | ||||
|     #  build llvmlite | ||||
|     cd ../../llvmlite && python setup.py bdist_wheel && \ | ||||
|     cd ../numba && \ | ||||
|     if ! grep '#include "dynamic_annotations.h"' numba/_dispatcher.cpp; then \ | ||||
|        sed -i '/#include "internal\/pycore_atomic.h"/i\#include "dynamic_annotations.h"' numba/_dispatcher.cpp; \ | ||||
|     fi && python setup.py bdist_wheel | ||||
|  | ||||
|  | ||||
| # Final build stage | ||||
| FROM python-install AS vllm-cpu | ||||
| ARG PYTHON_VERSION | ||||
| @ -163,23 +228,30 @@ RUN --mount=type=cache,target=/root/.cache/uv \ | ||||
|     --mount=type=bind,from=torch-vision,source=/tmp/vision/dist,target=/tmp/vision-wheels/ \ | ||||
|     --mount=type=bind,from=hf-xet-builder,source=/tmp/hf-xet/dist,target=/tmp/hf-xet-wheels/ \ | ||||
|     --mount=type=bind,from=torch,source=/tmp/pytorch/dist,target=/tmp/torch-wheels/ \ | ||||
|     --mount=type=bind,from=numba-builder,source=/tmp/llvmlite/dist,target=/tmp/llvmlite-wheels/ \ | ||||
|     --mount=type=bind,from=numba-builder,source=/tmp/numba/dist,target=/tmp/numba-wheels/ \ | ||||
|      sed -i '/^torch/d' requirements/build.txt && \ | ||||
|      ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl | head -n 1) && \ | ||||
|      VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl | head -n 1) && \ | ||||
|      HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl | head -n 1) && \ | ||||
|      TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl | head -n 1) && \ | ||||
|      ARROW_WHL_FILE=$(ls /tmp/arrow-wheels/pyarrow-*.whl) && \ | ||||
|      VISION_WHL_FILE=$(ls /tmp/vision-wheels/*.whl) && \ | ||||
|      HF_XET_WHL_FILE=$(ls /tmp/hf-xet-wheels/*.whl) && \ | ||||
|      TORCH_WHL_FILE=$(ls /tmp/torch-wheels/*.whl) && \ | ||||
|      LLVM_WHL_FILE=$(ls /tmp/llvmlite-wheels/*.whl) && \ | ||||
|      NUMBA_WHL_FILE=$(ls /tmp/numba-wheels/*.whl) && \ | ||||
|     uv pip install -v \     | ||||
|         $ARROW_WHL_FILE  \ | ||||
|         $VISION_WHL_FILE \ | ||||
|         $HF_XET_WHL_FILE \ | ||||
|         $TORCH_WHL_FILE \ | ||||
|         $LLVM_WHL_FILE \ | ||||
|         $NUMBA_WHL_FILE \ | ||||
|         --index-strategy unsafe-best-match \ | ||||
|         -r requirements/build.txt \ | ||||
|         -r requirements/cpu.txt  | ||||
|         -r requirements/cpu.txt | ||||
|  | ||||
|  | ||||
| # Build and install vllm | ||||
| RUN --mount=type=cache,target=/root/.cache/uv \ | ||||
|     VLLM_TARGET_DEVICE=cpu python setup.py bdist_wheel && \ | ||||
|     VLLM_TARGET_DEVICE=cpu VLLM_CPU_MOE_PREPACK=0 python setup.py bdist_wheel && \ | ||||
|     uv pip install "$(echo dist/*.whl)[tensorizer]" | ||||
|  | ||||
| # setup non-root user for vllm | ||||
| @ -196,4 +268,3 @@ WORKDIR /home/vllm | ||||
|  | ||||
| # Set the default entrypoint | ||||
| ENTRYPOINT ["python", "-m", "vllm.entrypoints.openai.api_server"] | ||||
|  | ||||
|  | ||||
| @ -7,7 +7,8 @@ WORKDIR /workspace/vllm | ||||
| # Install some basic utilities | ||||
| RUN apt-get update && apt-get install -y \ | ||||
|     git \ | ||||
|     ffmpeg libsm6 libxext6 libgl1 | ||||
|     ffmpeg libsm6 libxext6 libgl1 && \ | ||||
|     rm -rf /var/lib/apt/lists/* | ||||
|  | ||||
| # Build vLLM. | ||||
| COPY . . | ||||
| @ -16,6 +17,9 @@ RUN --mount=type=bind,source=.git,target=.git \ | ||||
|     if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi | ||||
|  | ||||
| # Remove existing versions of dependencies | ||||
| # TODO: These packages will remain as dead weight in the Docker image layers. | ||||
| # We should find a way to build the image without uninstalling these. | ||||
| # Consider using a different base image. | ||||
| RUN pip uninstall -y torch torch_xla torchvision | ||||
|  | ||||
| ENV VLLM_TARGET_DEVICE="tpu" | ||||
| @ -23,9 +27,10 @@ RUN --mount=type=cache,target=/root/.cache/pip \ | ||||
|     --mount=type=bind,source=.git,target=.git \ | ||||
|     python3 -m pip install \ | ||||
|         -r requirements/tpu.txt | ||||
| RUN python3 -m pip install -e . | ||||
|  | ||||
| RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install -e . | ||||
|  | ||||
| # install development dependencies (for testing) | ||||
| RUN python3 -m pip install -e tests/vllm_test_utils | ||||
| RUN --mount=type=cache,target=/root/.cache/pip python3 -m pip install -e tests/vllm_test_utils | ||||
|  | ||||
| CMD ["/bin/bash"] | ||||
|  | ||||
| @ -77,6 +77,7 @@ Internal data structures. | ||||
| - [vllm.multimodal.inputs.MultiModalFieldElem][] | ||||
| - [vllm.multimodal.inputs.MultiModalFieldConfig][] | ||||
| - [vllm.multimodal.inputs.MultiModalKwargsItem][] | ||||
| - [vllm.multimodal.inputs.MultiModalKwargsItems][] | ||||
| - [vllm.multimodal.inputs.MultiModalKwargs][] | ||||
| - [vllm.multimodal.inputs.MultiModalInputs][] | ||||
|  | ||||
|  | ||||
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 24 KiB | 
							
								
								
									
										
											BIN
										
									
								
								docs/assets/design/hybrid_kv_cache_manager/full_attn.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								docs/assets/design/hybrid_kv_cache_manager/full_attn.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 4.0 KiB | 
							
								
								
									
										
											BIN
										
									
								
								docs/assets/design/hybrid_kv_cache_manager/memory_layout.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								docs/assets/design/hybrid_kv_cache_manager/memory_layout.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 62 KiB | 
							
								
								
									
										
											BIN
										
									
								
								docs/assets/design/hybrid_kv_cache_manager/overview.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								docs/assets/design/hybrid_kv_cache_manager/overview.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 39 KiB | 
							
								
								
									
										
											BIN
										
									
								
								docs/assets/design/hybrid_kv_cache_manager/sw_attn.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								docs/assets/design/hybrid_kv_cache_manager/sw_attn.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 4.5 KiB | 
| @ -2,6 +2,7 @@ | ||||
|  | ||||
| We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: | ||||
|  | ||||
| - [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg), August 23rd 2025. [[Slides]](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH) | ||||
| - [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA), August 2nd 2025. [[Slides]](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) [[Recording]](https://www.chaspark.com/#/live/1166916873711665152). | ||||
| - [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing) | ||||
| - [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). | ||||
|  | ||||
| @ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", | ||||
|  | ||||
| If you run out of CPU RAM, try the following options: | ||||
|  | ||||
| - (Multi-modal models only) you can set the size of multi-modal processor cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB per API process + 4 GiB per engine core process) | ||||
| - (Multi-modal models only) you can set the size of multi-modal cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB). | ||||
| - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). | ||||
|  | ||||
| ## Multi-modal input limits | ||||
|  | ||||
| @ -48,7 +48,7 @@ You can tune the performance by adjusting `max_num_batched_tokens`: | ||||
|  | ||||
| - Smaller values (e.g., 2048) achieve better inter-token latency (ITL) because there are fewer prefills slowing down decodes. | ||||
| - Higher values achieve better time to first token (TTFT) as you can process more prefill tokens in a batch. | ||||
| - For optimal throughput, we recommend setting `max_num_batched_tokens > 8096` especially for smaller models on large GPUs. | ||||
| - For optimal throughput, we recommend setting `max_num_batched_tokens > 8192` especially for smaller models on large GPUs. | ||||
| - If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the V0 default scheduling policy (except that it still prioritizes decodes). | ||||
|  | ||||
| ```python | ||||
| @ -129,6 +129,53 @@ Data parallelism replicates the entire model across multiple GPU sets and proces | ||||
| Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`. | ||||
| Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size. | ||||
|  | ||||
| ### Batch-level DP for Multi-Modal Encoders | ||||
|  | ||||
| By default, TP is used to shard the weights of multi-modal encoders just like for language decoders, | ||||
| in order to reduce the memory and compute load on each GPU. | ||||
|  | ||||
| However, since the size of multi-modal encoders is very small compared to language decoders, | ||||
| there is relatively little gain from TP. On the other hand, TP incurs significant communication | ||||
| overhead because of all-reduce being performed after every layer. | ||||
|  | ||||
| Given this, it may be advantageous to instead shard the batched input data using TP, essentially | ||||
| performing batch-level DP. This has been shown to improve the throughput by around 10% for | ||||
| `tensor_parallel_size=8`. For vision encoders that use hardware-unoptimized Conv3D operations, | ||||
| batch-level DP can provide another 40% increase to throughput compared to regular TP. | ||||
|  | ||||
| Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank, | ||||
| there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already. | ||||
|  | ||||
| You can enable batch-level DP by setting `mm_encoder_tp_mode="data"`, for example: | ||||
|  | ||||
| ```python | ||||
| from vllm import LLM | ||||
|  | ||||
| llm = LLM( | ||||
|     model="Qwen/Qwen2.5-VL-72B-Instruct", | ||||
|     tensor_parallel_size=4, | ||||
|     # When mm_encoder_tp_mode="data", | ||||
|     # the vision encoder uses TP=4 (not DP=1) to shard the input data, | ||||
|     # so the TP size becomes the effective DP size. | ||||
|     # Note that this is independent of the DP size for language decoder which is used in expert parallel setting. | ||||
|     mm_encoder_tp_mode="data", | ||||
|     # The language decoder uses TP=4 to shard the weights regardless | ||||
|     # of the setting of mm_encoder_tp_mode | ||||
| ) | ||||
| ``` | ||||
|  | ||||
| !! important | ||||
|     Batch-level DP is not to be confused with API request-level DP | ||||
|     (which is instead controlled by `data_parallel_size`). | ||||
|  | ||||
| The availability of batch-level DP is based on model implementation. | ||||
| Currently, the following models support `mm_encoder_tp_mode="data"`: | ||||
|  | ||||
| - Llama4 (<gh-pr:18368>) | ||||
| - MiniCPM-V-4 (<gh-pr:23327>) | ||||
| - Qwen2.5-VL (<gh-pr:22742>) | ||||
| - Step3 (<gh-pr:22697>) | ||||
|  | ||||
| ## Input Processing | ||||
|  | ||||
| ### Parallel Processing | ||||
| @ -149,21 +196,41 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 | ||||
| !!! note | ||||
|     API server scale-out is only available for online inference. | ||||
|  | ||||
| !!! warning | ||||
|     By default, 8 CPU threads are used in each API server to load media items (e.g. images) | ||||
|     from request data. | ||||
|  | ||||
|     If you apply API server scale-out, consider adjusting `VLLM_MEDIA_LOADING_THREAD_COUNT` | ||||
|     to avoid CPU resource exhaustion. | ||||
|  | ||||
| !!! note | ||||
|     [Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled | ||||
|     API server scale-out disables [multi-modal IPC caching](#ipc-caching) | ||||
|     because it requires a one-to-one correspondance between API and engine core processes. | ||||
|  | ||||
|     This does not impact [multi-modal processor caching](#processor-caching). | ||||
|  | ||||
| ## Multi-Modal Caching | ||||
|  | ||||
| ### Processor Cache | ||||
|  | ||||
| By default, the multi-modal processor cache is enabled to avoid repeatedly processing | ||||
| the same multi-modal inputs via Hugging Face `AutoProcessor`, | ||||
| Multi-modal caching avoids repeated transfer or processing of the same multi-modal data, | ||||
| which commonly occurs in multi-turn conversations. | ||||
|  | ||||
| You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` | ||||
| (default 4 GiB per API process + 4 GiB per engine core process). | ||||
| If you do not benefit much from the cache, you can disable it completely via `mm_processor_cache_gb=0`. | ||||
| ### Processor Caching | ||||
|  | ||||
| Multi-modal processor caching is automatically enabled | ||||
| to avoid repeatedly processing the same multi-modal inputs in `BaseMultiModalProcessor`. | ||||
|  | ||||
| ### IPC Caching | ||||
|  | ||||
| Multi-modal IPC caching is automatically enabled when | ||||
| there is a one-to-one correspondance between API (`P0`) and engine core (`P1`) processes, | ||||
| to avoid repeatedly transferring the same multi-modal inputs between them. | ||||
|  | ||||
| ### Configuration | ||||
|  | ||||
| You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB). | ||||
|  | ||||
| If you do not benefit much from the cache, you can disable both IPC | ||||
| and processor caching completely via `mm_processor_cache_gb=0`. | ||||
|  | ||||
| Examples: | ||||
|  | ||||
| @ -176,3 +243,16 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", | ||||
| llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", | ||||
|           mm_processor_cache_gb=0) | ||||
| ``` | ||||
|  | ||||
| ### Cache Placement | ||||
|  | ||||
| Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows: | ||||
|  | ||||
| | Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory | | ||||
| |-------------------|-------------|------------|------------|-------------| | ||||
| | ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` | | ||||
| | ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` | | ||||
| | ❌ | ❌ | N/A | N/A | `0` | | ||||
|  | ||||
| K: Stores the hashes of multi-modal items   | ||||
| V: Stores the processed tensor data of multi-modal items | ||||
|  | ||||
| @ -70,7 +70,7 @@ For example, max_model_len=512, padding_gap=64, the buckets will be [16, 32, 64, | ||||
|  | ||||
| The fewer tokens we pad, the less unnecessary computation TPU does, the better performance we can get. For example, if num_tokens=300, with exponential padding, we pad to 512, with the bucket_padding above, we pad to 320. | ||||
|  | ||||
| However, you need to be careful to choose the padding gap. If the gap is too small, it means the number of buckets is large, leading to increased warmup (precompile) time and higher memory to store the compiled graph. Too many compilaed graphs may lead to HBM OOM. Conversely, an overly large gap yields no performance improvement compared to the default exponential padding. | ||||
| However, you need to be careful to choose the padding gap. If the gap is too small, it means the number of buckets is large, leading to increased warmup (precompile) time and higher memory to store the compiled graph. Too many compiled graphs may lead to HBM OOM. Conversely, an overly large gap yields no performance improvement compared to the default exponential padding. | ||||
|  | ||||
| #### Quantization | ||||
|  | ||||
|  | ||||
| @ -629,7 +629,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies | ||||
|             self, | ||||
|             mm_items: MultiModalDataItems, | ||||
|             hf_processor_mm_kwargs: Mapping[str, object], | ||||
|             out_mm_kwargs: MultiModalKwargs, | ||||
|             out_mm_kwargs: MultiModalKwargsItems, | ||||
|         ) -> Sequence[PromptUpdate]: | ||||
|             hf_config = self.info.get_hf_config() | ||||
|             image_token_id = hf_config.image_token_index | ||||
| @ -778,7 +778,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies | ||||
|             self, | ||||
|             mm_items: MultiModalDataItems, | ||||
|             hf_processor_mm_kwargs: Mapping[str, object], | ||||
|             out_mm_kwargs: MultiModalKwargs, | ||||
|             out_mm_kwargs: MultiModalKwargsItems, | ||||
|         ) -> Sequence[PromptUpdate]: | ||||
|             hf_config = self.info.get_hf_config() | ||||
|             bos_token_id = hf_config.bos_token_id | ||||
|  | ||||
| @ -18,7 +18,7 @@ vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096 | ||||
|  | ||||
| - Download and install [Anything LLM desktop](https://anythingllm.com/desktop). | ||||
|  | ||||
| - On the bottom left of open settings, AI Prooviders --> LLM: | ||||
| - On the bottom left of open settings, AI Providers --> LLM: | ||||
|     - LLM Provider: Generic OpenAI | ||||
|     - Base URL: http://{vllm server host}:{vllm server port}/v1 | ||||
|     - Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ` | ||||
|  | ||||
| @ -9,7 +9,7 @@ vLLM can be run on a cloud based GPU machine with [dstack](https://dstack.ai/), | ||||
| To install dstack client, run: | ||||
|  | ||||
| ```bash | ||||
| pip install "dstack[all] | ||||
| pip install dstack[all] | ||||
| dstack server | ||||
| ``` | ||||
|  | ||||
|  | ||||
| @ -133,7 +133,7 @@ class FusedMoEModularKernel: | ||||
| Typically a FusedMoEPrepareAndFinalize type is backed by an All2All Dispatch & Combine implementation / kernel. For example, | ||||
|  | ||||
| * PplxPrepareAndFinalize type is backed by Pplx All2All kernels, | ||||
| * DeepEPHTPrepareAndFinalize type is backed by DeepEP High-Throughtput All2All kernels, and | ||||
| * DeepEPHTPrepareAndFinalize type is backed by DeepEP High-Throughput All2All kernels, and | ||||
| * DeepEPLLPrepareAndFinalize type is backed by DeepEP Low-Latency All2All kernels. | ||||
|  | ||||
| #### Step 1: Add an All2All manager | ||||
| @ -183,7 +183,7 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking | ||||
|  | ||||
| #### maybe_make_prepare_finalize | ||||
|  | ||||
| The `maybe_make_prepare_finalize` method is responsbile for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled.  The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case.  Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case. | ||||
| The `maybe_make_prepare_finalize` method is responsible for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled.  The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case.  Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case. | ||||
| Please refer to the implementations in, | ||||
|  | ||||
| * `ModelOptNvFp4FusedMoE` | ||||
| @ -198,7 +198,7 @@ Please refer to the implementations in, | ||||
| * `CompressedTensorsW8A8Fp8MoECutlassMethod` | ||||
| * `Fp8MoEMethod` | ||||
| * `ModelOptNvFp4FusedMoE` | ||||
| dervied classes. | ||||
| derived classes. | ||||
|  | ||||
| #### init_prepare_finalize | ||||
|  | ||||
| @ -226,7 +226,7 @@ Doing this will add the new implementation to the test suite. | ||||
|  | ||||
| The unit test file [test_modular_kernel_combinations.py](gh-file:tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script. | ||||
| Example: `python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts` | ||||
| As a side-effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked | ||||
| As a side effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked | ||||
| with incompatible types, the script will error. | ||||
|  | ||||
| ### How To Profile | ||||
|  | ||||
							
								
								
									
										245
									
								
								docs/design/hybrid_kv_cache_manager.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										245
									
								
								docs/design/hybrid_kv_cache_manager.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,245 @@ | ||||
| # Hybrid KV Cache Manager | ||||
|  | ||||
| !!! warning | ||||
|     This document was written based on commit [458e74](https://github.com/vllm-project/vllm/commit/458e74eb907f96069e6d8a4f3c9f457001fef2ea). This feature is still in its early stage and things may change. | ||||
|  | ||||
| ## What is a hybrid model? | ||||
|  | ||||
| Many recent "hybrid" LLMs combine multiple attention types within one model. For example: | ||||
|  | ||||
| 1. Sliding window attention (sw) + full attention (full): gpt-oss, Gemma 2/3, Ministral, cohere, etc. | ||||
| 2. Mamba + full: Bamba, Jamba, Minimax, etc. | ||||
| 3. Local chunked attention + full: Llama4 | ||||
|  | ||||
| To serve these models efficiently, our [KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] must: | ||||
|  | ||||
| 1. Allocate different slots to different layer type, for example: | ||||
|     - Full attention layers: reserve slots for **all** tokens. | ||||
|     - Sliding window layers: reserve slots only for the most recent **`sliding_window_size`** tokens. | ||||
| 2. Support layer-specific prefix-cache rules, for example: | ||||
|     - Full attention: a cache hit prefix requires **all** tokens remain in the KV cache. | ||||
|     - Sliding window: a cache hit prefix only requires the last **`sliding_window_size`** tokens remain in the KV cache. | ||||
|  | ||||
| ## Definitions | ||||
|  | ||||
| 1. **kv hidden size**: The number of bytes to store one token's KV cache for a single layer. | ||||
| 2. **block**: the memory reserved for kv cache are divided into multiple *blocks* with the same *page size* (defined below) | ||||
| 3. **block size**: number of tokens inside a block | ||||
| 4. **page size**: the physical memory size of a block, defined as: | ||||
|  | ||||
|     $$ | ||||
|     \text{num_layers} \times \text{block_size} \times \text{kv_hidden_size} | ||||
|     $$ | ||||
|  | ||||
|     `num_layers` doesn't mean the total number of layers in the model. The exact number depends on the context in this doc. | ||||
|  | ||||
|     !!! note | ||||
|         This is different from `KVCacheSpec.page_size_bytes` in the code, which is defined as: | ||||
|  | ||||
|         $$ | ||||
|         \text{block_size} \times \text{kv_hidden_size} | ||||
|         $$ | ||||
|  | ||||
| ## Allocation | ||||
|  | ||||
| ### High level idea | ||||
|  | ||||
| We use a single memory pool for all layer types. The memory pool is split into multiple blocks with the same page size. [KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] allocates different numbers of blocks to different layers according to its attention type. | ||||
|  | ||||
| The core challenge is ensuring every layer type uses the same **page size**.  For full-attention-only models, the page size is straightforward, defined as: | ||||
|  | ||||
| $$ | ||||
| \text{page_size} = \text{block_size} \times \text{num_hidden_layers} \times \text{kv_hidden_size} | ||||
| $$ | ||||
|  | ||||
| However, in hybrid models, `num_hidden_layers` varies by attention type, which would normally produce mismatched page sizes. The cases below show how we unify them. | ||||
|  | ||||
| ### Case 1: toy model | ||||
|  | ||||
| Let's start with a toy example: a model has 1 full attention layer and 3 sliding window attention layers. All layers have the same `kv_hidden_size`. | ||||
|  | ||||
| We let each block to hold `block_size` tokens for one layer, so: | ||||
|  | ||||
| $$ | ||||
| \text{page_size} = \text{kv_hidden_size} \times \text{block_size} | ||||
| $$ | ||||
|  | ||||
| [KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] allocates a different number of blocks to each layer. | ||||
|  | ||||
| This case is only a toy example. For real models, please refer to the following cases. | ||||
|  | ||||
| ### Case 2: same `kv_hidden_size` and a regular pattern | ||||
|  | ||||
| When the model has more layers, e.g., 20 sliding window attention layers and 10 full attention layers with the same `kv_hidden_size`. Calling the allocator once per layer (30 calls) is OK but becomes inefficient. As a solution, we group the allocation of layers that need the same number of blocks to reduce the number of calls. | ||||
|  | ||||
| The grouping is feasible because there is usually a beautiful ratio between the number of different types of layers. For example: | ||||
|  | ||||
| - Gemma-2: 1 sw : 1 full | ||||
| - Llama 4: 3 local : 1 full | ||||
|  | ||||
| Our example can be regarded as 2 sw : 1 full. We can allocate blocks as if there are 2 sw and 1 full in the model, and repeat the result by 10 times to generate the `block_ids` for the 30 layers. The page size becomes: | ||||
|  | ||||
| $$ | ||||
| 10 \times \text{kv_hidden_size} \times \text{block_size} | ||||
| $$ | ||||
|  | ||||
| Assume `block_size` 16, sliding window size 32, request length 112, then for the above example model, we need to allocate 11 blocks (0-6 for full, 7-8 for sw group 1, 9-10 for sw group 2). | ||||
|  | ||||
|  | ||||
|  | ||||
| Here, "/" denotes no block needed (sliding‑window layers don't need slots for early tokens). | ||||
|  | ||||
| See the formal definition below. The layers are divided into multiple *KV Cache Groups* so that there is: | ||||
|  | ||||
| 1. **Identical attention type inside each group**: Each group only contains layers with the same attention type and thus need the same number of blocks for a given request. This enables layers in the same group share the same block ids without memory waste. | ||||
| 2. **Identical page size across groups**: Because our memory pool only have one page size. | ||||
|  | ||||
| Our example model is divided into 3 KV cache groups: | ||||
|  | ||||
| - Group 0: 10 full attention layers (full.0 - full.9) | ||||
| - Group 1: 10 sliding window attention layers (sw.0 - sw.9) | ||||
| - Group 2: 10 sliding window attention layers (sw.10 - sw.19) | ||||
|  | ||||
| Obviously, it satisfies rule 1. For rule 2, all 3 groups have | ||||
|  | ||||
| $$ | ||||
| 10 \times \text{kv_hidden_size} \times \text{block_size} | ||||
| $$ | ||||
|  | ||||
| as their page size. | ||||
|  | ||||
| ### Case 3: same `kv_hidden_size` and no regular pattern | ||||
|  | ||||
| Unfortunately, not all models have such a beautiful ratio, and approach in Case 2 will produce too many small groups. For example, Gemma-3-27b has 52 sliding window attention layers and 10 full attention layers. With the constraints in case 2, it would be 26 sliding window groups and 5 full attention groups, each contains 2 layers. The allocation is still inefficient. To reduce the number of kv cache groups, we group layers using the smallest layer count among all attention types. For example, min(52, 10)=10 layers per group in Gemma-3-27b. Then the grouping result is: | ||||
|  | ||||
| - Group 0: 10 full attention layers (full.0 - full.9) | ||||
| - Group 1: 10 sliding window attention layers (sw.0 - sw.9) | ||||
| - Group 2: 10 sliding window attention layers (sw.10 - sw.19) | ||||
| - ... | ||||
| - Group 6: 10 sliding window attention layers (sw.40 - sw.49) | ||||
| - Group 7: 2 sliding window attention layers (sw.50 - sw.51) and 8 padding layers | ||||
|  | ||||
| We will update this algorithm if this heuristic leads to a bad result when a new model comes out (e.g., 20 full + 30 sw, the group size should be 10 instead of 20). | ||||
|  | ||||
| This case happens in Gemma-3 series models, and models in case 2 but with eagle speculative decoding which introduce one full attention layer. The solution has some memory waste and is not perfect. Please report any cases where padding overhead becomes unacceptable so we can refine the algorithm. | ||||
|  | ||||
| ### Case 4: different `kv_hidden_size` (mainly hybrid mamba models) | ||||
|  | ||||
| Some architectures (e.g., Bamba, Jamba, Minimax) interleave standard attention layers with Mamba layers, where each Mamba layer's state size per token can be much larger than the attention layers' `kv_hidden_size`. Because we only support a single page size across all groups, we must reconcile these differing hidden sizes. | ||||
|  | ||||
| The current algorithm is: | ||||
|  | ||||
| 1. Increase the `block_size` of attention layers until | ||||
|     $$ | ||||
|     \text{block_size} \times \text{kv_hidden_size}_{\text{att}} \ge \text{state_size}_{\text{mamba}} | ||||
|     $$ | ||||
| 2. Pad the mamba state per layer to | ||||
|     $$ | ||||
|     \text{block_size} \times \text{kv_hidden_size}_{\text{att}} | ||||
|     $$ | ||||
| 3. Apply the grouping strategy in case 3. | ||||
|  | ||||
| !!! note | ||||
|     This can lead to more than 400 `block_size` for attention layers, which is too large. Another padding strategy is to increase `block_size` until | ||||
|  | ||||
|     $$ | ||||
|     \text{block_size} \times \text{kv_hidden_size}_{\text{att}} \times \text{num_attn_layers} \ge \text{state_size}_{\text{mamba}} | ||||
|     $$ | ||||
|  | ||||
|     This padding strategy is still a work in progress. | ||||
|  | ||||
| ### Case 5: KV sharing | ||||
|  | ||||
| KV sharing refers to a layer using the KV cache of another layer, e.g., gemma-3n. | ||||
| In these models, [KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager] ignores all layers with kv sharing and only allocates KV cache for layers that need kv cache, and some patches are made in model runner to apply the allocation result to kv sharing layers. | ||||
|  | ||||
| ## Prefix caching | ||||
|  | ||||
| For simplicity, we assume `block_size=1` in this section. | ||||
|  | ||||
| ### High level idea | ||||
|  | ||||
| The block pool uses a dict similar to `tuple(block_hash, group_id) -> block` to catch the full blocks. That means the same tokens of different groups are cached and evicted independently. | ||||
|  | ||||
| When a new request comes in, we check the cache hit prefix of each group, and return the intersection of these groups as the cached prefix of the request. See below for the detailed algorithm for checking the cache hit of one group & performing the intersection. | ||||
|  | ||||
| ### Case 0: full attention only models | ||||
|  | ||||
| For full attention layers, blocks are allocated for all tokens in the request. For details on the underlying design, see [Prefix Caching](prefix_caching.md) | ||||
|  | ||||
| To find the longest cache hit prefix of a request, we enumerate from left (the first block) to right (the last block), checking whether the block is cached, and exit when cache misses. For example, we will return the first 7 tokens (0-6) as the cache hit prefix in the below example (blue blocks are cached): | ||||
|  | ||||
|  | ||||
|  | ||||
| ### Case 1: sliding window attention only models | ||||
|  | ||||
| For sliding window attention layers, a naive implementation for memory allocation is to allocate `sliding_window_size` blocks and fill in the blocks in a round-robin way. But this naive implementation is not compatible with prefix caching so we didn't pick this design. In vLLM,  we allocate different blocks for different tokens and free blocks that are outside the sliding window. | ||||
|  | ||||
| For a new request, the cache hit prefix only requires the last `sliding_window_size - 1` tokens being cached. | ||||
| Let's say `sliding_window_size = 4` and `block_size = 1`, and the request is a 15-token prompt (blue blocks are cached): | ||||
|  | ||||
|  | ||||
|  | ||||
| There are 3 possible cache hit prefixes: | ||||
|  | ||||
| - cache hit length 5, compute prefill with [2, 3, 4] → [5, 6, …, 14] | ||||
| - cache hit length 6, compute prefill with [3, 4, 5] → [6, 7, …, 14] | ||||
| - cache hit length 14, compute prefill with [11, 12, 13] → [14] (most efficient) | ||||
|  | ||||
| We can check the cache hit from right to left, and early exit when we find a match.This is opposite from full attention, where we check from left to right and early exit when the match fails. One potential cons (compared to full attention) is that we end up iterating over the entire list of tokens when there's no match, which is often a common case. This could potentially cause non-negligible overheads, but fine with full + swa, as discussed below. | ||||
|  | ||||
| ### Case 2: sliding window attention + full attention models | ||||
|  | ||||
| The first problem is how to find the cache hit prefix. We need to "intersect" the cache hits of global and sliding window attention layers by: | ||||
|  | ||||
| 1. Get the longest cache hit for full attention (scanning from left to right) | ||||
| 2. Get the longest cache hit for sliding window attention that is within that length. Implemented by checking cache hits from right to left starting from the cache hit length of full attention. | ||||
|  | ||||
| It can be ensured that the resulting cache hit of sliding window attention layers is also a cache hit of full attention layers. This is more efficient than finding all possible prefixes of each group and doing the intersection, because our approach can exit early if there is no cache hit. | ||||
|  | ||||
| The algorithm applies to models with exactly two attention types full attention + X, where X can be an arbitrary efficient attention algorithm like sliding window, llama 4 local attention, and mamba. It doesn't support models without full attention layers, and models with more than 2 types of attention. This is enough for most hybrid models at the moment of writing this doc. | ||||
|  | ||||
| The second question is the cache eviction policy. For now, we use one LRU queue for all kv cache groups. The blocks are added to the LRU queue when freed, either because the request is finished or the block is out of the sliding window. | ||||
|  | ||||
| ### Case 3: mamba models | ||||
|  | ||||
| The prefix caching support of the mamba model is work in progress. Once implemented, models with mamba layer + full attention layer can be supported via the full attention + X algorithm in case 2. | ||||
|  | ||||
| ## Implementation | ||||
|  | ||||
| ### Overview | ||||
|  | ||||
|  | ||||
|  | ||||
| The `KVCacheManager` is organized into 3 layers: | ||||
|  | ||||
| - **[KVCacheManager][vllm.v1.core.kv_cache_manager.KVCacheManager]**: The interface between the scheduler and kv cache management system. | ||||
| - **[KVCacheCoordinator][vllm.v1.core.kv_cache_coordinator.KVCacheCoordinator]**: coordinate per-group SingleTypeKVCacheManagers to generate the allocation result of a request. Depending on the model's configuration, one of these coordinators is chosen: | ||||
|     - **[KVCacheCoordinatorNoPrefixCache][vllm.v1.core.kv_cache_coordinator.KVCacheCoordinatorNoPrefixCache]**: Used when prefix caching is disabled. | ||||
|     - **[UnitaryKVCacheCoordinator][vllm.v1.core.kv_cache_coordinator.UnitaryKVCacheCoordinator]**: If only one KV cache group. The prefix caching logic is simplified as no intersection is needed. | ||||
|     - **[HybridKVCacheCoordinator][vllm.v1.core.kv_cache_coordinator.HybridKVCacheCoordinator]**: Handles exactly two KV cache groups (must include one full‑attention group plus one other efficient‑attention group). Other cases are not implemented. You can disable prefix caching to use the KVCacheCoordinatorNoPrefixCache. | ||||
| - **[SingleTypeKVCacheManager][vllm.v1.core.single_type_kv_cache_manager.SingleTypeKVCacheManager]**: Each instance manages allocation and prefix caching for one KV cache group, implementing the attention‑type–specific logic (e.g., full attention, sliding window, Mamba). | ||||
|  | ||||
| The blue box in the above figure shows the case with 10 full attention layers and 20 sliding window attention layers, thus: | ||||
|  | ||||
| - use `HybridKVCacheCoordinator` | ||||
| - use 1 `FullAttentionManager` and 2 `SlidingWindowManager` for the 3 `KVCacheGroup`s. | ||||
|  | ||||
| ### Memory Layout | ||||
|  | ||||
| For a model with n `KVCacheGroup`s, each with m layers, we allocate m buffers. Each buffer is shared by n layers, one from each group. | ||||
|  | ||||
| The following figure is for a model with 10 full attention layers (full.0 - full.9) and 20 sliding window attention layers (sw.0-sw.19). It follows "case 2" in "Allocation" section and is divided into 3 groups: | ||||
|  | ||||
| - Group 0: 10 full attention layers (full.0 - full.9) | ||||
| - Group 1: 10 sliding window attention layers (sw.0 - sw.9) | ||||
| - Group 2: 10 sliding window attention layers (sw.10 - sw.19) | ||||
|  | ||||
| And for a request, we allocate 11 blocks with `block_id` 0-6 to group 0, 7-8 to group 1, and 9-10 to group 2. | ||||
|  | ||||
| With such an example, the physical memory is divided into 10 buffers (`KVCacheTensor` 0 - `KVCacheTensor` 9). Each buffer is shared by 3 layers (e.g., `KVCacheTensor` 0 is shared by full.0 from group 0, sw.0 from group 1, and sw.10 from group 2) and is divided into pieces with size `block_size * kv_hidden_size`. The KV cache of these 3 attention layers are saved to different pieces of the buffer based on the allocated `block_ids`: | ||||
|  | ||||
|  | ||||
|  | ||||
| !!! note | ||||
|     One logic "block" is mapped to 10 pieces in the 10 buffers of the physical memory. | ||||
| @ -565,7 +565,7 @@ model and then validate those tokens with the larger model. | ||||
| - `vllm:spec_decode_num_emitted_tokens_total` (Counter) | ||||
|  | ||||
| There is a PR under review (<gh-pr:12193>) to add "prompt lookup (ngram)" | ||||
| seculative decoding to v1. Other techniques will follow. We should | ||||
| speculative decoding to v1. Other techniques will follow. We should | ||||
| revisit the v0 metrics in this context. | ||||
|  | ||||
| !!! note | ||||
|  | ||||
| @ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`. | ||||
|  | ||||
| There are other miscellaneous places hard-coding the use of `spawn`: | ||||
|  | ||||
| - <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L135> | ||||
| - <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/all_reduce_utils.py#L135> | ||||
| - <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/entrypoints/openai/api_server.py#L184> | ||||
|  | ||||
| Related PRs: | ||||
|  | ||||
| @ -422,7 +422,7 @@ a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle | ||||
| a whole block of value tokens. And each `accs` in each thread | ||||
| contains 8 elements that accumulated at 8 different head positions. | ||||
| For the thread 0, the `accs` variable will have 8 elements, which | ||||
| are 0th, 32th … 224th elements of a value head that are accumulated | ||||
| are 0th, 32nd … 224th elements of a value head that are accumulated | ||||
| from all assigned 8 tokens. | ||||
|  | ||||
| ## LV | ||||
|  | ||||
| @ -2,6 +2,6 @@ | ||||
|  | ||||
| vLLM's examples are split into three categories: | ||||
|  | ||||
| - If you are using vLLM from within Python code, see [Offline Inference](./offline_inference/) | ||||
| - If you are using vLLM from an HTTP application or client, see [Online Serving](./online_serving/) | ||||
| - For examples of using some of vLLM's advanced features (e.g. LMCache or Tensorizer) which are not specific to either of the above use cases, see [Others](./others/) | ||||
| - If you are using vLLM from within Python code, see [Offline Inference](./offline_inference) | ||||
| - If you are using vLLM from an HTTP application or client, see [Online Serving](./online_serving) | ||||
| - For examples of using some of vLLM's advanced features (e.g. LMCache or Tensorizer) which are not specific to either of the above use cases, see [Others](./others) | ||||
|  | ||||
| @ -4,7 +4,6 @@ Quantization trades off model precision for smaller memory footprint, allowing l | ||||
|  | ||||
| Contents: | ||||
|  | ||||
| - [Supported Hardware](supported_hardware.md) | ||||
| - [AutoAWQ](auto_awq.md) | ||||
| - [AutoRound](auto_round.md) | ||||
| - [BitsAndBytes](bnb.md) | ||||
| @ -19,3 +18,50 @@ Contents: | ||||
| - [AMD Quark](quark.md) | ||||
| - [Quantized KV Cache](quantized_kvcache.md) | ||||
| - [TorchAO](torchao.md) | ||||
|  | ||||
| ## Supported Hardware | ||||
|  | ||||
| The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM: | ||||
|  | ||||
| <style> | ||||
| td:not(:first-child) { | ||||
|   text-align: center !important; | ||||
| } | ||||
| td { | ||||
|   padding: 0.5rem !important; | ||||
|   white-space: nowrap; | ||||
| } | ||||
|  | ||||
| th { | ||||
|   padding: 0.5rem !important; | ||||
|   min-width: 0 !important; | ||||
| } | ||||
|  | ||||
| th:not(:first-child) { | ||||
|   writing-mode: vertical-lr; | ||||
|   transform: rotate(180deg) | ||||
| } | ||||
| </style> | ||||
|  | ||||
| | Implementation        | Volta   | Turing   | Ampere   | Ada   | Hopper   | AMD GPU   | Intel GPU   | Intel Gaudi | x86 CPU   | AWS Neuron   | Google TPU   | | ||||
| |-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------| | ||||
| | AWQ                   | ❌      | ✅︎       | ✅︎       | ✅︎    | ✅︎       | ❌         | ✅︎          | ❌         | ✅︎        | ❌          | ❌           | | ||||
| | GPTQ                  | ✅︎      | ✅︎       | ✅︎       | ✅︎    | ✅︎       | ❌         | ✅︎          | ❌         | ✅︎        | ❌          | ❌           | | ||||
| | Marlin (GPTQ/AWQ/FP8) | ❌      | ❌       | ✅︎       | ✅︎    | ✅︎       | ❌         | ❌          | ❌         | ❌        | ❌          | ❌           | | ||||
| | INT8 (W8A8)           | ❌      | ✅︎       | ✅︎       | ✅︎    | ✅︎       | ❌         | ❌          | ❌         | ✅︎        | ✅︎          | ✅︎           | | ||||
| | FP8 (W8A8)            | ❌      | ❌       | ❌       | ✅︎    | ✅︎       | ✅︎         | ❌          | ❌         | ❌        | ✅︎          | ❌           | | ||||
| | BitBLAS               | ✅︎      | ✅       | ✅︎       | ✅︎    | ✅︎       | ❌         | ❌          | ❌         | ❌        | ❌          | ❌           | | ||||
| | BitBLAS (GPTQ)        | ❌      | ❌       | ✅︎       | ✅︎    | ✅︎       | ❌         | ❌          | ❌         | ❌        | ❌          | ❌           | | ||||
| | bitsandbytes          | ✅︎      | ✅︎       | ✅︎       | ✅︎    | ✅︎       | ❌         | ❌          | ❌         | ❌        | ❌          | ❌           | | ||||
| | DeepSpeedFP           | ✅︎      | ✅︎       | ✅︎       | ✅︎    | ✅︎       | ❌         | ❌          | ❌         | ❌        | ❌          | ❌           | | ||||
| | GGUF                  | ✅︎      | ✅︎       | ✅︎       | ✅︎    | ✅︎       | ✅︎         | ❌          | ❌         | ❌        | ❌          | ❌           | | ||||
| | INC (W8A8)            | ❌      | ❌       | ❌       | ❌    | ❌       | ❌         | ❌          | ✅︎         | ❌        | ❌          | ❌           | | ||||
|  | ||||
| - Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. | ||||
| - ✅︎ indicates that the quantization method is supported on the specified hardware. | ||||
| - ❌ indicates that the quantization method is not supported on the specified hardware. | ||||
|  | ||||
| !!! note | ||||
|     This compatibility chart is subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods. | ||||
|  | ||||
|     For the most up-to-date information on hardware support and quantization methods, please refer to <gh-dir:vllm/model_executor/layers/quantization> or consult with the vLLM development team. | ||||
|  | ||||
| @ -5,7 +5,7 @@ vLLM now supports [BitBLAS](https://github.com/microsoft/BitBLAS) for more effic | ||||
| !!! note | ||||
|     Ensure your hardware supports the selected `dtype` (`torch.bfloat16` or `torch.float16`). | ||||
|     Most recent NVIDIA GPUs support `float16`, while `bfloat16` is more common on newer architectures like Ampere or Hopper. | ||||
|     For details see [supported hardware](supported_hardware.md). | ||||
|     For details see [supported hardware](README.md#supported-hardware). | ||||
|  | ||||
| Below are the steps to utilize BitBLAS with vLLM. | ||||
|  | ||||
|  | ||||
| @ -79,7 +79,7 @@ Since simple RTN does not require data for weight quantization and the activatio | ||||
| Install `vllm` and `lm-evaluation-harness` for evaluation: | ||||
|  | ||||
| ```bash | ||||
| pip install vllm lm-eval==0.4.4 | ||||
| pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] | ||||
| ``` | ||||
|  | ||||
| Load and run the model in `vllm`: | ||||
|  | ||||
| @ -7,7 +7,7 @@ Intel Gaudi supports quantization of various modules and functions, including, b | ||||
| [Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules). | ||||
|  | ||||
| !!! note | ||||
|     Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. | ||||
|     Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vLLM HPU extension](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. | ||||
|  | ||||
| !!! note | ||||
|     `QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options). | ||||
|  | ||||
| @ -18,7 +18,7 @@ pip install llmcompressor | ||||
| Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: | ||||
|  | ||||
| ```bash | ||||
| pip install vllm lm-eval==0.4.4 | ||||
| pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] | ||||
| ``` | ||||
|  | ||||
| ## Quantization Process | ||||
|  | ||||
| @ -19,7 +19,7 @@ pip install llmcompressor | ||||
| Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: | ||||
|  | ||||
| ```bash | ||||
| pip install vllm lm-eval==0.4.4 | ||||
| pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] | ||||
| ``` | ||||
|  | ||||
| ## Quantization Process | ||||
|  | ||||
| @ -20,7 +20,7 @@ for more installation details. | ||||
| Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: | ||||
|  | ||||
| ```bash | ||||
| pip install vllm lm-eval==0.4.4 | ||||
| pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] | ||||
| ``` | ||||
|  | ||||
| ## Quantization Process | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	