Compare commits

..

450 Commits

Author SHA1 Message Date
d5bf492f16 Merge branch 'main' into optimize-prefix-caching-scheduling 2024-06-04 00:20:15 +00:00
f775a07e30 [FRONTEND] OpenAI tools support named functions (#5032) 2024-06-03 18:25:29 -05:00
4f0d17c05c New CI template on AWS stack (#5110)
Signed-off-by: kevin <kevin@anyscale.com>
2024-06-03 16:16:43 -07:00
10c38e3e46 [Misc]: Implement CPU/GPU swapping in BlockManagerV2 (#3834) 2024-06-03 13:37:11 -07:00
cafb8e06c5 [CI/BUILD] enable intel queue for longer CPU tests (#4113) 2024-06-03 10:39:50 -07:00
cbb2f59cc8 [Kernel] Pass a device pointer into the quantize kernel for the scales (#5159) 2024-06-03 09:52:30 -07:00
0ab278ca31 [Core] Remove unnecessary copies in flash attn backend (#5138) 2024-06-03 09:39:31 -07:00
7a64d24aad [Core] Support image processor (#4197) 2024-06-02 22:56:41 -07:00
8c7bab79f5 simplify code 2024-06-03 03:36:38 +00:00
dfbe60dc62 [Misc] Simplify code and fix type annotations in conftest.py (#5118) 2024-06-02 16:05:50 -07:00
a66cf40b20 [Kernel][ROCm][AMD] enable fused topk_softmax kernel for moe layer (#4927)
This PR enables the fused topk_softmax kernel used in moe layer for HIP
2024-06-02 14:13:26 -07:00
f790ad3c50 [Frontend][OpenAI] Support for returning max_model_len on /v1/models response (#4643) 2024-06-02 08:06:13 +00:00
ed59a7ed23 Update test_ignore_eos (#4898) 2024-06-02 02:21:53 +00:00
1936d7bab0 format 2024-06-02 00:02:54 +00:00
996cf2de5c Fix hashing logic for non-full blocks 2024-06-02 00:01:30 +00:00
044793d8df [BugFix] Prevent LLM.encode for non-generation Models (#5184)
Co-authored-by: mgoin <michael@neuralmagic.com>
2024-06-01 23:35:41 +00:00
c2d6d2f960 [Bugfix]: Fix issues related to prefix caching example (#5177) (#5180) 2024-06-01 15:53:52 -07:00
8279078e21 [Bugfix] Remove deprecated @abstractproperty (#5174) 2024-06-01 22:40:25 +00:00
b9c0605a8e [Feature][Kernel] Support bitsandbytes quantization and QLoRA (#4776) 2024-06-01 14:51:10 -06:00
37464a0f74 [Bugfix] Fix call to init_logger in openai server (#4765) 2024-06-01 17:18:50 +00:00
c354072828 [Minor] Fix the path typo in loader.py: save_sharded_states.py -> save_sharded_state.py (#5151)
Signed-off-by: Ye Cao <caoye.cao@alibaba-inc.com>
2024-06-01 17:11:22 +00:00
f081c3ce4b [Kernel] Update Cutlass fp8 configs (#5144)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
2024-06-01 08:46:07 +00:00
260d119e86 [Kernel] Refactor CUTLASS kernels to always take scales that reside on the GPU (#5137) 2024-06-01 06:45:32 +00:00
a360ff80bb [CI/Build] CMakeLists: build all extensions' cmake targets at the same time (#5034) 2024-05-31 22:06:45 -06:00
1197e02141 [Build] Guard against older CUDA versions when building CUTLASS 3.x kernels (#5168) 2024-05-31 17:21:38 -07:00
657579113f [Doc] Add checkmark for GPTBigCodeForCausalLM LoRA support (#5171) 2024-05-31 17:20:19 -07:00
e9899fb7a4 [Model] Enable FP8 QKV in MoE and refine kernel tuning script (#5039) 2024-05-31 14:29:19 -07:00
a377f0bd5e [Misc]: optimize eager mode host time (#4196)
Co-authored-by: xuhao <xuhao@cambricon.com>
2024-05-31 13:14:50 +08:00
e9d3aa04f6 Revert "[Kernel] Marlin_24: Ensure the mma.sp instruction is using the ::ordered_metadata modifier (introduced with PTX 8.5)" (#5149) 2024-05-30 22:00:26 -07:00
a22dea54d3 [Model] Support MAP-NEO model (#5081)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2024-05-30 19:24:41 -07:00
533c217792 Fix cutlass sm_90a vesrion in CMakeList 2024-05-31 02:13:01 +00:00
6d21fa1cad [Kernel] Marlin_24: Ensure the mma.sp instruction is using the ::ordered_metadata modifier (introduced with PTX 8.5) (#5136) 2024-05-30 21:02:11 -05:00
b35be5403f [Bugfix] Avoid Warnings in SparseML Activation Quantization (#5120) 2024-05-30 17:04:37 -07:00
45a1a69b98 [Build] Disable sm_90a in cu11 (#5141) 2024-05-30 14:37:16 -07:00
87a658c812 Bump version to v0.4.3 (#5046) 2024-05-30 11:13:46 -07:00
429d89720e add doc about serving option on dstack (#3074)
Co-authored-by: Roger Wang <ywang@roblox.com>
2024-05-30 10:11:07 -07:00
a9bcc7afb2 [Doc] Use intersphinx and update entrypoints docs (#5125) 2024-05-30 09:59:23 -07:00
d79d9eaaff [Misc] remove duplicate definition of seq_lens_tensor in model_runner.py (#5129) 2024-05-30 06:56:19 -07:00
f758505c73 [CI/Build] increase wheel size limit to 200 MB (#5130) 2024-05-30 06:29:48 -07:00
d910816c73 [Bugfix] Automatically Detect SparseML models (#5119) 2024-05-30 12:58:37 +00:00
87d41c849d [BUGFIX] [FRONTEND] Correct chat logprobs (#5029)
Co-authored-by: Breno Faria <breno.faria@intrafind.com>
2024-05-30 02:52:14 -07:00
e07aff9e52 [CI/Build] Docker cleanup functionality for amd servers (#5112)
Co-authored-by: Alexey Kondratiev <alexey.kondratiev@amd.com>
Co-authored-by: Alexei-V-Ivanov-AMD <156011006+Alexei-V-Ivanov-AMD@users.noreply.github.com>
Co-authored-by: Alexei V. Ivanov <alexei.ivanov@amd.com>
Co-authored-by: omkarkakarparthi <okakarpa>
2024-05-30 03:27:39 +00:00
5bf185a1c4 [Bugfix] gptq_marlin: Ensure g_idx_sort_indices is not a Parameter (#5108) 2024-05-30 00:30:18 +00:00
4fbcb0f27e [Doc][Build] update after removing vllm-nccl (#5103)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
2024-05-29 23:51:18 +00:00
7c3604fb68 [Bugfix] logprobs is not compatible with the OpenAI spec #4795 (#5031) 2024-05-29 16:13:22 -07:00
b1c255630d [Core] Avoid the need to pass None values to Sequence.inputs (#5099) 2024-05-29 16:05:01 -07:00
eb6c50cdc2 [Bugfix][CI/Build] Fix codespell failing to skip files in git diff (#5097) 2024-05-29 16:02:54 -07:00
eecd864388 [Bugfix][CI/Build] Fix test and improve code for merge_async_iterators (#5096) 2024-05-29 16:02:25 -07:00
ae495c74ea [Doc]Replace deprecated flag in readme (#4526) 2024-05-29 22:26:33 +00:00
4238bc82f2 [Core] Cross-attention KV caching and memory-management (towards eventual encoder/decoder model support) (#4837) 2024-05-29 16:09:13 +00:00
594392d27a [Core][Distributed] improve p2p access check (#4992) 2024-05-29 11:29:07 +00:00
18c1f16d86 [Bugfix] Fix arguments passed to Sequence in stop checker test (#5092) 2024-05-29 07:16:41 +00:00
5bd3c65072 [Core][Optimization] remove vllm-nccl (#5091) 2024-05-29 05:13:52 +00:00
616e600e0b [Misc] add gpu_memory_utilization arg (#5079)
Signed-off-by: pandyamarut <pandyamarut@gmail.com>
2024-05-28 17:16:18 -07:00
dfba529b40 [Bugfix] Remove the last EOS token unless explicitly specified (#5077) 2024-05-28 17:15:35 -07:00
5ae5ed1e60 [Core] Consolidate prompt arguments to LLM engines (#4328)
Co-authored-by: Roger Wang <ywang@roblox.com>
2024-05-28 13:29:31 -07:00
290f4ada2b [Docs] Add Dropbox as sponsors (#5089) 2024-05-28 10:29:09 -07:00
dd8de11f0a [Kernel][ROCm][AMD] Add fused_moe Triton configs for MI300X (#4951)
This PR adds Triton kernel configs for the MoE kernel for MI300X
2024-05-28 16:03:23 +00:00
9ba415588a [BugFix] Fix Embedding Models with TP>1 (#5075) 2024-05-28 08:32:42 -07:00
d4f3985907 [Core] Sliding window for block manager v2 (#4545)
Co-authored-by: Ruth Evans <ruthevans@Ruths-MacBook-Pro.local>
2024-05-28 11:07:07 +09:00
890aa93d27 [Model] Add support for falcon-11B (#5069) 2024-05-27 16:41:43 -07:00
fbdb7b3ee2 [Core] Allow AQLM on Pascal (#5058) 2024-05-27 15:26:14 -07:00
1102bef219 [Bugfix / Core] Prefix Caching Guards (merged with main) (#4846)
Co-authored-by: rsnm2 <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
2024-05-27 15:18:17 -07:00
f17a1a8f96 [Misc] Make Serving Benchmark More User-friendly (#5044) 2024-05-25 17:28:16 +00:00
d5a1697772 [Dynamic Spec Decoding] Minor fix for disabling speculative decoding (#5000) 2024-05-25 10:00:14 -07:00
325c119961 [Misc] add logging level env var (#5045) 2024-05-24 23:49:49 -07:00
8e192ff967 [Kernel][Backend][Model] Blocksparse flash attention kernel and Phi-3-Small model (#4799)
Co-authored-by: beagleski <yunanzhang@microsoft.com>
Co-authored-by: bapatra <bapatra@microsoft.com>
Co-authored-by: Barun Patra <codedecde@users.noreply.github.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
2024-05-24 22:00:52 -07:00
e64fde4b01 [Core][Bugfix]: fix prefix caching for blockv2 (#4764)
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
2024-05-24 10:07:09 -07:00
919770957f [Bugfix] Fix Mistral v0.3 Weight Loading (#5005)
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
2024-05-24 12:28:27 +00:00
6a50f4cafa [Doc] add ccache guide in doc (#5012)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
2024-05-23 23:21:54 +00:00
e3470f8753 [Core]: Option To Use Prompt Token Ids Inside Logits Processor (#4985)
Co-authored-by: Elisei Smirnov <el.smirnov@innopolis.university>
2024-05-23 22:04:24 +00:00
a1242324c9 [Kernel] Initial Activation Quantization Support (#4525)
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
2024-05-23 21:29:18 +00:00
5eda2ea02a [Core][1/N] Support send/recv in PyNCCL Groups (#4988)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
2024-05-23 09:54:48 -07:00
2ba80bed27 [Bugfix] Update Dockerfile.cpu to fix NameError: name 'vllm_ops' is not defined (#5009) 2024-05-23 09:08:58 -07:00
6066253296 Marlin 24 prefill performance improvement (about 25% better on average) (#4983) 2024-05-23 02:39:27 -04:00
ee3eea0a1b [Misc] Take user preference in attention selector (#4960) 2024-05-23 07:55:56 +09:00
a36de682d4 [Minor] Fix small typo in llama.py: QKVParallelLinear -> QuantizationConfig (#4991) 2024-05-22 22:26:56 +00:00
eb6d3c264d [Core] Eliminate parallel worker per-step task scheduling overhead (#4894) 2024-05-23 06:17:27 +09:00
97b030005c [Model] LoRA gptbigcode implementation (#3949) 2024-05-22 13:58:59 -07:00
a3a73ab069 [Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893)
The 2nd PR for #4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
2024-05-22 13:28:20 -07:00
8674f9880e [Kernel] Fixup for CUTLASS kernels in CUDA graphs (#4954)
Pass the CUDA stream into the CUTLASS GEMMs, to avoid future issues with CUDA graphs
2024-05-22 14:10:43 +00:00
c74c913bfb [misc] remove comments that were supposed to be removed (#4977) 2024-05-22 09:02:58 -04:00
5f6d10c14c [CI/Build] Enforce style for C++ and CUDA code with clang-format (#4722) 2024-05-22 07:18:41 +00:00
9b9a10d6cb [Frontend] Dynamic RoPE scaling (#4638) 2024-05-22 01:32:35 -04:00
99eff67ba9 [Bugfix][Kernel] Add head size check for attention backend selection (#4944) 2024-05-21 15:33:25 -04:00
14772eeb8e [Bugfix] Fix flag name for max_seq_len_to_capture (#4935)
Signed-off-by: kerthcet <kerthcet@gmail.com>
2024-05-21 09:30:52 -07:00
757b62c495 [CI/Build] Codespell ignore build/ directory (#4945) 2024-05-21 09:06:10 -07:00
e941f88584 [Docs] Add acknowledgment for sponsors (#4925) 2024-05-21 00:17:25 -07:00
f12c3b5b3d [Model] Add Phi-2 LoRA support (#4886) 2024-05-21 14:24:17 +09:00
d130b573a0 [Model] add rope_scaling support for qwen2 (#4930) 2024-05-21 05:22:22 +00:00
65ae8c2c8f [Core] Fix scheduler considering "no LoRA" as "LoRA" (#4897) 2024-05-20 17:48:32 -07:00
c3af44722c [Doc]Add documentation to benchmarking script when running TGI (#4920) 2024-05-20 20:16:57 +00:00
1937e29848 [Core] Sharded State Loader download from HF (#4889) 2024-05-20 11:46:12 -07:00
f0eecee610 [Bugfix] Fix dummy weight for fp8 (#4916)
Allow dummy load format for fp8,
torch.uniform_ doesn't support FP8 at the moment

Co-authored-by: Mor Zusman <morz@ai21.com>
2024-05-20 18:44:25 +00:00
943e72ca56 [Build/CI] Enabling AMD Entrypoints Test (#4834)
Co-authored-by: Alexey Kondratiev <alexey.kondratiev@amd.com>
2024-05-20 11:29:28 -07:00
546a97ef69 [Misc]: allow user to specify port in distributed setting (#4914) 2024-05-20 17:45:06 +00:00
da5a0b539d Remove marlin warning (#4918) 2024-05-20 14:55:34 +00:00
6287537a0c [Model] LLaVA model refactor (#4910) 2024-05-20 08:11:25 +00:00
b57e6c5949 [Kernel] Add flash-attn back (#4907) 2024-05-19 18:11:30 -07:00
27ce85476e [Kernel] Add marlin_24 unit tests (#4901) 2024-05-19 11:37:34 -04:00
f68470e803 [Bugfix][Model] Add base class for vision-language models (#4809) 2024-05-19 00:13:33 -07:00
2e9a2227ec [Lora] Support long context lora (#4787)
Currently we need to call rotary embedding kernel for each LoRA, which makes it hard to serve multiple long context length LoRA. Add batched rotary embedding kernel and pipe it through.

It replaces the rotary embedding layer to the one that is aware of multiple cos-sin-cache per scaling factors.

Follow up of https://github.com/vllm-project/vllm/pull/3095/files
2024-05-18 16:05:23 +09:00
c0724fc915 [ROCm][Hardware][AMD] Adding Navi21 to fallback to naive attention if Triton is not used (#4658) 2024-05-18 05:09:11 +00:00
86b45ae065 [Bugfix] Relax tiktoken to >= 0.6.0 (#4890) 2024-05-17 12:58:52 -06:00
c5711ef985 [Doc] Update Ray Data distributed offline inference example (#4871) 2024-05-17 10:52:11 -07:00
48d5985a08 Sync huggingface modifications of qwen Moe model (#4774) 2024-05-17 09:43:19 -07:00
33e0823de5 [Bugfix] fix rope error when load models with different dtypes (#4835) 2024-05-17 18:43:34 +09:00
26148120b3 [Build/CI] Extending the set of AMD tests with Regression, Basic Correctness, Distributed, Engine, Llava Tests (#4797) 2024-05-16 20:58:25 -07:00
0150a10630 [Frontend] OpenAI API server: Do not add bos token by default when encoding (#4688) 2024-05-16 18:47:22 -07:00
8e7fb5d43a Support to serve vLLM on Kubernetes with LWS (#4829)
Signed-off-by: kerthcet <kerthcet@gmail.com>
2024-05-16 16:37:29 -07:00
9a31a817a8 [Bugfix] Fix FP8 KV cache support (#4869) 2024-05-16 22:42:29 +00:00
2060e93659 [Kernel] Add w8a8 CUTLASS kernels (#4749) 2024-05-16 18:32:50 -04:00
8435b207af [Kernel] Add punica dimension for Qwen1.5-32B LoRA (#4850)
Co-authored-by: Silencio <silencio@adsl-99-6-187-6.dsl.irvnca.sbcglobal.net>
2024-05-16 11:16:09 -07:00
10fa9eea21 [Misc] remove old comments (#4866) 2024-05-16 11:07:41 -07:00
e08188081b [Core][Distributed] remove graph mode function (#4818) 2024-05-16 10:59:52 -07:00
b5853f9963 [ROCm][AMD][Bugfix] adding a missing triton autotune config (#4845) 2024-05-16 10:46:52 -07:00
f09edd8a25 Add JSON output support for benchmark_latency and benchmark_throughput (#4848) 2024-05-16 10:02:56 -07:00
6979ade384 Add GPTQ Marlin 2:4 sparse structured support (#4790)
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
2024-05-16 12:56:15 -04:00
9216b9cc38 [Bugfix] Bypass authorization API token for preflight requests (#4862) 2024-05-16 09:42:21 -07:00
5e0391c040 [Frontend] Separate OpenAI Batch Runner usage from API Server (#4851) 2024-05-17 00:42:41 +09:00
dbc0754ddf [docs] Fix typo in examples filename openi -> openai (#4864) 2024-05-17 00:42:17 +09:00
99caa49106 [Kernel] add bfloat16 support for gptq marlin kernel (#4788) 2024-05-16 09:55:29 -04:00
5c342570d7 Add marlin unit tests and marlin benchmark script (#4815) 2024-05-16 09:36:49 -04:00
973617ae02 [Speculative decoding][Re-take] Enable TP>1 speculative decoding (#4840)
Co-authored-by: Cade Daniel <edacih@gmail.com>
Co-authored-by: Cade Daniel <cade@anyscale.com>
2024-05-16 00:53:51 -07:00
30e754390c [Core] Implement sharded state loader (#4690)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-05-15 22:11:54 -07:00
52f8107cf2 [Frontend] Support OpenAI batch file format (#4794)
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
2024-05-15 19:13:36 -04:00
fc0d9dfc3a [Frontend] Re-enable custom roles in Chat Completions API (#4758) 2024-05-15 14:58:46 -07:00
361c461a12 [Doc] Highlight the fourth meetup in the README (#4842) 2024-05-15 11:38:49 -07:00
a5675d348b [Bugfix] Properly set distributed_executor_backend in ParallelConfig (#4816) 2024-05-15 07:22:09 -07:00
e9cdd2b1e2 [CI/Build] Further decouple HuggingFace implementation from ours during tests (#4166) 2024-05-14 23:38:40 -07:00
65bf2ac165 [Core][2/N] Model runner refactoring part 2. Combine prepare prefill / decode to a single API (#4681)
This PR combines prepare_prompt and prepare_decode into a single API. This PR also coelsce the attn metadata for prefill/decode to a single class and allow to slice them when running attn backend.

It also refactors subquery_start_loc which was not refactored in the previous PR
2024-05-15 14:00:10 +09:00
8a7cc254a0 Revert "[Kernel] Use flash-attn for decoding (#3648)" (#4820)
Lora 3 & 4 test seems to have illegal memory access failure after this commit;

[2024-05-14 23:51:18,182 E 22 22] logging.cc:101: Unhandled exception: N3c105ErrorE. what(): CUDA error: an illegal memory access was encountered
<br class="Apple-interchange-newline">
Exmaple: https://buildkite.com/vllm/ci/builds/7382#018f793d-1527-4e1c-ab59-c3a34ec55241

This reverts commit 1356df5.

FILL IN THE PR DESCRIPTION HERE

FIX #xxxx (link existing issues this PR will resolve)
2024-05-15 11:52:45 +09:00
29bc01bf3b Add 4th meetup announcement to readme (#4817) 2024-05-14 18:33:06 -04:00
676a99982f [Core] Add MultiprocessingGPUExecutor (#4539)
Co-authored-by: SAHIL SUNEJA <suneja@us.ibm.com>
2024-05-14 10:38:59 -07:00
dc72402b57 [Bugfix][Doc] Fix CI failure in docs (#4804)
This PR fixes the CI failure introduced by #4798.

The failure originates from having duplicate target names in reST, and is fixed by changing the ref targets to anonymous ones. For more information, see this discussion.

I have also changed the format of the links to be more distinct from each other.
2024-05-15 01:57:08 +09:00
ccb63a8245 [Core][Hash][Automatic Prefix caching] Accelerating the hashing function by avoiding deep copies (#4696) 2024-05-14 21:34:33 +09:00
c579b750a0 [Doc] Add meetups to the doc (#4798) 2024-05-13 18:48:00 -07:00
4bfa7e7f75 [Doc] Add API reference for offline inference (#4710) 2024-05-13 17:47:42 -07:00
ac1fbf7fd2 [Doc] Shorten README by removing supported model list (#4796) 2024-05-13 16:23:54 -07:00
33d3914b1e [Bugfix] Fix dynamic FP8 quantization for Mixtral (#4793) 2024-05-13 19:00:27 -04:00
1356df53bd [Kernel] Use flash-attn for decoding (#3648)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
2024-05-13 15:50:33 -07:00
ce532ff45c [Speculative decoding] Improve n-gram efficiency (#4724) 2024-05-13 15:00:13 -07:00
8bc68e198c [Frontend] [Core] perf: Automatically detect vLLM-tensorized model, update tensorizer to version 2.9.0 (#4208) 2024-05-13 14:57:07 -07:00
0fca3cdcf2 [Misc] Enhance attention selector (#4751) 2024-05-13 10:47:25 -07:00
e7c46b9527 [Scheduler] Warning upon preemption and Swapping (#4647)
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
2024-05-13 23:50:44 +09:00
350f9e107f [CI/Build] Move test_utils.py to tests/utils.py (#4425)
Since #4335 was merged, I've noticed that the definition of ServerRunner in the tests is the same as in the test for OpenAI API. I have moved the class to the test utilities to avoid code duplication. (Although it only has been repeated twice so far, I will add another similar test suite in #4200 which would duplicate the code a third time)

Also, I have moved the test utilities file (test_utils.py) to under the test directory (tests/utils.py), since none of its code is actually used in the main package. Note that I have added __init__.py to each test subpackage and updated the ray.init() call in the test utilities file in order to relative import tests/utils.py.
2024-05-13 23:50:09 +09:00
702bee461f [Core][Distributed] refactor custom allreduce to support multiple tp groups (#4754) 2024-05-12 17:47:59 -07:00
a7be4d0072 [CORE] Improvement in ranks code (#4718) 2024-05-12 17:47:47 -07:00
a709e87a4f [CI/Build] Tweak Marlin Nondeterminism Issues (#4713) 2024-05-12 17:46:31 -07:00
6eaccb7353 [Model] Add support for IBM Granite Code models (#4636) 2024-05-11 21:27:24 -07:00
e254497b66 [Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734) 2024-05-11 11:30:37 -07:00
4e12131089 [Core][Test] fix function name typo in custom allreduce (#4750) 2024-05-10 15:14:40 -07:00
fcc2994be6 [CI] Nits for bad initialization of SeqGroup in testing (#4748) 2024-05-10 18:01:01 -04:00
2e7796f2cf [Speculative decoding] CUDA graph support (#4295)
Co-authored-by: Cade Daniel <edacih@gmail.com>
2024-05-10 17:36:25 +00:00
706588a77d [Bugfix] Fix CLI arguments in OpenAI server docs (#4729) 2024-05-11 00:00:56 +09:00
6a0f617210 [Core] Fix circular reference which leaked llm instance in local dev env (#4737)
Storing exception frame is extremely prone to circular refernece because it contains the reference to objects.

When tensorizer is not installed, it leaks llm instance because error frame has references to various modules which cause circular reference problem.

I also found spec decoding has a circular reference issue, and I solved it using weakref.proxy.
2024-05-10 23:54:32 +09:00
dac6a3f6ed [Misc] Apply a couple g++ cleanups (#4719) 2024-05-10 13:37:05 +00:00
64b77dfd7e [Core]fix type annotation for swap_blocks (#4726) 2024-05-10 21:52:48 +09:00
51d4094fda chunked-prefill-doc-syntax (#4603)
Fix the docs: https://docs.vllm.ai/en/latest/models/performance.html

Co-authored-by: sang <rkooo567@gmail.com>
2024-05-10 14:13:23 +09:00
e965d46184 [Misc] Keep only one implementation of the create_dummy_prompt function. (#4716) 2024-05-09 21:42:38 -07:00
208b71bcc1 [Core][Distributed] refactor pynccl (#4591)
[Core][Distributed] refactor pynccl to hold multiple communicators (#4591)
2024-05-09 19:48:43 -07:00
c833101740 [Kernel] Refactor FP8 kv-cache with NVIDIA float8_e4m3 support (#4535) 2024-05-09 18:04:17 -06:00
379da6dcb5 [Kernel] [FP8] Improve FP8 linear layer performance (#4691)
This PR improves the FP8 performance of linear layers, which had been lacking before (#4118 (comment) and #4118 (comment)).

We noticed that CUBLASLt can find a better algorithm if the first dimension of the matrix is greater than 16. So this PR enlarges matrices appropriately during quantization. This improves FP8 performance and removes the performance regression vs. FP16, in many cases exceeding FP16 performance.

Here are benchmarks on llama3 70b (ITL numbers for 1000 input and 50 output tokens at fixed qps and at TP 4), all FP8 measurements are for dynamic quantization:

qps = 1: 24 ms (FP8, this PR), 32 ms (FP8, previous main), 26 ms (FP16)
qps = 2: 26 ms (FP8, this PR), 34ms (FP8, previous main), 28 ms (FP16) 
qps = 4: 33 ms (FP8, this PR), 44 ms (FP8, previous main), 36 ms (FP16)
qps = 6: 46 ms (FP8, this PR), 56 ms (FP8, previous main), 54 ms (FP16)
qps = 8: 85 ms (FP8, this PR), 85 ms (FP8, previous main), 138 ms (FP16)
2024-05-09 16:38:07 -07:00
ebce310b74 [Model] Snowflake arctic model implementation (#4652)
Co-authored-by: Dash Desai <1723932+iamontheinet@users.noreply.github.com>
Co-authored-by: Aurick Qiao <qiao@aurick.net>
Co-authored-by: Aurick Qiao <aurick.qiao@snowflake.com>
Co-authored-by: Aurick Qiao <aurickq@users.noreply.github.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
2024-05-09 22:37:14 +00:00
be0c5180ac [Bugfix] Add logs for all model dtype casting (#4717) 2024-05-09 18:36:25 +00:00
cea64430f6 [Bugfix] Update grafana.json (#4711) 2024-05-09 10:10:13 -07:00
a3c124570a [Bugfix] Fix CLI arguments in OpenAI server docs (#4709) 2024-05-09 09:53:14 -07:00
ff5abcd746 [ROCm] Add support for Punica kernels on AMD GPUs (#3140)
Co-authored-by: miloice <jeffaw99@hotmail.com>
2024-05-09 09:19:50 -07:00
0ee535b294 [Misc] Set block size at initialization & Fix test_model_runner (#4705) 2024-05-09 09:04:59 -07:00
190bc838e1 [Misc] Remove unnecessary ModelRunner imports (#4703) 2024-05-09 00:17:17 -07:00
f12b20decc [Frontend] Move async logic outside of constructor (#4674) 2024-05-08 22:48:33 -07:00
16bc0a098f [Frontend] add tok/s speed metric to llm class when using tqdm (#4400)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
2024-05-08 22:02:31 -07:00
e288df0632 [Bugfix] Fine-tune gptq_marlin configs to be more similar to marlin (#4626) 2024-05-08 17:14:31 -07:00
8b9241be3a [Speculative decoding] [Bugfix] Fix overallocation in ngram + spec logprobs (#4672) 2024-05-08 23:24:46 +00:00
f942efb5a3 [Dynamic Spec Decoding] Auto-disable by the running queue size (#4592)
Co-authored-by: Cade Daniel <edacih@gmail.com>
2024-05-08 21:44:00 +00:00
89579a201f [Misc] Use vllm-flash-attn instead of flash-attn (#4686) 2024-05-08 13:15:34 -07:00
230c4b38c1 [CI/Test] fix swap test for multi gpu (#4689) 2024-05-08 13:14:02 -07:00
20cfcdec99 [Core][Optimization] change python dict to pytorch tensor for blocks to swap (#4659) 2024-05-08 12:07:05 -07:00
ad932a221d [Core] Faster startup for LoRA enabled models (#4634) 2024-05-08 10:33:18 -07:00
5510cf0e8a [Misc] Add get_name method to attention backends (#4685) 2024-05-08 09:59:31 -07:00
0f9a6e3d22 [Bugfix][Kernel] allow non-power-of-2 for prefix prefill with alibi (#4573) 2024-05-08 09:19:58 -07:00
f6a593093a [CI] Make mistral tests pass (#4596) 2024-05-08 08:44:35 -07:00
d7740ea4dc [Core] Optimize sampler get_logprobs (#4594) 2024-05-08 08:42:28 -07:00
cc466a3290 [Core][Distributed] support cpu&device in broadcast tensor dict (#4660)
[Core][Distributed] support both cpu and device tensor in broadcast tensor dict (#4660)
2024-05-07 19:34:47 -07:00
8344f7742b [Bug fix][Core] fixup ngram not setup correctly (#4551)
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
Co-authored-by: Cade Daniel <edacih@gmail.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
2024-05-07 11:40:18 -07:00
469f85c782 [Core][Optimization] change copy-on-write from dict[int, list] to list (#4648) 2024-05-07 11:06:32 -07:00
10760da800 [Bugfix] Fixed error in slice_lora_b for MergedQKVParallelLinearWithLora (#4609) 2024-05-07 10:59:07 -07:00
478aed5827 [Build/CI] Fixing 'docker run' to re-enable AMD CI tests. (#4642) 2024-05-07 09:23:17 -07:00
63575bc2e1 [Core][Optimization] change python dict to pytorch tensor (#4607) 2024-05-06 21:30:27 -07:00
a98187cf72 [Kernel] Make static FP8 scaling more robust (#4570)
Previously FP8 static scaling works if the scales are overestimating the maxima of all activation tensors during computation. However this will not always be the case even if the scales were calibrated very carefully. For example, with the activations in my checkpoint

https://huggingface.co/pcmoritz/Mixtral-8x7B-v0.1-fp8-act-scale

(which was calibrated on https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k), I'm getting the following mostly random performance on MMLU:

|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.2295|±  |0.0035|
| - humanities     |N/A    |none  |     5|acc   |0.2421|±  |0.0062|
| - other          |N/A    |none  |     5|acc   |0.2398|±  |0.0076|
| - social_sciences|N/A    |none  |     5|acc   |0.2171|±  |0.0074|
| - stem           |N/A    |none  |     5|acc   |0.2125|±  |0.0073|
With the fix in this PR where the scaled activations are clamped between [-std::numeric_limits<c10::Float8_e4m3fn>::max(), std::numeric_limits<c10::Float8_e4m3fn>::max()] to make sure there are no NaNs, the performance is

|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7008|±  |0.0036|
| - humanities     |N/A    |none  |     5|acc   |0.6453|±  |0.0065|
| - other          |N/A    |none  |     5|acc   |0.7692|±  |0.0072|
| - social_sciences|N/A    |none  |     5|acc   |0.8083|±  |0.0070|
| - stem           |N/A    |none  |     5|acc   |0.6115|±  |0.0083|
This is not perfect yet but is getting very close to the FP16 / dynamic activation scale performance.
2024-05-06 17:39:28 -07:00
bd99d22629 Update lm-format-enforcer to 0.10.1 (#4631) 2024-05-06 23:51:59 +00:00
19cb4716ee [CI] Add retry for agent lost (#4633) 2024-05-06 23:18:57 +00:00
e186d37cb1 [CI] use ccache actions properly in release workflow (#4629) 2024-05-06 22:23:36 +00:00
323f27b904 [Bugfix] Fix asyncio.Task not being subscriptable (#4623) 2024-05-06 09:31:05 -07:00
0650e5935b Disable cuda version check in vllm-openai image (#4530) 2024-05-05 16:58:55 -07:00
c7f2cf2b7f [CI] Reduce wheel size by not shipping debug symbols (#4602) 2024-05-04 21:28:58 -07:00
8d8357c8ed bump version to v0.4.2 (#4600) 2024-05-04 17:09:49 -07:00
4302987069 [Bugfix] Fix inappropriate content of model_name tag in Prometheus metrics (#3937) 2024-05-04 15:39:34 -07:00
021b1a2ab7 [CI] check size of the wheels (#4319) 2024-05-04 20:44:36 +00:00
2a052011ca [Kernel] Support MoE Fp8 Checkpoints for Mixtral (Static Weights with Dynamic/Static Activations) (#4527)
Follow on to #4332 to enable FP8 checkpoint loading for Mixtral and supersedes #4436.

This PR enables the following checkpoint loading features for Mixtral:

Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model
Supports static or dynamic activation quantization with static weight quantization (all per tensor)
Supports different scales for each expert weight
Supports Fp8 in QKV layer
Notes:

The Expert Gate/Router always runs at half / full precision for now.
If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.
2024-05-04 11:45:16 -07:00
36fb68f947 [Doc] Chunked Prefill Documentation (#4580) 2024-05-04 00:18:00 -07:00
bc8ad68455 [Misc][Refactor] Introduce ExecuteModelData (#4540) 2024-05-03 17:47:07 -07:00
344bf7cd2d [Misc] add installation time env vars (#4574) 2024-05-03 15:55:56 -07:00
ab50275111 [Speculative decoding] Support target-model logprobs (#4378) 2024-05-03 15:52:01 -07:00
43c413ec57 [Kernel] Use flashinfer for decoding (#4353)
Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>
2024-05-03 15:51:27 -07:00
f8e7adda21 Fix/async chat serving (#2727) 2024-05-03 11:04:14 -07:00
7e65477e5e [Bugfix] Allow "None" or "" to be passed to CLI for string args that default to None (#4586) 2024-05-03 10:32:21 -07:00
3521ba4f25 [Core][Model runner refactoring 1/N] Refactor attn metadata term (#4518) 2024-05-03 10:20:12 -07:00
2d7bce9cd5 [Doc] add env vars to the doc (#4572) 2024-05-03 05:13:49 +00:00
ce3f1eedf8 [Misc] remove chunk detected debug logs (#4571) 2024-05-03 04:48:08 +00:00
808632d3b4 [BugFix] Prevent the task of _force_log from being garbage collected (#4567) 2024-05-03 01:35:18 +00:00
344a5d0c33 [Core][Distributed] enable allreduce for multiple tp groups (#4566) 2024-05-02 17:32:33 -07:00
0f8a91401c [Core] Ignore infeasible swap requests. (#4557) 2024-05-02 14:31:20 -07:00
9b5c9f9484 [CI/Build] AMD CI pipeline with extended set of tests. (#4267)
Co-authored-by: simon-mo <simon.mo@hey.com>
2024-05-02 12:29:07 -07:00
32881f3f31 [kernel] fix sliding window in prefix prefill Triton kernel (#4405)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
2024-05-02 11:23:37 -07:00
5b8a7c1cb0 [Misc] centralize all usage of environment variables (#4548) 2024-05-02 11:13:25 -07:00
1ff0c73a79 [BugFix] Include target-device specific requirements.txt in sdist (#4559) 2024-05-02 10:52:51 -07:00
5ad60b0cbd [Misc] Exclude the tests directory from being packaged (#4552) 2024-05-02 10:50:25 -07:00
fb087af52e [mypy][7/N] Cover all directories (#4555) 2024-05-02 10:47:41 -07:00
7038e8b803 [Kernel] Support running GPTQ 8-bit models in Marlin (#4533) 2024-05-02 12:56:22 -04:00
2a85f93007 [Core][Distributed] enable multiple tp group (#4512)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2024-05-02 04:28:21 +00:00
cf8cac8c70 [mypy][6/N] Fix all the core subdirectory typing (#4450)
Co-authored-by: Cade Daniel <edacih@gmail.com>
2024-05-02 03:01:00 +00:00
5e401bce17 [CI]Add regression tests to ensure the async engine generates metrics (#4524) 2024-05-01 19:57:12 -07:00
0d62fe58db [Bug fix][Core] assert num_new_tokens == 1 fails when SamplingParams.n is not 1 and max_tokens is large & Add tests for preemption (#4451) 2024-05-01 19:24:13 -07:00
b8afa8b95a [MISC] Rework logger to enable pythonic custom logging configuration to be provided (#4273) 2024-05-01 17:34:40 -07:00
826b82a260 [Misc] Fix expert_ids shape in MoE (#4517) 2024-05-01 23:47:59 +00:00
c9d852d601 [Misc] Remove Mixtral device="cuda" declarations (#4543)
Remove the device="cuda" declarations in mixtral as promised in #4343
2024-05-01 16:30:52 -07:00
6ef09b08f8 [Core][Distributed] fix pynccl del error (#4508) 2024-05-01 15:23:06 -07:00
Roy
3a922c1e7e [Bugfix][Core] Fix and refactor logging stats (#4336) 2024-05-01 20:08:14 +00:00
c47ba4aaa9 [Bugfix] Add validation for seed (#4529) 2024-05-01 19:31:22 +00:00
24bb4fe432 [Kernel] Update fused_moe tuning script for FP8 (#4457)
This PR updates the tuning script for the fused_moe kernel to support FP8 and also adds configurations for TP4. Note that for the configuration I removed num_warps and num_stages for small batch sizes since that improved performance and brought the benchmarks on par with the numbers before in that regime to make sure this is a strict improvement over the status quo.

All the numbers below are for mistralai/Mixtral-8x7B-Instruct-v0.1, 1000 input and 50 output tokens.

Before this PR (with static activation scaling):

qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency 
qps = 4: 10.1 ms ITL, 0.52s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 14.0 ms ITL, 0.70s e2e latency
qps = 10: 15.7 ms ITL, 0.79s e2e latency

After this PR (with static activation scaling):

qps = 1: 9.8 ms ITL, 0.49s e2e latency
qps = 2: 9.7 ms ITL, 0.49s e2e latency
qps = 4: 10.2 ms ITL, 0.53s e2e latency
qps = 6: 11.9 ms ITL, 0.59s e2e latency
qps = 8: 11.9 ms ITL, 0.59s e2e latency
qps = 10: 12.1 ms ITL, 0.61s e2e latency
2024-05-01 11:47:38 -07:00
a657bfc48a [Core] Add multiproc_worker_utils for multiprocessing-based workers (#4357) 2024-05-01 18:41:59 +00:00
24750f4cad [Core] Enable prefix caching with block manager v2 enabled (#4142)
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
Co-authored-by: Sage Moore <sagemoore@utexas.edu>
2024-05-01 11:20:32 -07:00
b38e42fbca [Speculative decoding] Add ngram prompt lookup decoding (#4237)
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
2024-05-01 11:13:03 -07:00
8b798eec75 [CI/Build][Bugfix] VLLM_USE_PRECOMPILED should skip compilation (#4534)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
2024-05-01 18:01:50 +00:00
69909126a7 [Bugfix] Use random seed if seed is -1 (#4531) 2024-05-01 10:41:17 -07:00
e491c7e053 [Doc] update(example model): for OpenAI compatible serving (#4503) 2024-05-01 10:14:16 -07:00
4dc8026d86 [Bugfix] Fix 307 Redirect for /metrics (#4523) 2024-05-01 09:14:13 -07:00
a88bb9b032 [Bugfix] Fix the fp8 kv_cache check error that occurs when failing to obtain the CUDA version. (#4173)
Signed-off-by: AnyISalIn <anyisalin@gmail.com>
2024-05-01 09:11:03 -07:00
6f1df80436 [Test] Add ignore_eos test (#4519) 2024-05-01 08:45:42 -04:00
d6f4bd7cdd [Misc]Add customized information for models (#4132) 2024-04-30 21:18:14 -07:00
c3845d82dc Allow user to define whitespace pattern for outlines (#4305) 2024-04-30 20:48:39 -07:00
a822eb3413 [Misc] fix typo in block manager (#4453) 2024-04-30 20:41:32 -07:00
f458112e8a [Misc][Typo] type annotation fix (#4495) 2024-04-30 20:21:39 -07:00
2e240c69a9 [Core] Centralize GPU Worker construction (#4419) 2024-05-01 01:06:34 +00:00
ee37328da0 Unable to find Punica extension issue during source code installation (#4494)
Co-authored-by: Simon Mo <simon.mo@hey.com>
2024-05-01 00:42:09 +00:00
6ad58f42c5 fix_tokenizer_snapshot_download_bug (#4493) 2024-04-30 16:38:50 -07:00
dd1a50a8bc [Bugfix][Minor] Make ignore_eos effective (#4468) 2024-04-30 16:33:33 -07:00
715c2d854d [Frontend] [Core] Tensorizer: support dynamic num_readers, update version (#4467) 2024-04-30 16:32:13 -07:00
a494140433 [Frontend] Support complex message content for chat completions endpoint (#3467)
Co-authored-by: Lily Liu <lilyliupku@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
2024-04-30 16:28:46 -07:00
111815d482 [Kernel] Support Fp8 Checkpoints (Dynamic + Static) (#4332)
Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
2024-04-30 21:46:12 +00:00
b31a1fb63c [Doc] add visualization for multi-stage dockerfile (#4456)
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
2024-04-30 17:41:59 +00:00
4bb53e2dde [BugFix] fix num_lookahead_slots missing in async executor (#4165)
Co-authored-by: Lei Wen <wenlei03@qiyi.com>
2024-04-30 10:12:59 -07:00
26f2fb5113 [Core]Refactor gptq_marlin ops (#4466) 2024-04-30 08:14:47 -04:00
fa32207842 [Bugfix][Kernel] Fix compute_type for MoE kernel (#4463) 2024-04-29 22:05:40 -07:00
d627a3d837 [Misc] Upgrade to torch==2.3.0 (#4454) 2024-04-29 20:05:47 -04:00
f4f921b7f1 [Core][Distributed] use cpu group to broadcast metadata in cpu (#4444) 2024-04-29 13:52:22 -07:00
ac5ccf0156 [CI] hotfix: soft fail neuron test (#4458) 2024-04-29 19:50:01 +00:00
73c8d677e5 [Kernel] Marlin Expansion: Support AutoGPTQ Models with Marlin (#3922)
Co-authored-by: alexm <alexm@neuralmagic.com>
Co-authored-by: mgoin <michael@neuralmagic.com>
2024-04-29 09:35:34 -07:00
df29793dc7 [mypy][5/N] Support all typing on model executor (#4427) 2024-04-28 19:01:26 -07:00
03dd7d52bf [CI] clean docker cache for neuron (#4441) 2024-04-28 23:32:07 +00:00
bf480c5302 Add more Prometheus metrics (#2764)
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
2024-04-28 15:59:33 -07:00
9c7306ac11 [Misc] fix typo in llm_engine init logging (#4428) 2024-04-28 18:58:30 +08:00
4ea1f9678d [BugFix] Resolved Issues For LinearMethod --> QuantConfig (#4418) 2024-04-27 18:35:33 +00:00
ba4be44c32 [BugFix] Fix return type of executor execute_model methods (#4402) 2024-04-27 11:17:45 -07:00
d6e520e170 [Core] Support offline use of local cache for models (#4374)
Signed-off-by: Prashant Gupta <prashantgupta@us.ibm.com>
Co-authored-by: Travis Johnson <tjohnson31415@gmail.com>
2024-04-27 09:59:55 -07:00
81661da7b2 [BugFix] Fix min_tokens when eos_token_id is None (#4389)
Co-authored-by: DefTruth <31974251+deftruth@users.noreply.github.com>
2024-04-27 09:52:46 -07:00
dfea173148 [Bugfix] Abort requests when the connection to /v1/completions is interrupted (#4363) 2024-04-27 09:48:37 -07:00
Roy
7134303cbb [Bugfix][Core] Fix get decoding config from ray (#4335) 2024-04-27 11:30:08 +00:00
3da24c2df7 [Model] Phi-3 4k sliding window temp. fix (#4380) 2024-04-27 18:08:15 +08:00
eefeb16464 [Kernel] Full Tensor Parallelism for LoRA Layers (#3524)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
2024-04-27 00:03:48 -07:00
18d23f642a [ROCm][Hardware][AMD] Enable group query attention for triton FA (#4406) 2024-04-26 23:37:40 -07:00
Roy
87f545ba6f [Misc] Fix logger format typo (#4396) 2024-04-27 13:45:02 +08:00
8947bc3c15 [Frontend][Bugfix] Disallow extra fields in OpenAI API (#4355) 2024-04-27 05:08:24 +00:00
12628d3c78 [Kernel] Optimize FP8 support for MoE kernel / Mixtral via static scales (#4343)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-04-27 04:49:59 +00:00
258a2c58d0 [Core] Introduce DistributedGPUExecutor abstract class (#4348) 2024-04-27 04:14:26 +00:00
aba47be3fe [Misc] add RFC issue template (#4401)
Co-authored-by: Simon Mo <simon.mo@hey.com>
2024-04-26 15:47:45 -07:00
a62aaf1df5 [Misc][Refactor] Generalize linear_method to be quant_method (#4373) 2024-04-26 16:41:14 -04:00
603ad84815 [Core] Refactoring sampler and support prompt logprob for chunked prefill (#4309) 2024-04-26 13:02:02 +00:00
a88081bf76 [CI] Disable non-lazy string operation on logging (#4326)
Co-authored-by: Danny Guinther <dguinther@neuralmagic.com>
2024-04-26 00:16:58 -07:00
2f30e7c72f [Frontend] Add --log-level option to api server (#4377) 2024-04-26 05:36:01 +00:00
a74dee9b62 [Bugfix] Fix parameter name in get_tokenizer (#4107) 2024-04-25 19:10:48 -07:00
cf29b7eda4 [ROCm][Hardware][AMD][Doc] Documentation update for ROCm (#4376)
Co-authored-by: WoosukKwon <woosuk.kwon@berkeley.edu>
2024-04-25 18:12:25 -07:00
efffb63f58 [Core] Move function tracing setup to util function (#4352) 2024-04-25 16:45:12 -07:00
15e7c675b0 [Core] Add shutdown() method to ExecutorBase (#4349) 2024-04-25 16:32:48 -07:00
Roy
b6dcb4d442 [Misc] Fix flash attention backend log (#4368) 2024-04-25 12:43:32 -07:00
b5b4a398a7 [Mypy] Typing lora folder (#4337) 2024-04-25 19:13:50 +00:00
f4bc4de1b1 [Core]refactor aqlm quant ops (#4351) 2024-04-25 15:03:56 -04:00
bd7a8eef25 [Doc] README Phi-3 name fix. (#4372)
Co-authored-by: Caio Mendes <caiocesart@microsoft.com>
2024-04-25 10:32:00 -07:00
7ee82bef1e [CI/Build] Adding functionality to reset the node's GPUs before processing. (#4213) 2024-04-25 09:37:20 -07:00
fbf152d976 [Bugfix][Model] Refactor OLMo model to support new HF format in transformers 4.40.0 (#4324)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-04-25 09:35:56 -07:00
479d69fad0 [Core] Move ray_utils.py from engine to executor package (#4347) 2024-04-25 06:52:22 +00:00
96e90fdeb3 [Model] Adds Phi-3 support (#4298) 2024-04-25 03:06:57 +00:00
a395a638c2 [Misc] Use public API in benchmark_throughput (#4300) 2024-04-24 21:10:24 +00:00
2768884ac4 [Doc] Add note for docker user (#4340)
Co-authored-by: Simon Mo <simon.mo@hey.com>
2024-04-24 21:09:44 +00:00
aae08249ac [Bugfix] Fix marlin kernel crash on H100 (#4218)
This PR addresses the Marlin kernel H100 crash that was reported here: neuralmagic#187.
The reason for the crash was the inline PTX assembly that introduced the async_copy with streaming behavior. The solution is to use the more standard PTX for async_copy (without the fractional L2 policy for "evict_first"). There is no performance difference between standard async_copy PTX and the previous one.
2024-04-24 10:35:01 -07:00
7923dcad12 [Misc] Update ShareGPT Dataset Sampling in Serving Benchmark (#4279) 2024-04-24 09:49:13 -07:00
3cd9b5bb2d [Core][Distributed] use existing torch.cuda.device (#4318)
[Core][Distributed] use existing torch.cuda.device context manager (#4318)
2024-04-24 09:00:20 -07:00
468d761b32 [Misc] Reduce supported Punica dtypes (#4304) 2024-04-23 18:54:33 -07:00
e4bf860a54 [CI][Build] change pynvml to nvidia-ml-py (#4302) 2024-04-23 18:33:12 -07:00
91f50a6fe2 [Core][Distributed] use cpu/gloo to initialize pynccl (#4248) 2024-04-23 18:32:19 -07:00
79a268c4ab [BUG] fixed fp8 conflict with aqlm (#4307)
Fixes fp8 iterface which broke in AQLM merge.
2024-04-23 18:26:33 -07:00
eace8bf0b9 [Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208

It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:

```python
from vllm import LLM, SamplingParams

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:

<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">


**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:

```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7018|±  |0.0036|
| - humanities     |N/A    |none  |     5|acc   |0.6472|±  |0.0065|
| - other          |N/A    |none  |     5|acc   |0.7673|±  |0.0072|
| - social_sciences|N/A    |none  |     5|acc   |0.8099|±  |0.0070|
| - stem           |N/A    |none  |     5|acc   |0.6131|±  |0.0083|
```
this compares favorably with the fp16 results which are
```
|      Groups      |Version|Filter|n-shot|Metric|Value |   |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu              |N/A    |none  |     0|acc   |0.7020|±  |0.1313|
| - humanities     |N/A    |none  |     5|acc   |0.6425|±  |0.1349|
| - other          |N/A    |none  |     5|acc   |0.7744|±  |0.1038|
| - social_sciences|N/A    |none  |     5|acc   |0.8131|±  |0.0695|
| - stem           |N/A    |none  |     5|acc   |0.6108|±  |0.1383|
```

Happy hacking!
2024-04-24 01:18:23 +00:00
1e8f4252aa [Bugfix][Frontend] Raise exception when file-like chat template fails to be opened (#4292) 2024-04-23 18:19:03 +00:00
2b7949c1c2 AQLM CUDA support (#3287)
Co-authored-by: mgoin <michael@neuralmagic.com>
2024-04-23 13:59:33 -04:00
62b5166bd4 [CI] Add ccache for wheel builds job (#4281) 2024-04-23 09:51:41 -07:00
d86285a4a4 [Core][Logging] Add last frame information for better debugging (#4278) 2024-04-23 09:45:52 -07:00
d87f39e9a9 [Bugfix] Add init_cached_hf_modules to RayWorkerWrapper (#4286) 2024-04-23 09:28:35 -07:00
d3c8180ac4 [Bugfix] Fixing max token error message for openai compatible server (#4016) 2024-04-23 19:06:29 +08:00
62b8aebc6f [Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (#3951) 2024-04-23 08:02:36 +00:00
050f285ff6 [Core] Scheduling optimization 2 (#4280) 2024-04-23 08:02:11 +00:00
8f2ea22bde [Core] Some simplification of WorkerWrapper changes (#4183) 2024-04-23 07:49:08 +00:00
0ae11f78ab [Mypy] Part 3 fix typing for nested directories for most of directory (#4161) 2024-04-22 21:32:44 -07:00
34128a697e Fix autodoc directives (#4272)
Co-authored-by: Harry Mellor <hmellor@oxts.com>
2024-04-23 01:53:01 +00:00
c1b4e4157c [Core][Distributed] use absolute path for library file (#4271) 2024-04-22 17:21:48 -07:00
ceaf4ed003 [Doc] Update the SkyPilot doc with serving and Llama-3 (#4276) 2024-04-22 15:34:31 -07:00
ad8d696a99 [Core] Scheduler perf fix (#4270) 2024-04-22 21:11:06 +00:00
3d925165f2 Add example scripts to documentation (#4225)
Co-authored-by: Harry Mellor <hmellor@oxts.com>
2024-04-22 16:36:54 +00:00
1543680691 [Bugfix] Ensure download_weights_from_hf(..) inside loader is using the revision parameter (#4217) 2024-04-22 09:10:48 -07:00
077f0a2e8a [Frontend] Enable support for CPU backend in AsyncLLMEngine. (#3993)
Signed-off-by: Tao He <sighingnow@gmail.com>
2024-04-22 09:19:51 +00:00
e73ed0f1c6 [Bugfix] Fix type annotations in CPU model runner (#4256) 2024-04-22 00:54:16 -07:00
296cdf8ac7 [Misc] Add vision language model support to CPU backend (#3968) 2024-04-22 00:44:16 -07:00
747b1a7147 [Core][Distributed] fix _is_full_nvlink detection (#4233) 2024-04-21 23:04:16 -07:00
95e5b087cf [AMD][Hardware][Misc][Bugfix] xformer cleanup and light navi logic and CI fixes and refactoring (#4129) 2024-04-21 21:57:24 -07:00
a37d815b83 Make initialization of tokenizer and detokenizer optional (#3748)
Co-authored-by: Yun Ding <yunding@nvidia.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
2024-04-21 22:06:46 +00:00
7f2593b164 [Doc]: Update the doc of adding new models (#4236) 2024-04-21 09:57:08 -07:00
fe7d648fe5 Don't show default value for flags in EngineArgs (#4223)
Co-authored-by: Harry Mellor <hmellor@oxts.com>
2024-04-21 09:15:28 -07:00
cc74b2b232 Updating lm-format-enforcer version and adding links to decoding libraries in docs (#4222) 2024-04-20 08:33:16 +00:00
91528575ec [Frontend] multiple sampling params support (#3570) 2024-04-20 00:11:57 -07:00
a22cdea371 [Kernel][FP8] Initial support with dynamic per-tensor scaling (#4118)
Provide an initial support to FP8 computation. This PR is inspired by HuggingFace TGI: huggingface/text-generation-inference#1726

This feature can be enabled with --quantization fp8 or -q fp8 when launching an engine.

Algorithm:
We still load a model checkpoint in FP16/BF16. After the weights are loaded, Fp8LinearMethod calculates the per-tensor scaling factor of weights and quantizes the weights accordingly. The scaling factor will then be stored for future use. Meanwhile, the per-tensor scaling factor for activations is calculated in every forward pass.

Initial Results:
Currently tested Mistral-7B on 1xH100. With prompt length ~5 and decoding length 128:

BF16: 1.47s
FP8: 1.66s
I'll try to use larger models and try to find more performance bottleneck. Meanwhile, you're welcome to try this code.
2024-04-20 04:28:57 +00:00
682789d402 Fix missing docs and out of sync EngineArgs (#4219)
Co-authored-by: Harry Mellor <hmellor@oxts.com>
2024-04-19 20:51:33 -07:00
138485a82d [Bugfix] Add fix for JSON whitespace (#4189)
Co-authored-by: Ubuntu <ubuntu@ip-172-31-13-147.ec2.internal>
2024-04-19 20:49:22 -07:00
bc9df1571b Pass tokenizer_revision when getting tokenizer in openai serving (#4214) 2024-04-19 17:13:56 -07:00
15b86408a8 [Misc] add nccl in collect env (#4211) 2024-04-19 19:44:51 +00:00
7be4f5628f [Bugfix][Core] Restore logging of stats in the async engine (#4150) 2024-04-19 08:08:26 -07:00
8f20fc04bf [Misc] fix docstrings (#4191)
Co-authored-by: Zhong Wang <wangzhong@infini-ai.com>
2024-04-19 08:18:33 +00:00
221d93ecbf Bump version of 0.4.1 (#4177) 2024-04-19 01:00:22 -07:00
d17c8477f1 [Bugfix] Fix LoRA loading check (#4138)
Co-authored-by: simon-mo <simon.mo@hey.com>
2024-04-19 00:59:54 -07:00
a134ef6f5e Support eos_token_id from generation_config.json (#4182) 2024-04-19 04:13:36 +00:00
8a7a3e4436 [Core] add an option to log every function call to for debugging hang/crash in distributed inference (#4079)
Co-authored-by: Simon Mo <simon.mo@hey.com>
2024-04-18 16:15:12 -07:00
8f9c28fd40 [Bugfix] Fix CustomAllreduce nvlink topology detection (#3974)
[Bugfix] Fix CustomAllreduce pcie nvlink topology detection (#3974) (#4159)
2024-04-18 15:32:47 -07:00
cd2f63fb36 [CI/CD] add neuron docker and ci test scripts (#3571) 2024-04-18 15:26:01 -07:00
87fa80c91f [Misc] Bump transformers to latest version (#4176) 2024-04-18 14:36:39 -07:00
e1bb2fd52d [Bugfix] Support logprobs when using guided_json and other constrained decoding fields (#4149) 2024-04-18 21:12:55 +00:00
705578ae14 [Docs] document that Meta Llama 3 is supported (#4175) 2024-04-18 10:55:48 -07:00
e8cc7967ff [Bugfix][Kernel] allow non-power-of-two head sizes in prefix prefill (#4128) 2024-04-18 00:51:28 -07:00
53b018edcb [Bugfix] Get available quantization methods from quantization registry (#4098) 2024-04-18 00:21:55 -07:00
66ded03067 Allow model to be served under multiple names (#2894)
Co-authored-by: Alexandre Payot <alexandrep@graphcore.ai>
2024-04-18 00:16:26 -07:00
6dc1fc9cfe [Core] nccl integrity check and test (#4155)
[Core] Add integrity check during initialization; add test for it (#4155)
2024-04-17 22:28:52 -07:00
533d2a1f39 [Typing] Mypy typing part 2 (#4043)
Co-authored-by: SangBin Cho <sangcho@sangcho-LT93GQWG9C.local>
2024-04-17 17:28:43 -07:00
a53222544c [Kernel] Add punica dimension for Swallow-MS-7B LoRA (#4134) 2024-04-17 10:02:45 -07:00
fe3b5bbc23 [Bugfix] fix output parsing error for trtllm backend (#4137)
Co-authored-by: Roger Wang <ywang@roblox.com>
2024-04-17 11:07:23 +00:00
8438e0569e [Core] RayWorkerVllm --> WorkerWrapper to reduce duplication (#4024)
[Core] replace narrow-usage RayWorkerVllm to general WorkerWrapper to reduce code duplication (#4024)
2024-04-17 08:34:33 +00:00
11d652bd4f [CI] Move CPU/AMD tests to after wait (#4123) 2024-04-16 22:53:26 -07:00
d150e4f89f [Misc] [CI] Fix CI failure caught after merge (#4126) 2024-04-16 17:56:01 -07:00
e95cd87959 [Speculative decoding 6/9] Integrate speculative decoding with LLMEngine (#3894) 2024-04-16 13:09:21 -07:00
69e1d2fb69 [Core] Refactor model loading code (#4097) 2024-04-16 11:34:39 -07:00
05434764cd LM Format Enforcer Guided Decoding Support (#3868)
Co-authored-by: Simon Mo <simon.mo@hey.com>
2024-04-16 05:54:57 +00:00
4e7ee664e2 [Core] Fix engine-use-ray broken (#4105) 2024-04-16 05:24:53 +00:00
37e84a403d [Typing] Fix Sequence type GenericAlias only available after Python 3.9. (#4092) 2024-04-15 14:47:31 -07:00
4695397dcf [Bugfix] Fix ray workers profiling with nsight (#4095) 2024-04-15 14:24:45 -07:00
d619ae2d19 [Doc] Add better clarity for tensorizer usage (#4090)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
2024-04-15 13:28:25 -07:00
eb46fbfda2 [Core] Simplifications to executor classes (#4071) 2024-04-15 13:05:09 -07:00
0003e9154b [Misc][Minor] Fix CPU block num log in CPUExecutor. (#4088) 2024-04-15 08:35:55 -07:00
e11e200736 [Bugfix] Fix filelock version requirement (#4075) 2024-04-14 21:50:08 -07:00
Roy
8db1bf32f8 [Misc] Upgrade triton to 2.2.0 (#4061) 2024-04-14 17:43:54 -07:00
aceb17cf2d [Docs] document that mixtral 8x22b is supported (#4073) 2024-04-14 14:35:55 -07:00
563c54f760 [BugFix] Fix tensorizer extra in setup.py (#4072) 2024-04-14 14:12:42 -07:00
2cd6b4f362 [Core] avoid too many cuda context by caching p2p test (#4021) 2024-04-13 23:40:21 -07:00
711a000255 [Frontend] [Core] feat: Add model loading using tensorizer (#3476) 2024-04-13 17:13:01 -07:00
989ae2538d [Kernel] Add punica dimension for Baichuan-13B (#4053) 2024-04-13 07:55:05 -07:00
0a430b4ae2 [Bugfix] fix_small_bug_in_neuron_executor (#4051) 2024-04-13 07:54:03 -07:00
ec8e3c695f [Bugfix] fix_log_time_in_metrics (#4050) 2024-04-13 07:52:36 -07:00
98afde19fc [Core][Distributed] improve logging for init dist (#4042) 2024-04-13 07:12:53 -07:00
5c2e66e487 [Bugfix] More type hint fixes for py 3.8 (#4039) 2024-04-12 21:07:04 -07:00
546e721168 [CI/Test] expand ruff and yapf for all supported python version (#4037) 2024-04-13 01:43:37 +00:00
b8aacac31a [Bugfix] Fix LoRA bug (#4032) 2024-04-12 16:56:37 -07:00
d04973ad54 Fix triton compilation issue (#3984)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-04-12 16:41:26 -07:00
fbb9d9eef4 [Core] fix custom allreduce default value (#4040) 2024-04-12 16:40:39 -07:00
09473ee41c [mypy] Add mypy type annotation part 1 (#4006) 2024-04-12 14:35:50 -07:00
d4ec9ffb95 [Misc] Fix typo in scheduler.py (#4022) 2024-04-12 13:56:04 -07:00
96b6a6d790 [Bugfix] fix type hint for py 3.8 (#4036) 2024-04-12 19:35:44 +00:00
36729bac13 [Test] Test multiple attn backend for chunked prefill. (#4023) 2024-04-12 09:56:57 -07:00
7fd3949a0b [Frontend][Core] Move merge_async_iterators to utils (#4026) 2024-04-12 05:30:54 +00:00
1096717ae9 [Core] Support LoRA on quantized models (#4012) 2024-04-11 21:02:44 -07:00
c2b4a1bce9 [Doc] Add typing hints / mypy types cleanup (#3816)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
2024-04-11 17:17:21 -07:00
e46a60aa4c [BugFix] Fix handling of stop strings and stop token ids (#3672) 2024-04-11 15:34:12 -07:00
1e96c3341a Add extra punica sizes to support bigger vocabs (#4015) 2024-04-11 22:18:57 +00:00
95e7d4a97c Fix echo/logprob OpenAI completion bug (#3441)
Co-authored-by: Dylan Hawk <dylanwawk@gmail.com>
2024-04-11 22:15:50 +00:00
559eb852f8 [Core] init_distributed_environment align with init_process_group(#4014)
[Core][Distributed] make init_distributed_environment compatible with init_process_group (#4014)
2024-04-11 14:00:48 -07:00
a10d3056da [Core] Set linear_weights directly on the layer (#3977) 2024-04-11 16:35:51 -04:00
8afca50889 [Hardware][Intel] Isolate CPUModelRunner and ModelRunner for better maintenance (#3824) 2024-04-11 11:56:49 -07:00
08ccee1e83 punica fix-bgmv-kernel-640 (#4007) 2024-04-11 08:59:26 -07:00
c1dc547129 [Kernel] Fused MoE Config for Mixtral 8x22 (#4002) 2024-04-11 07:50:00 -07:00
f3d0bf7589 [Doc][Installation] delete python setup.py develop (#3989) 2024-04-11 03:33:02 +00:00
e9da5a40c6 [Misc] Add indirection layer for custom ops (#3913) 2024-04-10 20:26:07 -07:00
e42df7227d [Test] Add xformer and flash attn tests (#3961)
Co-authored-by: Simon Mo <simon.mo@hey.com>
2024-04-11 03:09:50 +00:00
caada5e50a [Core][Model] torch.compile for layernorm in commandr (#3985)
[Core][Model] Use torch.compile to accelerate layernorm in commandr (#3985)
2024-04-11 01:48:26 +00:00
67b4221a61 [Core][5/N] Fully working chunked prefill e2e (#3884) 2024-04-10 17:56:48 -07:00
63e7176f26 [Core][Refactor] move parallel_utils into vllm/distributed (#3950)
[WIP][Core][Refactor] move vllm/model_executor/parallel_utils into vllm/distributed and vllm/device_communicators (#3950)
2024-04-10 15:33:30 -07:00
934d3662f7 [Bugfix] handle hf_config with architectures == None (#3982)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
2024-04-10 22:28:25 +00:00
92cd2e2f21 [Doc] Fix getting stared to use publicly available model (#3963) 2024-04-10 18:05:52 +00:00
e4c4072c94 [Bugfix] Remove key sorting for guided_json parameter in OpenAi compatible Server (#3945) 2024-04-10 10:15:51 -07:00
e35397468f [Doc] Add doc to state our model support policy (#3948)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
2024-04-10 17:03:02 +00:00
8b317c6dd0 [Model][AMD] ROCm support for 256 head dims for Gemma (#3972) 2024-04-10 08:12:00 -07:00
bd3c144e0b [Bugfix][ROCm] Add numba to Dockerfile.rocm (#3962) 2024-04-10 07:37:17 -07:00
0258b7a94b [Bugfix] handle prompt_logprobs in _apply_min_tokens_penalty (#3876)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
2024-04-10 01:39:56 -07:00
b3104b2a10 [Bugfix] Fix logits processor when prompt_logprobs is not None (#3899) 2024-04-10 00:09:36 -07:00
c2e00af523 [Bugfix] fix utils.py/merge_dict func TypeError: 'type' object is not subscriptable (#3955)
Co-authored-by: tianyi_zhao <tianyi.zhao@transwarp.io>
2024-04-10 04:49:11 +00:00
c013d32c75 [Benchmark] Add cpu options to bench scripts (#3915) 2024-04-09 21:30:03 -07:00
11dd6ebb89 [Misc] Avoid loading incorrect LoRA config (#3777) 2024-04-09 19:47:15 -07:00
6c0b04515f [ROCm][Hardware][AMD] Use Triton Kernel for default FA on ROCm (#3643)
Co-authored-by: jpvillam <jpvillam@amd.com>
Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-04-09 15:10:47 -07:00
e23a43aef8 [Bugfix] Fix KeyError on loading GPT-NeoX (#3925) 2024-04-09 12:11:31 -07:00
e7c7067b45 [Misc] [Core] Implement RFC "Augment BaseExecutor interfaces to enable hardware-agnostic speculative decoding" (#3837) 2024-04-09 11:44:15 -07:00
6d592eb430 [Core] separate distributed_init from worker (#3904) 2024-04-09 08:49:02 +00:00
Roy
d036198e23 [BugFix][Model] Fix commandr RoPE max_position_embeddings (#3919) 2024-04-09 06:17:21 +08:00
59a6abf3c9 [Hotfix][CI/Build][Kernel] CUDA 11.8 does not support layernorm optimizations (#3782) 2024-04-08 14:31:02 -07:00
bc0c0192d1 [Bugfix] Enable Proper attention_bias Usage in Llama Model Configuration (#3767)
Co-authored-by: roy <jasonailu87@gmail.com>
2024-04-08 19:42:35 +00:00
f46864d68d [Bugfix] Added Command-R GPTQ support (#3849)
Co-authored-by: Egor Tolmachev <t333ga@gmail.com>
2024-04-08 14:59:38 +00:00
b4543c8f6b [Model] add minicpm (#3893) 2024-04-08 18:28:36 +08:00
0ce0539d47 [Bugfix] Fix Llava inference with Tensor Parallelism. (#3883) 2024-04-07 22:54:13 +08:00
2f19283549 [Core] latency optimization (#3890) 2024-04-06 19:14:06 -07:00
95baec828f [Core] enable out-of-tree model register (#3871) 2024-04-06 17:11:41 -07:00
e4be7d70bb [CI/Benchmark] add more iteration and use median for robust latency benchmark (#3889) 2024-04-06 21:32:30 +00:00
54951ac4bf [Bugfix] Fix incorrect output on OLMo models in Tensor Parallelism (#3869) 2024-04-05 12:02:09 -07:00
18de883489 [Chunked Prefill][4/n] Chunked prefill scheduler. (#3853) 2024-04-05 10:17:58 -07:00
1d7c940d74 Add option to completion API to truncate prompt tokens (#3144) 2024-04-05 10:15:42 -07:00
cfaf49a167 [Misc] Define common requirements (#3841) 2024-04-05 00:39:17 -07:00
9edec652e2 [Bugfix] Fixing requirements.txt (#3865) 2024-04-04 23:46:01 -07:00
e0dd4d3589 [Misc] Fix linter issues in examples/fp8/quantizer/quantize.py (#3864) 2024-04-04 21:57:33 -07:00
e5043a3e75 [Misc] Add pytest marker to opt-out of global test cleanup (#3863) 2024-04-04 21:54:16 -07:00
d03d64fd2e [CI/Build] refactor dockerfile & fix pip cache
[CI/Build] fix pip cache with vllm_nccl & refactor dockerfile to build wheels (#3859)
2024-04-04 21:53:16 -07:00
78107fa091 [Doc]Add asynchronous engine arguments to documentation. (#3810)
Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
2024-04-04 21:52:01 -07:00
c391e4b68e [Core] improve robustness of pynccl (#3860) 2024-04-04 16:52:12 -07:00
9117f892f0 [Model] Cohere CommandR+ (#3829) 2024-04-04 13:31:49 -07:00
db2a6a41e2 [Hardware][CPU] Update cpu torch to match default of 2.2.1 (#3854) 2024-04-04 19:49:49 +00:00
ca81ff5196 [Core] manage nccl via a pypi package & upgrade to pt 2.2.1 (#3805) 2024-04-04 10:26:19 -07:00
b7782002e1 [Benchmark] Refactor sample_requests in benchmark_throughput (#3613)
Co-authored-by: Roger Wang <ywang@roblox.com>
2024-04-04 09:56:22 +00:00
819a309c0f [Bugfix] Fix args in benchmark_serving (#3836)
Co-authored-by: Roger Wang <ywang@roblox.com>
2024-04-04 07:41:05 +00:00
aabe8f40f2 [Core] [Frontend] Make detokenization optional (#3749)
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
2024-04-03 21:52:18 -07:00
498eb5cfa3 [Bugfix] Add kv_scale input parameter to CPU backend (#3840) 2024-04-04 04:33:08 +00:00
537ee25f43 [Core] Enable hf_transfer by default if available (#3817) 2024-04-04 04:02:43 +00:00
294f8f6665 [BugFix] Pass tokenizer_config to local_tokenizer_group (#3754)
Signed-off-by: Tao He <sighingnow@gmail.com>
2024-04-03 20:31:46 -07:00
b95047f2da [Misc] Publish 3rd meetup slides (#3835) 2024-04-03 15:46:10 -07:00
2ff767b513 Enable scaled FP8 (e4m3fn) KV cache on ROCm (AMD GPU) (#3290)
Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: HaiShaw <hixiao@gmail.com>
Co-authored-by: AdrianAbeyta <Adrian.Abeyta@amd.com>
Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com>
Co-authored-by: root <root@gt-pla-u18-08.pla.dcgpu>
Co-authored-by: mawong-amd <156021403+mawong-amd@users.noreply.github.com>
Co-authored-by: ttbachyinsda <ttbachyinsda@outlook.com>
Co-authored-by: guofangze <guofangze@kuaishou.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: jacobthebanana <50071502+jacobthebanana@users.noreply.github.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-04-03 14:15:55 -07:00
3dcb3e8b98 [3/N] Refactor scheduler for chunked prefill scheduling (#3550) 2024-04-03 14:13:49 -07:00
c64cf38673 [Doc] Update contribution guidelines for better onboarding (#3819) 2024-04-03 07:31:43 +00:00
76b889bf1d [Doc] Update README.md (#3806) 2024-04-02 23:11:10 -07:00
c9b506dad4 [BugFix] Use different mechanism to get vllm version in is_cpu() (#3804) 2024-04-02 23:06:25 -07:00
5757d90e26 [Speculative decoding] Adding configuration object for speculative decoding (#3706)
Co-authored-by: Lily Liu <lilyliupku@gmail.com>
2024-04-03 00:40:57 +00:00
582 changed files with 64382 additions and 15720 deletions

View File

@ -0,0 +1,36 @@
import os
import zipfile
MAX_SIZE_MB = 200
def print_top_10_largest_files(zip_file):
with zipfile.ZipFile(zip_file, 'r') as z:
file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()]
file_sizes.sort(key=lambda x: x[1], reverse=True)
for f, size in file_sizes[:10]:
print(f"{f}: {size/(1024*1024)} MBs uncompressed.")
def check_wheel_size(directory):
for root, _, files in os.walk(directory):
for f in files:
if f.endswith(".whl"):
wheel_path = os.path.join(root, f)
wheel_size = os.path.getsize(wheel_path)
wheel_size_mb = wheel_size / (1024 * 1024)
if wheel_size_mb > MAX_SIZE_MB:
print(
f"Wheel {wheel_path} is too large ({wheel_size_mb} MB) "
f"compare to the allowed size ({MAX_SIZE_MB} MB).")
print_top_10_largest_files(wheel_path)
return 1
else:
print(f"Wheel {wheel_path} is within the allowed size "
f"({wheel_size_mb} MB).")
return 0
if __name__ == "__main__":
import sys
sys.exit(check_wheel_size(sys.argv[1]))

View File

@ -1,38 +1,73 @@
# This script build the ROCm docker image and run the API server inside the container. # This script runs test inside the corresponding ROCm docker container.
# It serves a sanity check for compilation and basic model usage.
set -ex set -ex
# Print ROCm version # Print ROCm version
echo "--- ROCm info"
rocminfo rocminfo
# Try building the docker image # cleanup older docker images
docker build -t rocm -f Dockerfile.rocm . cleanup_docker() {
# Get Docker's root directory
# Setup cleanup docker_root=$(docker info -f '{{.DockerRootDir}}')
remove_docker_container() { docker rm -f rocm || true; } if [ -z "$docker_root" ]; then
trap remove_docker_container EXIT echo "Failed to determine Docker root directory."
remove_docker_container exit 1
fi
# Run the image echo "Docker root directory: $docker_root"
docker run --device /dev/kfd --device /dev/dri --network host --name rocm rocm python3 -m vllm.entrypoints.api_server & # Check disk usage of the filesystem where Docker's root directory is located
disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//')
# Wait for the server to start # Define the threshold
wait_for_server_to_start() { threshold=70
timeout=300 if [ "$disk_usage" -gt "$threshold" ]; then
counter=0 echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..."
# Remove dangling images (those that are not tagged and not used by any container)
while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do docker image prune -f
sleep 1 # Remove unused volumes
counter=$((counter + 1)) docker volume prune -f
if [ $counter -ge $timeout ]; then echo "Docker images and volumes cleanup completed."
echo "Timeout after $timeout seconds" else
break echo "Disk usage is below $threshold%. No cleanup needed."
fi fi
done
} }
wait_for_server_to_start
# Test a simple prompt # Call the cleanup docker function
curl -X POST -H "Content-Type: application/json" \ cleanup_docker
localhost:8000/generate \
-d '{"prompt": "San Francisco is a"}' echo "--- Resetting GPUs"
echo "reset" > /opt/amdgpu/etc/gpu_state
while true; do
sleep 3
if grep -q clean /opt/amdgpu/etc/gpu_state; then
echo "GPUs state is \"clean\""
break
fi
done
echo "--- Building container"
sha=$(git rev-parse --short HEAD)
image_name=rocm_${sha}
container_name=rocm_${sha}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)
docker build \
-t ${image_name} \
-f Dockerfile.rocm \
--progress plain \
.
remove_docker_container() {
docker rm -f ${container_name} || docker image rm -f ${image_name} || true
}
trap remove_docker_container EXIT
echo "--- Running container"
docker run \
--device /dev/kfd --device /dev/dri \
--network host \
--rm \
-e HF_TOKEN \
--name ${container_name} \
${image_name} \
/bin/bash -c "${@}"

View File

@ -9,10 +9,10 @@ cd "$(dirname "${BASH_SOURCE[0]}")/.."
(which wget && which curl) || (apt-get update && apt-get install -y wget curl) (which wget && which curl) || (apt-get update && apt-get install -y wget curl)
# run python-based benchmarks and upload the result to buildkite # run python-based benchmarks and upload the result to buildkite
python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt python3 benchmarks/benchmark_latency.py --output-json latency_results.json 2>&1 | tee benchmark_latency.txt
bench_latency_exit_code=$? bench_latency_exit_code=$?
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 --output-json throughput_results.json 2>&1 | tee benchmark_throughput.txt
bench_throughput_exit_code=$? bench_throughput_exit_code=$?
# run server-based benchmarks and upload the result to buildkite # run server-based benchmarks and upload the result to buildkite
@ -53,6 +53,11 @@ echo '```' >> benchmark_results.md
tail -n 20 benchmark_serving.txt >> benchmark_results.md # last 20 lines tail -n 20 benchmark_serving.txt >> benchmark_results.md # last 20 lines
echo '```' >> benchmark_results.md echo '```' >> benchmark_results.md
# if the agent binary is not found, skip uploading the results, exit 0
if [ ! -f /workspace/buildkite-agent ]; then
exit 0
fi
# upload the results to buildkite # upload the results to buildkite
/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md /workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
@ -69,4 +74,5 @@ if [ $bench_serving_exit_code -ne 0 ]; then
exit $bench_serving_exit_code exit $bench_serving_exit_code
fi fi
/workspace/buildkite-agent artifact upload openai-*.json rm ShareGPT_V3_unfiltered_cleaned_split.json
/workspace/buildkite-agent artifact upload "*.json"

View File

@ -10,5 +10,15 @@ remove_docker_container() { docker rm -f cpu-test || true; }
trap remove_docker_container EXIT trap remove_docker_container EXIT
remove_docker_container remove_docker_container
# Run the image and launch offline inference # Run the image
docker run --network host --env VLLM_CPU_KVCACHE_SPACE=1 --name cpu-test cpu-test python3 examples/offline_inference.py docker run -itd -v ~/.cache/huggingface:/root/.cache/huggingface --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test
# offline inference
docker exec cpu-test bash -c "python3 examples/offline_inference.py"
# Run basic model test
docker exec cpu-test bash -c "cd tests;
pip install pytest Pillow protobuf
bash ../.buildkite/download-images.sh
cd ../
pytest -v -s tests/models --ignore=tests/models/test_llava.py --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py"

View File

@ -0,0 +1,51 @@
# This script build the Neuron docker image and run the API server inside the container.
# It serves a sanity check for compilation and basic model usage.
set -e
# Try building the docker image
aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com
# prune old image and containers to save disk space, and only once a day
# by using a timestamp file in tmp.
if [ -f /tmp/neuron-docker-build-timestamp ]; then
last_build=$(cat /tmp/neuron-docker-build-timestamp)
current_time=$(date +%s)
if [ $((current_time - last_build)) -gt 86400 ]; then
docker system prune -f
echo $current_time > /tmp/neuron-docker-build-timestamp
fi
else
echo $(date +%s) > /tmp/neuron-docker-build-timestamp
fi
docker build -t neuron -f Dockerfile.neuron .
# Setup cleanup
remove_docker_container() { docker rm -f neuron || true; }
trap remove_docker_container EXIT
remove_docker_container
# Run the image
docker run --device=/dev/neuron0 --device=/dev/neuron1 --network host --name neuron neuron python3 -m vllm.entrypoints.api_server \
--model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --max-num-seqs 8 --max-model-len 128 --block-size 128 --device neuron --tensor-parallel-size 2 &
# Wait for the server to start
wait_for_server_to_start() {
timeout=300
counter=0
while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do
sleep 1
counter=$((counter + 1))
if [ $counter -ge $timeout ]; then
echo "Timeout after $timeout seconds"
break
fi
done
}
wait_for_server_to_start
# Test a simple prompt
curl -X POST -H "Content-Type: application/json" \
localhost:8000/generate \
-d '{"prompt": "San Francisco is a"}'

View File

@ -5,92 +5,161 @@
steps: steps:
- label: Regression Test - label: Regression Test
mirror_hardwares: [amd]
command: pytest -v -s test_regression.py command: pytest -v -s test_regression.py
working_dir: "/vllm-workspace/tests" # optional working_dir: "/vllm-workspace/tests" # optional
- label: AsyncEngine Test - label: AsyncEngine Test
#mirror_hardwares: [amd]
command: pytest -v -s async_engine command: pytest -v -s async_engine
- label: Basic Correctness Test - label: Basic Correctness Test
command: pytest -v -s basic_correctness mirror_hardwares: [amd]
commands:
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
- label: Core Test - label: Core Test
mirror_hardwares: [amd]
command: pytest -v -s core command: pytest -v -s core
- label: Distributed Comm Ops Test - label: Distributed Comm Ops Test
command: pytest -v -s test_comm_ops.py #mirror_hardwares: [amd]
working_dir: "/vllm-workspace/tests/distributed" command: pytest -v -s distributed/test_comm_ops.py
num_gpus: 2 # only support 1 or 2 for now. working_dir: "/vllm-workspace/tests"
num_gpus: 2
- label: Distributed Tests - label: Distributed Tests
working_dir: "/vllm-workspace/tests/distributed" mirror_hardwares: [amd]
num_gpus: 2 # only support 1 or 2 for now. working_dir: "/vllm-workspace/tests"
num_gpus: 2
commands: commands:
- pytest -v -s test_pynccl.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s spec_decode/e2e/test_integration_dist.py
- label: Distributed Tests (Multiple Groups)
#mirror_hardwares: [amd]
working_dir: "/vllm-workspace/tests"
num_gpus: 4
commands:
- pytest -v -s distributed/test_pynccl.py
- label: Engine Test - label: Engine Test
command: pytest -v -s engine tokenization test_sequence.py test_config.py mirror_hardwares: [amd]
command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
- label: Entrypoints Test - label: Entrypoints Test
command: pytest -v -s entrypoints mirror_hardwares: [amd]
commands:
- pytest -v -s test_inputs.py
- pytest -v -s entrypoints -m llm
- pytest -v -s entrypoints -m openai
- label: Examples Test - label: Examples Test
working_dir: "/vllm-workspace/examples" working_dir: "/vllm-workspace/examples"
mirror_hardwares: [amd]
commands: commands:
# install aws cli for llava_example.py # install aws cli for llava_example.py
- pip install awscli # install tensorizer for tensorize_vllm_model.py
- pip install awscli tensorizer
- python3 offline_inference.py - python3 offline_inference.py
- python3 offline_inference_with_prefix.py - python3 offline_inference_with_prefix.py
- python3 llm_engine_example.py - python3 llm_engine_example.py
- python3 llava_example.py - python3 llava_example.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- label: Kernels Test %N - label: Kernels Test %N
#mirror_hardwares: [amd]
command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 4 parallelism: 4
- label: Models Test - label: Models Test
#mirror_hardwares: [amd]
commands: commands:
- bash ../.buildkite/download-images.sh - bash ../.buildkite/download-images.sh
- pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py - pytest -v -s models --ignore=models/test_llava.py
- label: Llava Test - label: Llava Test
mirror_hardwares: [amd]
commands: commands:
- bash ../.buildkite/download-images.sh - bash ../.buildkite/download-images.sh
- pytest -v -s models/test_llava.py - pytest -v -s models/test_llava.py
- label: Prefix Caching Test - label: Prefix Caching Test
mirror_hardwares: [amd]
commands: commands:
- pytest -v -s prefix_caching - pytest -v -s prefix_caching
- label: Samplers Test - label: Samplers Test
#mirror_hardwares: [amd]
command: pytest -v -s samplers command: pytest -v -s samplers
- label: LogitsProcessor Test - label: LogitsProcessor Test
mirror_hardwares: [amd]
command: pytest -v -s test_logits_processor.py command: pytest -v -s test_logits_processor.py
- label: Utils Test
command: pytest -v -s test_utils.py
- label: Worker Test - label: Worker Test
mirror_hardwares: [amd]
command: pytest -v -s worker command: pytest -v -s worker
- label: Speculative decoding tests - label: Speculative decoding tests
#mirror_hardwares: [amd]
command: pytest -v -s spec_decode command: pytest -v -s spec_decode
- label: LoRA Test %N - label: LoRA Test %N
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT #mirror_hardwares: [amd]
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
parallelism: 4 parallelism: 4
- label: LoRA Long Context (Distributed)
#mirror_hardwares: [amd]
num_gpus: 4
# This test runs llama 13B, so it is required to run on 4 GPUs.
commands:
# Temporarily run this way because we cannot clean up GPU mem usage
# for multi GPU tests.
# TODO(sang): Fix it.
- pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced
- pytest -v -s lora/test_long_context.py::test_batched_rope_kernel
- pytest -v -s lora/test_long_context.py::test_self_consistency
- pytest -v -s lora/test_long_context.py::test_quality
- pytest -v -s lora/test_long_context.py::test_max_len
- label: Tensorizer Test
#mirror_hardwares: [amd]
command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader
- label: Metrics Test - label: Metrics Test
mirror_hardwares: [amd]
command: pytest -v -s metrics command: pytest -v -s metrics
- label: Quantization Test
#mirror_hardwares: [amd]
command: pytest -v -s quantization
- label: Benchmarks - label: Benchmarks
working_dir: "/vllm-workspace/.buildkite" working_dir: "/vllm-workspace/.buildkite"
mirror_hardwares: [amd]
commands: commands:
- pip install aiohttp - pip install aiohttp
- bash run-benchmarks.sh - bash run-benchmarks.sh
- label: Documentation Build - label: Documentation Build
working_dir: "/vllm-workspace/docs" working_dir: "/vllm-workspace/test_docs/docs"
no_gpu: True no_gpu: True
commands: commands:
- pip install -r requirements-docs.txt - pip install -r requirements-docs.txt

View File

@ -0,0 +1,59 @@
{% set docker_image = "public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT" %}
{% set default_working_dir = "/vllm-workspace/tests" %}
steps:
- label: ":docker: build image"
agents:
queue: cpu_queue
commands:
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
- "docker push {{ docker_image }}"
env:
DOCKER_BUILDKIT: "1"
retry:
automatic:
- exit_status: -1 # Agent was lost
limit: 5
- exit_status: -10 # Agent was lost
limit: 5
- wait
{% for step in steps %}
- label: "{{ step.label }}"
agents:
{% if step.no_gpu %}
queue: cpu_queue
{% elif step.num_gpus == 2 or step.num_gpus == 4 %}
queue: gpu_4_queue
{% else %}
queue: gpu_1_queue
{% endif %}
soft_fail: true
{% if step.parallelism %}
parallelism: {{ step.parallelism }}
{% endif %}
retry:
automatic:
- exit_status: -1 # Agent was lost
limit: 5
- exit_status: -10 # Agent was lost
limit: 5
plugins:
- docker#v5.2.0:
image: {{ docker_image }}
always-pull: true
propagate-environment: true
{% if not step.no_gpu %}
gpus: all
{% endif %}
command: ["bash", "-c", "cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}"]
environment:
- VLLM_USAGE_SOURCE=ci-test
- HF_TOKEN
{% if step.label == "Speculative decoding tests" %}
- VLLM_ATTENTION_BACKEND=XFORMERS
{% endif %}
volumes:
- /dev/shm:/dev/shm
{% endfor %}

View File

@ -3,16 +3,8 @@
{% set default_working_dir = "/vllm-workspace/tests" %} {% set default_working_dir = "/vllm-workspace/tests" %}
steps: steps:
- label: "AMD Test"
agents:
queue: amd
command: bash .buildkite/run-amd-test.sh
- label: "CPU Test"
command: bash .buildkite/run-cpu-test.sh
- label: ":docker: build image" - label: ":docker: build image"
commands: commands:
- "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ." - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
- "docker push {{ docker_image }}" - "docker push {{ docker_image }}"
env: env:
@ -21,8 +13,37 @@ steps:
automatic: automatic:
- exit_status: -1 # Agent was lost - exit_status: -1 # Agent was lost
limit: 5 limit: 5
- exit_status: -10 # Agent was lost
limit: 5
- wait - wait
- group: "AMD Tests"
depends_on: ~
steps:
{% for step in steps %}
{% if step.mirror_hardwares and "amd" in step.mirror_hardwares %}
- label: "AMD: {{ step.label }}"
agents:
queue: amd
command: bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" ; ")) | safe }}"
env:
DOCKER_BUILDKIT: "1"
{% endif %}
{% endfor %}
- label: "Neuron Test"
depends_on: ~
agents:
queue: neuron
command: bash .buildkite/run-neuron-test.sh
soft_fail: true
- label: "Intel Test"
depends_on: ~
agents:
queue: intel
command: bash .buildkite/run-cpu-test.sh
{% for step in steps %} {% for step in steps %}
- label: "{{ step.label }}" - label: "{{ step.label }}"
agents: agents:
@ -35,9 +56,14 @@ steps:
automatic: automatic:
- exit_status: -1 # Agent was lost - exit_status: -1 # Agent was lost
limit: 5 limit: 5
- exit_status: -10 # Agent was lost
limit: 5
plugins: plugins:
- kubernetes: - kubernetes:
podSpec: podSpec:
{% if step.num_gpus %}
priorityClassName: gpu-priority-cls-{{ step.num_gpus }}
{% endif %}
volumes: volumes:
- name: dshm - name: dshm
emptyDir: emptyDir:

26
.clang-format Normal file
View File

@ -0,0 +1,26 @@
BasedOnStyle: Google
UseTab: Never
IndentWidth: 2
ColumnLimit: 80
# Force pointers to the type for C++.
DerivePointerAlignment: false
PointerAlignment: Left
# Reordering #include statements can (and currently will) introduce errors
SortIncludes: false
# Style choices
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
IndentPPDirectives: BeforeHash
IncludeCategories:
- Regex: '^<'
Priority: 4
- Regex: '^"(llvm|llvm-c|clang|clang-c|mlir|mlir-c)/'
Priority: 3
- Regex: '^"(qoda|\.\.)/'
Priority: 2
- Regex: '.*'
Priority: 1

View File

@ -18,6 +18,7 @@ body:
# For security purposes, please feel free to check the contents of collect_env.py before running it. # For security purposes, please feel free to check the contents of collect_env.py before running it.
python collect_env.py python collect_env.py
``` ```
It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
value: | value: |
```text ```text
The output of `python collect_env.py` The output of `python collect_env.py`

View File

@ -18,6 +18,7 @@ body:
# For security purposes, please feel free to check the contents of collect_env.py before running it. # For security purposes, please feel free to check the contents of collect_env.py before running it.
python collect_env.py python collect_env.py
``` ```
It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
value: | value: |
```text ```text
The output of `python collect_env.py` The output of `python collect_env.py`

View File

@ -18,6 +18,7 @@ body:
# For security purposes, please feel free to check the contents of collect_env.py before running it. # For security purposes, please feel free to check the contents of collect_env.py before running it.
python collect_env.py python collect_env.py
``` ```
It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
value: | value: |
```text ```text
The output of `python collect_env.py` The output of `python collect_env.py`
@ -57,6 +58,10 @@ body:
If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
Please set the environment variable `export VLLM_LOGGING_LEVEL=DEBUG` to turn on more logging to help debugging potential issues.
If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs.
placeholder: | placeholder: |
A clear and concise description of what the bug is. A clear and concise description of what the bug is.

View File

@ -39,6 +39,7 @@ body:
# For security purposes, please feel free to check the contents of collect_env.py before running it. # For security purposes, please feel free to check the contents of collect_env.py before running it.
python collect_env.py python collect_env.py
``` ```
It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues.
value: | value: |
```text ```text
The output of `python collect_env.py` The output of `python collect_env.py`

49
.github/ISSUE_TEMPLATE/750-RFC.yml vendored Normal file
View File

@ -0,0 +1,49 @@
name: 💬 Request for comments (RFC).
description: Ask for feedback on major architectural changes or design choices.
title: "[RFC]: "
labels: ["RFC"]
body:
- type: markdown
attributes:
value: >
#### Please take a look at previous [RFCs](https://github.com/vllm-project/vllm/issues?q=label%3ARFC+sort%3Aupdated-desc) for reference.
- type: textarea
attributes:
label: Motivation.
description: >
The motivation of the RFC.
validations:
required: true
- type: textarea
attributes:
label: Proposed Change.
description: >
The proposed change of the RFC.
validations:
required: true
- type: textarea
attributes:
label: Feedback Period.
description: >
The feedback period of the RFC. Usually at least one week.
validations:
required: false
- type: textarea
attributes:
label: CC List.
description: >
The list of people you want to CC.
validations:
required: false
- type: textarea
attributes:
label: Any Other Things.
description: >
Any other things you would like to mention.
validations:
required: false
- type: markdown
attributes:
value: >
Thanks for contributing 🎉!

42
.github/workflows/clang-format.yml vendored Normal file
View File

@ -0,0 +1,42 @@
name: clang-format
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
clang-format:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install clang-format==18.1.5
- name: Running clang-format
run: |
EXCLUDES=(
'csrc/moe/topk_softmax_kernels.cu'
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
'csrc/punica/bgmv/bgmv_config.h'
'csrc/punica/bgmv/bgmv_impl.cuh'
'csrc/punica/bgmv/vec_dtypes.cuh'
'csrc/punica/punica_ops.cu'
'csrc/punica/type_convert.h'
)
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
| xargs clang-format --dry-run --Werror

51
.github/workflows/mypy.yaml vendored Normal file
View File

@ -0,0 +1,51 @@
name: mypy
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
ruff:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install mypy==1.9.0
pip install types-setuptools
pip install types-PyYAML
pip install types-requests
pip install types-setuptools
- name: Mypy
run: |
mypy vllm/attention --config-file pyproject.toml
mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/multimodal --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml

View File

@ -49,13 +49,19 @@ jobs:
matrix: matrix:
os: ['ubuntu-20.04'] os: ['ubuntu-20.04']
python-version: ['3.8', '3.9', '3.10', '3.11'] python-version: ['3.8', '3.9', '3.10', '3.11']
pytorch-version: ['2.1.2'] # Must be the most recent version that meets requirements.txt. pytorch-version: ['2.3.0'] # Must be the most recent version that meets requirements-cuda.txt.
cuda-version: ['11.8', '12.1'] cuda-version: ['11.8', '12.1']
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Setup ccache
uses: hendrikmuhs/ccache-action@v1.2
with:
create-symlink: true
key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }}
- name: Set up Linux Env - name: Set up Linux Env
if: ${{ runner.os == 'Linux' }} if: ${{ runner.os == 'Linux' }}
run: | run: |
@ -76,6 +82,8 @@ jobs:
- name: Build wheel - name: Build wheel
shell: bash shell: bash
env:
CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size
run: | run: |
bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename) wheel_name=$(ls dist/*whl | xargs -n 1 basename)

View File

@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: ["3.10"] python-version: ["3.8", "3.9", "3.10", "3.11"]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}

View File

@ -9,7 +9,7 @@ LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
# Install requirements # Install requirements
$python_executable -m pip install wheel packaging $python_executable -m pip install wheel packaging
$python_executable -m pip install -r requirements.txt $python_executable -m pip install -r requirements-cuda.txt
# Limit the number of parallel jobs to avoid OOM # Limit the number of parallel jobs to avoid OOM
export MAX_JOBS=1 export MAX_JOBS=1

View File

@ -8,7 +8,7 @@ module.exports = async (github, context, core) => {
generate_release_notes: true, generate_release_notes: true,
name: process.env.RELEASE_TAG, name: process.env.RELEASE_TAG,
owner: context.repo.owner, owner: context.repo.owner,
prerelease: false, prerelease: true,
repo: context.repo.repo, repo: context.repo.repo,
tag_name: process.env.RELEASE_TAG, tag_name: process.env.RELEASE_TAG,
}); });

View File

@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: ["3.10"] python-version: ["3.8", "3.9", "3.10", "3.11"]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}

3
.gitignore vendored
View File

@ -70,6 +70,8 @@ instance/
# Sphinx documentation # Sphinx documentation
docs/_build/ docs/_build/
docs/source/getting_started/examples/*.rst
!**/*.template.rst
# PyBuilder # PyBuilder
.pybuilder/ .pybuilder/
@ -181,6 +183,7 @@ _build/
# hip files generated by PyTorch # hip files generated by PyTorch
*.hip *.hip
*_hip* *_hip*
hip_compat.h
# Benchmark dataset # Benchmark dataset
*.json *.json

View File

@ -19,7 +19,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11")
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0") set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
# Supported AMD GPU architectures. # Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100") set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100")
# #
# Supported/expected torch versions for CUDA/ROCm. # Supported/expected torch versions for CUDA/ROCm.
@ -31,7 +31,7 @@ set(HIP_SUPPORTED_ARCHS "gfx908;gfx90a;gfx942;gfx1100")
# requirements.txt files and should be kept consistent. The ROCm torch # requirements.txt files and should be kept consistent. The ROCm torch
# versions are derived from Dockerfile.rocm # versions are derived from Dockerfile.rocm
# #
set(TORCH_SUPPORTED_VERSION_CUDA "2.1.2") set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0")
set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1") set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1") set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
@ -167,15 +167,47 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu" "csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu" "csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu" "csrc/moe_align_block_size_kernels.cu"
"csrc/pybind.cpp") "csrc/pybind.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA") if(VLLM_GPU_LANG STREQUAL "CUDA")
include(FetchContent)
SET(CUTLASS_ENABLE_HEADERS_ONLY=ON)
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
# CUTLASS 3.5.0
GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc
)
FetchContent_MakeAvailable(cutlass)
list(APPEND VLLM_EXT_SRC list(APPEND VLLM_EXT_SRC
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/marlin_cuda_kernel.cu" "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/custom_all_reduce.cu") "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu")
#
# The CUTLASS kernels for Hopper require sm90a to be enabled.
# This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
set_source_files_properties(
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a")
endif()
endif() endif()
define_gpu_extension_target( define_gpu_extension_target(
@ -185,6 +217,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_EXT_SRC} SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS} COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES} ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
WITH_SOABI) WITH_SOABI)
# #
@ -210,24 +243,13 @@ define_gpu_extension_target(
set(VLLM_PUNICA_EXT_SRC set(VLLM_PUNICA_EXT_SRC
"csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu" "csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu" "csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
"csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu"
"csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu" "csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu"
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu" "csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
"csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu" "csrc/punica/punica_ops.cu"
"csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu" "csrc/punica/punica_pybind.cpp")
"csrc/punica/punica_ops.cc")
# #
# Copy GPU compilation flags+update for punica # Copy GPU compilation flags+update for punica
@ -251,6 +273,9 @@ if (${VLLM_GPU_LANG} STREQUAL "CUDA")
endif() endif()
endforeach() endforeach()
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}") message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
elseif(${VLLM_GPU_LANG} STREQUAL "HIP")
set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES})
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
endif() endif()
if (VLLM_PUNICA_GPU_ARCHES) if (VLLM_PUNICA_GPU_ARCHES)
@ -285,9 +310,7 @@ add_custom_target(default)
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling C extension.") message(STATUS "Enabling C extension.")
add_dependencies(default _C) add_dependencies(default _C)
endif()
if(VLLM_GPU_LANG STREQUAL "CUDA")
message(STATUS "Enabling moe extension.") message(STATUS "Enabling moe extension.")
add_dependencies(default _moe_C) add_dependencies(default _moe_C)

View File

@ -21,7 +21,6 @@ Express your support on Twitter if vLLM aids you, or simply offer your appreciat
### Build from source ### Build from source
```bash ```bash
pip install -r requirements.txt
pip install -e . # This may take several minutes. pip install -e . # This may take several minutes.
``` ```
@ -30,6 +29,8 @@ pip install -e . # This may take several minutes.
```bash ```bash
pip install -r requirements-dev.txt pip install -r requirements-dev.txt
# linting and formatting
bash format.sh
# Static type checking # Static type checking
mypy mypy
# Unit tests # Unit tests

View File

@ -1,8 +1,13 @@
# The vLLM Dockerfile is used to construct vLLM image that can be directly used # The vLLM Dockerfile is used to construct vLLM image that can be directly used
# to run the OpenAI compatible server. # to run the OpenAI compatible server.
# Please update any changes made here to
# docs/source/dev/dockerfile/dockerfile.rst and
# docs/source/assets/dev/dockerfile-stages-dependency.png
#################### BASE BUILD IMAGE #################### #################### BASE BUILD IMAGE ####################
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev # prepare basic build environment
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS dev
RUN apt-get update -y \ RUN apt-get update -y \
&& apt-get install -y python3-pip git && apt-get install -y python3-pip git
@ -11,23 +16,31 @@ RUN apt-get update -y \
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully # https://github.com/pytorch/pytorch/issues/107960 -- hopefully
# this won't be needed for future versions of this docker image # this won't be needed for future versions of this docker image
# or future versions of triton. # or future versions of triton.
RUN ldconfig /usr/local/cuda-12.1/compat/ RUN ldconfig /usr/local/cuda-12.4/compat/
WORKDIR /workspace WORKDIR /workspace
# install build and runtime dependencies # install build and runtime dependencies
COPY requirements.txt requirements.txt COPY requirements-common.txt requirements-common.txt
COPY requirements-cuda.txt requirements-cuda.txt
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements.txt pip install -r requirements-cuda.txt
# install development dependencies # install development dependencies
COPY requirements-dev.txt requirements-dev.txt COPY requirements-dev.txt requirements-dev.txt
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements-dev.txt pip install -r requirements-dev.txt
# cuda arch list used by torch
# can be useful for both `dev` and `test`
# explicitly set the list to avoid issues with torch 2.2
# see https://github.com/pytorch/pytorch/pull/123243
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
#################### BASE BUILD IMAGE #################### #################### BASE BUILD IMAGE ####################
#################### EXTENSION BUILD IMAGE #################### #################### WHEEL BUILD IMAGE ####################
FROM dev AS build FROM dev AS build
# install build dependencies # install build dependencies
@ -38,18 +51,16 @@ RUN --mount=type=cache,target=/root/.cache/pip \
# install compiler cache to speed up compilation leveraging local or remote caching # install compiler cache to speed up compilation leveraging local or remote caching
RUN apt-get update -y && apt-get install -y ccache RUN apt-get update -y && apt-get install -y ccache
# copy input files # files and directories related to build wheels
COPY csrc csrc COPY csrc csrc
COPY setup.py setup.py COPY setup.py setup.py
COPY cmake cmake COPY cmake cmake
COPY CMakeLists.txt CMakeLists.txt COPY CMakeLists.txt CMakeLists.txt
COPY requirements.txt requirements.txt COPY requirements-common.txt requirements-common.txt
COPY requirements-cuda.txt requirements-cuda.txt
COPY pyproject.toml pyproject.toml COPY pyproject.toml pyproject.toml
COPY vllm/__init__.py vllm/__init__.py COPY vllm vllm
# cuda arch list used by torch
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
# max jobs used by Ninja to build extensions # max jobs used by Ninja to build extensions
ARG max_jobs=2 ARG max_jobs=2
ENV MAX_JOBS=${max_jobs} ENV MAX_JOBS=${max_jobs}
@ -61,77 +72,64 @@ ENV VLLM_INSTALL_PUNICA_KERNELS=1
ENV CCACHE_DIR=/root/.cache/ccache ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \ RUN --mount=type=cache,target=/root/.cache/ccache \
python3 setup.py build_ext --inplace --mount=type=cache,target=/root/.cache/pip \
python3 setup.py bdist_wheel --dist-dir=dist
# check the size of the wheel, we cannot upload wheels larger than 100MB
COPY .buildkite/check-wheel-size.py check-wheel-size.py
RUN python3 check-wheel-size.py dist
#################### EXTENSION Build IMAGE #################### #################### EXTENSION Build IMAGE ####################
#################### FLASH_ATTENTION Build IMAGE #################### #################### vLLM installation IMAGE ####################
FROM dev as flash-attn-builder # image with vLLM installed
# max jobs used for build FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base
ARG max_jobs=2 WORKDIR /vllm-workspace
ENV MAX_JOBS=${max_jobs}
# flash attention version
ARG flash_attn_version=v2.5.6
ENV FLASH_ATTN_VERSION=${flash_attn_version}
WORKDIR /usr/src/flash-attention-v2 RUN apt-get update -y \
&& apt-get install -y python3-pip git vim
# Download the wheel or build it if a pre-compiled release doesn't exist # Workaround for https://github.com/openai/triton/issues/2507 and
RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ # https://github.com/pytorch/pytorch/issues/107960 -- hopefully
--no-build-isolation --no-deps --no-cache-dir # this won't be needed for future versions of this docker image
# or future versions of triton.
RUN ldconfig /usr/local/cuda-12.4/compat/
# install vllm wheel first, so that torch etc will be installed
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/pip \
pip install dist/*.whl --verbose
#################### vLLM installation IMAGE ####################
#################### FLASH_ATTENTION Build IMAGE ####################
#################### TEST IMAGE #################### #################### TEST IMAGE ####################
# image to run unit testing suite # image to run unit testing suite
FROM dev AS test # note that this uses vllm installed by `pip`
FROM vllm-base AS test
# copy pytorch extensions separately to avoid having to rebuild
# when python code changes
WORKDIR /vllm-workspace
# ADD is used to preserve directory structure
ADD . /vllm-workspace/ ADD . /vllm-workspace/
COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/
# Install flash attention (from pre-built wheel)
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
# ignore build dependencies installation because we are using pre-complied extensions
RUN rm pyproject.toml
RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose
#################### TEST IMAGE ####################
# install development dependencies (for testing)
#################### RUNTIME BASE IMAGE ####################
# We used base cuda image because pytorch installs its own cuda libraries.
# However pynccl depends on cuda libraries so we had to switch to the runtime image
# In the future it would be nice to get a container with pytorch and cuda without duplicating cuda
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base
# libnccl required for ray
RUN apt-get update -y \
&& apt-get install -y python3-pip
WORKDIR /workspace
COPY requirements.txt requirements.txt
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements.txt pip install -r requirements-dev.txt
# Install flash attention (from pre-built wheel) # doc requires source code
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ # we hide them inside `test_docs/` , so that this source code
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir # will not be imported by other tests
RUN mkdir test_docs
#################### RUNTIME BASE IMAGE #################### RUN mv docs test_docs/
RUN mv vllm test_docs/
#################### TEST IMAGE ####################
#################### OPENAI API SERVER #################### #################### OPENAI API SERVER ####################
# openai api server alternative # openai api server alternative
FROM vllm-base AS vllm-openai FROM vllm-base AS vllm-openai
# install additional dependencies for openai api server # install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate hf_transfer modelscope pip install accelerate hf_transfer modelscope
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY vllm vllm
ENV VLLM_USAGE_SOURCE production-docker-image ENV VLLM_USAGE_SOURCE production-docker-image
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]

View File

@ -1,6 +1,6 @@
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform. # This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
FROM ubuntu:22.04 FROM ubuntu:22.04 AS cpu-test-1
RUN apt-get update -y \ RUN apt-get update -y \
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \ && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
@ -9,6 +9,8 @@ RUN apt-get update -y \
RUN pip install --upgrade pip \ RUN pip install --upgrade pip \
&& pip install wheel packaging ninja setuptools>=49.4.0 numpy && pip install wheel packaging ninja setuptools>=49.4.0 numpy
FROM cpu-test-1 AS build
COPY ./ /workspace/vllm COPY ./ /workspace/vllm
WORKDIR /workspace/vllm WORKDIR /workspace/vllm
@ -17,4 +19,8 @@ RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.py
RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
WORKDIR /workspace/
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
CMD ["/bin/bash"] CMD ["/bin/bash"]

36
Dockerfile.neuron Normal file
View File

@ -0,0 +1,36 @@
# default base image
ARG BASE_IMAGE="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:2.1.1-neuronx-py310-sdk2.17.0-ubuntu20.04"
FROM $BASE_IMAGE
RUN echo "Base image is $BASE_IMAGE"
# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y
### Mount Point ###
# When launching the container, mount the code directory to /app
ARG APP_MOUNT=/app
VOLUME [ ${APP_MOUNT} ]
WORKDIR ${APP_MOUNT}
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
RUN python3 -m pip install sentencepiece transformers==4.36.2 -U
RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
RUN python3 -m pip install --pre neuronx-cc==2.12.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
COPY ./vllm /app/vllm/vllm
COPY ./setup.py /app/vllm/setup.py
COPY ./requirements-common.txt /app/vllm/requirements-common.txt
COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt
RUN cd /app/vllm \
&& python3 -m pip install -U -r requirements-neuron.txt
ENV VLLM_BUILD_WITH_NEURON 1
RUN cd /app/vllm \
&& pip install -e . \
&& cd ..
CMD ["/bin/bash"]

View File

@ -14,7 +14,7 @@ RUN echo "Base image is $BASE_IMAGE"
ARG FA_GFX_ARCHS="gfx90a;gfx942" ARG FA_GFX_ARCHS="gfx90a;gfx942"
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS" RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
ARG FA_BRANCH="3d2b6f5" ARG FA_BRANCH="ae7928c"
RUN echo "FA_BRANCH is $FA_BRANCH" RUN echo "FA_BRANCH is $FA_BRANCH"
# whether to build flash-attention # whether to build flash-attention
@ -23,6 +23,9 @@ RUN echo "FA_BRANCH is $FA_BRANCH"
# In that case, we need to use the python reference attention implementation in vllm # In that case, we need to use the python reference attention implementation in vllm
ARG BUILD_FA="1" ARG BUILD_FA="1"
# whether to build triton on rocm
ARG BUILD_TRITON="1"
# Install some basic utilities # Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y RUN apt-get update && apt-get install python3 python3-pip -y
@ -43,7 +46,7 @@ RUN apt-get update && apt-get install -y \
### Mount Point ### ### Mount Point ###
# When launching the container, mount the code directory to /app # When launching the container, mount the code directory to /app
ARG APP_MOUNT=/app ARG APP_MOUNT=/vllm-workspace
VOLUME [ ${APP_MOUNT} ] VOLUME [ ${APP_MOUNT} ]
WORKDIR ${APP_MOUNT} WORKDIR ${APP_MOUNT}
@ -75,21 +78,38 @@ RUN if [ "$BUILD_FA" = "1" ]; then \
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \ RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
COPY ./ /app/vllm # build triton
RUN if [ "$BUILD_TRITON" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& pip uninstall -y triton \
&& git clone https://github.com/ROCm/triton.git \
&& cd triton/python \
&& pip3 install . \
&& cd ../..; \
fi
RUN python3 -m pip install --upgrade pip WORKDIR /vllm-workspace
RUN python3 -m pip install xformers==0.0.23 --no-deps COPY . .
RUN cd /app \ #RUN python3 -m pip install pynvml # to be removed eventually
&& cd vllm \ RUN python3 -m pip install --upgrade pip numba
&& pip install -U -r requirements-rocm.txt \
&& if [ "$BUILD_FA" = "1" ]; then \ # make sure punica kernels are built (for LoRA)
bash patch_xformers.rocm.sh; fi \ ENV VLLM_INSTALL_PUNICA_KERNELS=1
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \ # Workaround for ray >= 2.10.0
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
ENV VLLM_NCCL_SO_PATH=/opt/rocm/lib/librccl.so
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -U -r requirements-rocm.txt \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
&& python3 setup.py install \ && python3 setup.py install \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \
&& cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.cpython-39-x86_64-linux-gnu.so vllm/ \
&& cd .. && cd ..
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3
CMD ["/bin/bash"] CMD ["/bin/bash"]

View File

@ -1,5 +1,9 @@
include LICENSE include LICENSE
include requirements.txt include requirements-common.txt
include requirements-cuda.txt
include requirements-rocm.txt
include requirements-neuron.txt
include requirements-cpu.txt
include CMakeLists.txt include CMakeLists.txt
recursive-include cmake * recursive-include cmake *

View File

@ -16,16 +16,17 @@ Easy, fast, and cheap LLM serving for everyone
--- ---
**The Third vLLM Bay Area Meetup (April 2nd 6pm-8:30pm PT)** **The Fourth vLLM Bay Area Meetup (June 11th 5:30pm-8pm PT)**
We are thrilled to announce our third vLLM Meetup! We are thrilled to announce our fourth vLLM Meetup!
The vLLM team will share recent updates and roadmap. The vLLM team will share recent updates and roadmap.
We will also have vLLM collaborators from Roblox coming up to the stage to discuss their experience in deploying LLMs with vLLM. We will also have vLLM collaborators from BentoML and Cloudflare coming up to the stage to discuss their experience in deploying LLMs with vLLM.
Please register [here](https://robloxandvllmmeetup2024.splashthat.com/) and join us! Please register [here](https://lu.ma/agivllm) and join us!
--- ---
*Latest News* 🔥 *Latest News* 🔥
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing). - [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
- [2024/01] Added ROCm 6.0 support to vLLM. - [2024/01] Added ROCm 6.0 support to vLLM.
- [2023/12] Added ROCm 5.7 support to vLLM. - [2023/12] Added ROCm 5.7 support to vLLM.
@ -61,39 +62,14 @@ vLLM is flexible and easy to use with:
- (Experimental) Prefix caching support - (Experimental) Prefix caching support
- (Experimental) Multi-lora support - (Experimental) Multi-lora support
vLLM seamlessly supports many Hugging Face models, including the following architectures: vLLM seamlessly supports most popular open-source models on HuggingFace, including:
- Transformer-like LLMs (e.g., Llama)
- Mixture-of-Expert LLMs (e.g., Mixtral)
- Multi-modal LLMs (e.g., LLaVA)
- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.) Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) ## Getting Started
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
- Command-R (`CohereForAI/c4ai-command-r-v01`, etc.)
- DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.)
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.)
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
- Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.)
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
@ -101,9 +77,7 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get
pip install vllm pip install vllm
``` ```
## Getting Started Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to learn more.
Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started.
- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html) - [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html) - [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html) - [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
@ -113,6 +87,32 @@ Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started
We welcome and value any contributions and collaborations. We welcome and value any contributions and collaborations.
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved. Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
## Sponsors
vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support!
<!-- Note: Please sort them in alphabetical order. -->
<!-- Note: Please keep these consistent with docs/source/community/sponsors.md -->
- a16z
- AMD
- Anyscale
- AWS
- Crusoe Cloud
- Databricks
- DeepInfra
- Dropbox
- Lambda Lab
- NVIDIA
- Replicate
- Roblox
- RunPod
- Trainy
- UC Berkeley
- UC San Diego
We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM.
## Citation ## Citation
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180): If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):

View File

@ -27,8 +27,8 @@ class RequestFuncInput:
class RequestFuncOutput: class RequestFuncOutput:
generated_text: str = "" generated_text: str = ""
success: bool = False success: bool = False
latency: float = 0 latency: float = 0.0
ttft: float = 0 # Time to first token ttft: float = 0.0 # Time to first token
itl: List[float] = field( itl: List[float] = field(
default_factory=list) # List of inter-token latencies default_factory=list) # List of inter-token latencies
prompt_len: int = 0 prompt_len: int = 0
@ -58,23 +58,24 @@ async def async_request_tgi(
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
ttft = 0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload) as response: async with session.post(url=api_url, json=payload) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk_bytes in response.content:
chunk = chunk.strip() chunk_bytes = chunk_bytes.strip()
if not chunk: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk.decode("utf-8"), "data:") chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")
data = json.loads(chunk) data = json.loads(chunk)
timestamp = time.perf_counter() timestamp = time.perf_counter()
# First token # First token
if ttft == 0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
@ -88,6 +89,9 @@ async def async_request_tgi(
output.latency = most_recent_timestamp - st output.latency = most_recent_timestamp - st
output.success = True output.success = True
output.generated_text = data["generated_text"] output.generated_text = data["generated_text"]
else:
output.error = response.reason or ""
output.success = False
except Exception: except Exception:
output.success = False output.success = False
exc_info = sys.exc_info() exc_info = sys.exc_info()
@ -119,23 +123,25 @@ async def async_request_trt_llm(
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
ttft = 0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload) as response: async with session.post(url=api_url, json=payload) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk_bytes in response.content:
chunk = chunk.strip() chunk_bytes = chunk_bytes.strip()
if not chunk: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk.decode("utf-8"), "data:") chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")
data = json.loads(chunk) data = json.loads(chunk)
output.generated_text += data["text_output"]
timestamp = time.perf_counter() timestamp = time.perf_counter()
# First token # First token
if ttft == 0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
@ -147,11 +153,10 @@ async def async_request_trt_llm(
most_recent_timestamp = timestamp most_recent_timestamp = timestamp
output.latency = most_recent_timestamp - st output.latency = most_recent_timestamp - st
output.generated_text = json.loads(data)["text_output"]
output.success = True output.success = True
else: else:
output.error = response.reason output.error = response.reason or ""
output.success = False output.success = False
except Exception: except Exception:
output.success = False output.success = False
@ -195,7 +200,7 @@ async def async_request_deepspeed_mii(
output.generated_text = parsed_resp["text"][0] output.generated_text = parsed_resp["text"][0]
output.success = True output.success = True
else: else:
output.error = response.reason output.error = response.reason or ""
output.success = False output.success = False
except Exception: except Exception:
output.success = False output.success = False
@ -234,19 +239,20 @@ async def async_request_openai_completions(
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
generated_text = "" generated_text = ""
ttft = 0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload,
headers=headers) as response: headers=headers) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk_bytes in response.content:
chunk = chunk.strip() chunk_bytes = chunk_bytes.strip()
if not chunk: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk.decode("utf-8"), "data: ") chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
if chunk == "[DONE]": if chunk == "[DONE]":
latency = time.perf_counter() - st latency = time.perf_counter() - st
else: else:
@ -255,7 +261,7 @@ async def async_request_openai_completions(
if data["choices"][0]["text"]: if data["choices"][0]["text"]:
timestamp = time.perf_counter() timestamp = time.perf_counter()
# First token # First token
if ttft == 0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
@ -273,6 +279,9 @@ async def async_request_openai_completions(
output.generated_text = generated_text output.generated_text = generated_text
output.success = True output.success = True
output.latency = latency output.latency = latency
else:
output.error = response.reason or ""
output.success = False
except Exception: except Exception:
output.success = False output.success = False
exc_info = sys.exc_info() exc_info = sys.exc_info()
@ -315,19 +324,20 @@ async def async_request_openai_chat_completions(
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
generated_text = "" generated_text = ""
ttft = 0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload,
headers=headers) as response: headers=headers) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk_bytes in response.content:
chunk = chunk.strip() chunk_bytes = chunk_bytes.strip()
if not chunk: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk.decode("utf-8"), "data: ") chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
if chunk == "[DONE]": if chunk == "[DONE]":
latency = time.perf_counter() - st latency = time.perf_counter() - st
else: else:
@ -337,7 +347,7 @@ async def async_request_openai_chat_completions(
delta = data["choices"][0]["delta"] delta = data["choices"][0]["delta"]
if delta.get("content", None): if delta.get("content", None):
# First token # First token
if ttft == 0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
@ -354,7 +364,7 @@ async def async_request_openai_chat_completions(
output.success = True output.success = True
output.latency = latency output.latency = latency
else: else:
output.error = response.reason output.error = response.reason or ""
output.success = False output.success = False
except Exception: except Exception:
output.success = False output.success = False

View File

@ -1,14 +1,17 @@
"""Benchmark the latency of processing a single batch of requests.""" """Benchmark the latency of processing a single batch of requests."""
import argparse import argparse
import json
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
@ -17,6 +20,8 @@ def main(args: argparse.Namespace):
# NOTE(woosuk): If the request cannot be processed in a single batch, # NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches. # the engine will automatically process the request in multiple batches.
llm = LLM(model=args.model, llm = LLM(model=args.model,
speculative_model=args.speculative_model,
num_speculative_tokens=args.num_speculative_tokens,
tokenizer=args.tokenizer, tokenizer=args.tokenizer,
quantization=args.quantization, quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size, tensor_parallel_size=args.tensor_parallel_size,
@ -24,11 +29,14 @@ def main(args: argparse.Namespace):
dtype=args.dtype, dtype=args.dtype,
enforce_eager=args.enforce_eager, enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype, kv_cache_dtype=args.kv_cache_dtype,
quantization_param_path=args.quantization_param_path,
device=args.device, device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight, ray_workers_use_nsight=args.ray_workers_use_nsight,
use_v2_block_manager=args.use_v2_block_manager,
enable_chunked_prefill=args.enable_chunked_prefill, enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir, download_dir=args.download_dir,
block_size=args.block_size) block_size=args.block_size,
gpu_memory_utilization=args.gpu_memory_utilization)
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=args.n, n=args.n,
@ -42,7 +50,9 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000, dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size, size=(args.batch_size,
args.input_len)) args.input_len))
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist() dummy_inputs: List[PromptStrictInputs] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]
def run_to_completion(profile_dir: Optional[str] = None): def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir: if profile_dir:
@ -53,13 +63,13 @@ def main(args: argparse.Namespace):
], ],
on_trace_ready=torch.profiler.tensorboard_trace_handler( on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p: str(profile_dir))) as p:
llm.generate(prompt_token_ids=dummy_prompt_token_ids, llm.generate(dummy_inputs,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=False) use_tqdm=False)
print(p.key_averages()) print(p.key_averages())
else: else:
start_time = time.perf_counter() start_time = time.perf_counter()
llm.generate(prompt_token_ids=dummy_prompt_token_ids, llm.generate(dummy_inputs,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=False) use_tqdm=False)
end_time = time.perf_counter() end_time = time.perf_counter()
@ -67,7 +77,8 @@ def main(args: argparse.Namespace):
return latency return latency
print("Warming up...") print("Warming up...")
run_to_completion(profile_dir=None) for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
run_to_completion(profile_dir=None)
if args.profile: if args.profile:
profile_dir = args.profile_result_dir profile_dir = args.profile_result_dir
@ -83,7 +94,22 @@ def main(args: argparse.Namespace):
latencies = [] latencies = []
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies.append(run_to_completion(profile_dir=None)) latencies.append(run_to_completion(profile_dir=None))
latencies = np.array(latencies)
percentages = [10, 25, 50, 75, 90]
percentiles = np.percentile(latencies, percentages)
print(f'Avg latency: {np.mean(latencies)} seconds') print(f'Avg latency: {np.mean(latencies)} seconds')
for percentage, percentile in zip(percentages, percentiles):
print(f'{percentage}% percentile latency: {percentile} seconds')
# Output JSON results if specified
if args.output_json:
results = {
"avg_latency": np.mean(latencies),
"latencies": latencies.tolist(),
"percentiles": dict(zip(percentages, percentiles.tolist())),
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
if __name__ == '__main__': if __name__ == '__main__':
@ -91,10 +117,12 @@ if __name__ == '__main__':
description='Benchmark the latency of processing a single batch of ' description='Benchmark the latency of processing a single batch of '
'requests till completion.') 'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m') parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--speculative-model', type=str, default=None)
parser.add_argument('--num-speculative-tokens', type=int, default=None)
parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
choices=['awq', 'gptq', 'squeezellm', None], choices=[*QUANTIZATION_METHODS, None],
default=None) default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--input-len', type=int, default=32)
@ -105,9 +133,13 @@ if __name__ == '__main__':
default=1, default=1,
help='Number of generated sequences per prompt.') help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true') parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters-warmup',
type=int,
default=10,
help='Number of iterations to run for warmup.')
parser.add_argument('--num-iters', parser.add_argument('--num-iters',
type=int, type=int,
default=3, default=30,
help='Number of iterations to run.') help='Number of iterations to run.')
parser.add_argument('--trust-remote-code', parser.add_argument('--trust-remote-code',
action='store_true', action='store_true',
@ -125,12 +157,23 @@ if __name__ == '__main__':
action='store_true', action='store_true',
help='enforce eager mode and disable CUDA graph') help='enforce eager mode and disable CUDA graph')
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", '--kv-cache-dtype',
type=str, type=str,
choices=['auto', 'fp8_e5m2'], choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default='auto', default="auto",
help= help='Data type for kv cache storage. If "auto", will use model '
'Data type for kv cache storage. If "auto", will use model data type.') 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
default=None,
help='Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument( parser.add_argument(
'--profile', '--profile',
action='store_true', action='store_true',
@ -145,18 +188,18 @@ if __name__ == '__main__':
"--device", "--device",
type=str, type=str,
default="cuda", default="cuda",
choices=["cuda"], choices=["cuda", "cpu"],
help='device type for vLLM execution, supporting CUDA only currently.') help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument('--block-size', parser.add_argument('--block-size',
type=int, type=int,
default=16, default=16,
help='block size of key/value cache') help='block size of key/value cache')
parser.add_argument( parser.add_argument(
'--enable-chunked-prefill', '--enable-chunked-prefill',
type=bool, action='store_true',
default=False,
help='If True, the prefill requests can be chunked based on the ' help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens') 'max_num_batched_tokens')
parser.add_argument('--use-v2-block-manager', action='store_true')
parser.add_argument( parser.add_argument(
"--ray-workers-use-nsight", "--ray-workers-use-nsight",
action='store_true', action='store_true',
@ -167,5 +210,16 @@ if __name__ == '__main__':
default=None, default=None,
help='directory to download and load the weights, ' help='directory to download and load the weights, '
'default to the default cache dir of huggingface') 'default to the default cache dir of huggingface')
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the latency results in JSON format.')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.9,
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -16,20 +16,22 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
def main(args): def main(args):
llm = LLM(model="baichuan-inc/Baichuan2-13B-Chat", llm = LLM(model=args.model,
tokenizer_mode='auto', tokenizer_mode='auto',
trust_remote_code=True, trust_remote_code=True,
enforce_eager=True, enforce_eager=True,
use_v2_block_manager=args.use_v2_block_manager,
tensor_parallel_size=args.tensor_parallel_size,
enable_prefix_caching=args.enable_prefix_caching) enable_prefix_caching=args.enable_prefix_caching)
num_prompts = 100 num_prompts = 100
prompts = [PROMPT] * num_prompts prompts = [PROMPT] * num_prompts
sampling_params = SamplingParams(temperature=0, max_tokens=100) sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
print("------warm up------") print("------warm up------")
test_prefix( test_prefix(
llm=llm, llm=llm,
prompts=prompts[:1], prompts=prompts,
sampling_params=sampling_params, sampling_params=sampling_params,
) )
@ -45,8 +47,16 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Benchmark the performance with or without automatic ' description='Benchmark the performance with or without automatic '
'prefix caching.') 'prefix caching.')
parser.add_argument('--model',
type=str,
default='baichuan-inc/Baichuan2-13B-Chat')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--enable-prefix-caching', parser.add_argument('--enable-prefix-caching',
action='store_true', action='store_true',
help='enable prefix caching') help='enable prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -17,6 +17,10 @@ On the client side, run:
--dataset-path <path to dataset> \ --dataset-path <path to dataset> \
--request-rate <request_rate> \ # By default <request_rate> is inf --request-rate <request_rate> \ # By default <request_rate> is inf
--num-prompts <num_prompts> # By default <num_prompts> is 1000 --num-prompts <num_prompts> # By default <num_prompts> is 1000
when using tgi backend, add
--endpoint /generate_stream
to the end of the command above.
""" """
import argparse import argparse
import asyncio import asyncio
@ -27,7 +31,7 @@ import time
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import AsyncGenerator, List, Tuple from typing import AsyncGenerator, List, Optional, Tuple
import numpy as np import numpy as np
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
@ -58,7 +62,11 @@ def sample_sharegpt_requests(
dataset_path: str, dataset_path: str,
num_requests: int, num_requests: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, int, int]]: ) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset. # Load the dataset.
with open(dataset_path) as f: with open(dataset_path) as f:
dataset = json.load(f) dataset = json.load(f)
@ -68,38 +76,32 @@ def sample_sharegpt_requests(
dataset = [(data["conversations"][0]["value"], dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset] data["conversations"][1]["value"]) for data in dataset]
# some of these will be filtered out, so sample more than we need # Shuffle the dataset.
sampled_indices = random.sample(range(len(dataset)), random.shuffle(dataset)
int(num_requests * 1.2))
dataset = [dataset[i] for i in sampled_indices]
# Tokenize the prompts and completions. # Filter out sequences that are too long or too short
prompts = [prompt for prompt, _ in dataset]
prompt_token_ids = tokenizer(prompts).input_ids
completions = [completion for _, completion in dataset]
completion_token_ids = tokenizer(completions).input_ids
tokenized_dataset = []
for i in range(len(dataset)):
output_len = len(completion_token_ids[i])
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
# Filter out too long sequences.
filtered_dataset: List[Tuple[str, int, int]] = [] filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, output_len in tokenized_dataset: for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4: if prompt_len < 4 or output_len < 4:
# Prune too short sequences. # Prune too short sequences.
# This is because TGI causes errors when the input or output length
# is too short.
continue continue
if prompt_len > 1024 or prompt_len + output_len > 2048: if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences. # Prune too long sequences.
continue continue
filtered_dataset.append((prompt, prompt_len, output_len)) filtered_dataset.append((prompt, prompt_len, output_len))
# Sample the requests. return filtered_dataset
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests
def sample_sonnet_requests( def sample_sonnet_requests(
@ -110,7 +112,9 @@ def sample_sonnet_requests(
prefix_len: int, prefix_len: int,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, str, int, int]]: ) -> List[Tuple[str, str, int, int]]:
assert input_len > prefix_len, "input_len must be greater than prefix_len." assert (
input_len > prefix_len
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
# Load the dataset. # Load the dataset.
with open(dataset_path) as f: with open(dataset_path) as f:
@ -131,8 +135,9 @@ def sample_sonnet_requests(
base_message, add_generation_prompt=True, tokenize=False) base_message, add_generation_prompt=True, tokenize=False)
base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids) base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)
assert (input_len > base_prompt_offset assert (
), f"Please set 'args.input-len' higher than {base_prompt_offset}." input_len > base_prompt_offset
), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}."
num_input_lines = round( num_input_lines = round(
(input_len - base_prompt_offset) / average_poem_len) (input_len - base_prompt_offset) / average_poem_len)
@ -140,7 +145,7 @@ def sample_sonnet_requests(
# prompt are fixed poem lines. # prompt are fixed poem lines.
assert ( assert (
prefix_len > base_prompt_offset prefix_len > base_prompt_offset
), f"Please set 'args.prefix-len' higher than {base_prompt_offset}." ), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}."
num_prefix_lines = round( num_prefix_lines = round(
(prefix_len - base_prompt_offset) / average_poem_len) (prefix_len - base_prompt_offset) / average_poem_len)
@ -210,6 +215,11 @@ def calculate_metrics(
else: else:
actual_output_lens.append(0) actual_output_lens.append(0)
if completed == 0:
warnings.warn(
"All requests failed. This is likely due to a misconfiguration "
"on the benchmark arguments.",
stacklevel=2)
metrics = BenchmarkMetrics( metrics = BenchmarkMetrics(
completed=completed, completed=completed,
total_input=total_input, total_input=total_input,
@ -221,9 +231,9 @@ def calculate_metrics(
1000, # ttfts is empty if streaming is not supported by backend 1000, # ttfts is empty if streaming is not supported by backend
median_ttft_ms=np.median(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000,
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
mean_tpot_ms=np.mean(tpots) * 1000, mean_tpot_ms=np.mean(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000,
p99_tpot_ms=np.percentile(tpots, 99) * 1000, p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
) )
return metrics, actual_output_lens return metrics, actual_output_lens
@ -245,6 +255,24 @@ async def benchmark(
else: else:
raise ValueError(f"Unknown backend: {backend}") raise ValueError(f"Unknown backend: {backend}")
print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len = input_requests[0]
test_input = RequestFuncInput(
model=model_id,
prompt=test_prompt,
api_url=api_url,
prompt_len=test_prompt_len,
output_len=test_output_len,
best_of=best_of,
use_beam_search=use_beam_search,
)
test_output = await request_func(request_func_input=test_input)
if not test_output.success:
raise ValueError(
"Initial test run failed - Please make sure benchmark arguments "
f"are correctly specified. Error: {test_output.error}")
else:
print("Initial test run completed. Starting main benchmark run...")
print(f"Traffic request rate: {request_rate}") print(f"Traffic request rate: {request_rate}")
pbar = None if disable_tqdm else tqdm(total=len(input_requests)) pbar = None if disable_tqdm else tqdm(total=len(input_requests))
@ -358,6 +386,7 @@ def main(args: argparse.Namespace):
dataset_path=args.dataset, dataset_path=args.dataset,
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
) )
elif args.dataset_name == "sharegpt": elif args.dataset_name == "sharegpt":
@ -365,6 +394,7 @@ def main(args: argparse.Namespace):
dataset_path=args.dataset_path, dataset_path=args.dataset_path,
num_requests=args.num_prompts, num_requests=args.num_prompts,
tokenizer=tokenizer, tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
) )
elif args.dataset_name == "sonnet": elif args.dataset_name == "sonnet":
@ -373,9 +403,9 @@ def main(args: argparse.Namespace):
input_requests = sample_sonnet_requests( input_requests = sample_sonnet_requests(
dataset_path=args.dataset_path, dataset_path=args.dataset_path,
num_requests=args.num_prompts, num_requests=args.num_prompts,
input_len=args.input_len, input_len=args.sonnet_input_len,
output_len=args.output_len, output_len=args.sonnet_output_len,
prefix_len=args.prefix_len, prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer, tokenizer=tokenizer,
) )
input_requests = [(prompt, prompt_len, output_len) input_requests = [(prompt, prompt_len, output_len)
@ -388,9 +418,9 @@ def main(args: argparse.Namespace):
input_requests = sample_sonnet_requests( input_requests = sample_sonnet_requests(
dataset_path=args.dataset_path, dataset_path=args.dataset_path,
num_requests=args.num_prompts, num_requests=args.num_prompts,
input_len=args.input_len, input_len=args.sonnet_input_len,
output_len=args.output_len, output_len=args.sonnet_output_len,
prefix_len=args.prefix_len, prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer, tokenizer=tokenizer,
) )
input_requests = [(prompt_formatted, prompt_len, output_len) input_requests = [(prompt_formatted, prompt_len, output_len)
@ -521,6 +551,12 @@ if __name__ == "__main__":
default=1000, default=1000,
help="Number of prompts to process.", help="Number of prompts to process.",
) )
parser.add_argument(
"--sharegpt-output-len",
type=int,
default=None,
help="Output length for each request. Overrides the output length "
"from the ShareGPT dataset.")
parser.add_argument( parser.add_argument(
"--sonnet-input-len", "--sonnet-input-len",
type=int, type=int,

View File

@ -10,6 +10,8 @@ from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer, from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase) PreTrainedTokenizerBase)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
def sample_requests( def sample_requests(
dataset_path: str, dataset_path: str,
@ -29,22 +31,23 @@ def sample_requests(
dataset = [(data["conversations"][0]["value"], dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset] data["conversations"][1]["value"]) for data in dataset]
# Tokenize the prompts and completions. # Shuffle the dataset.
prompts = [prompt for prompt, _ in dataset] random.shuffle(dataset)
prompt_token_ids = tokenizer(prompts).input_ids
completions = [completion for _, completion in dataset]
completion_token_ids = tokenizer(completions).input_ids
tokenized_dataset = []
for i in range(len(dataset)):
output_len = len(completion_token_ids[i])
if fixed_output_len is not None:
output_len = fixed_output_len
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
# Filter out too long sequences. # Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = [] filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, output_len in tokenized_dataset: for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4: if prompt_len < 4 or output_len < 4:
# Prune too short sequences. # Prune too short sequences.
continue continue
@ -53,9 +56,7 @@ def sample_requests(
continue continue
filtered_dataset.append((prompt, prompt_len, output_len)) filtered_dataset.append((prompt, prompt_len, output_len))
# Sample the requests. return filtered_dataset
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests
def run_vllm( def run_vllm(
@ -72,47 +73,52 @@ def run_vllm(
max_model_len: Optional[int], max_model_len: Optional[int],
enforce_eager: bool, enforce_eager: bool,
kv_cache_dtype: str, kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str, device: str,
enable_prefix_caching: bool, enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
download_dir: Optional[str] = None, download_dir: Optional[str] = None,
) -> float: ) -> float:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM(model=model, llm = LLM(
tokenizer=tokenizer, model=model,
quantization=quantization, tokenizer=tokenizer,
tensor_parallel_size=tensor_parallel_size, quantization=quantization,
seed=seed, tensor_parallel_size=tensor_parallel_size,
trust_remote_code=trust_remote_code, seed=seed,
dtype=dtype, trust_remote_code=trust_remote_code,
max_model_len=max_model_len, dtype=dtype,
gpu_memory_utilization=gpu_memory_utilization, max_model_len=max_model_len,
enforce_eager=enforce_eager, gpu_memory_utilization=gpu_memory_utilization,
kv_cache_dtype=kv_cache_dtype, enforce_eager=enforce_eager,
device=device, kv_cache_dtype=kv_cache_dtype,
enable_prefix_caching=enable_prefix_caching, quantization_param_path=quantization_param_path,
download_dir=download_dir) device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
)
# Add the requests to the engine. # Add the requests to the engine.
prompts = []
sampling_params = []
for prompt, _, output_len in requests: for prompt, _, output_len in requests:
sampling_params = SamplingParams( prompts.append(prompt)
n=n, sampling_params.append(
temperature=0.0 if use_beam_search else 1.0, SamplingParams(
top_p=1.0, n=n,
use_beam_search=use_beam_search, temperature=0.0 if use_beam_search else 1.0,
ignore_eos=True, top_p=1.0,
max_tokens=output_len, use_beam_search=use_beam_search,
) ignore_eos=True,
# FIXME(woosuk): Do not use internal method. max_tokens=output_len,
llm._add_request( ))
prompt=prompt,
prompt_token_ids=None,
sampling_params=sampling_params,
)
start = time.perf_counter() start = time.perf_counter()
# FIXME(woosuk): Do not use internal method. llm.generate(prompts, sampling_params, use_tqdm=True)
llm._run_engine(use_tqdm=True)
end = time.perf_counter() end = time.perf_counter()
return end - start return end - start
@ -212,14 +218,15 @@ def main(args: argparse.Namespace):
args.output_len) args.output_len)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm(requests, args.model, args.tokenizer, elapsed_time = run_vllm(
args.quantization, args.tensor_parallel_size, requests, args.model, args.tokenizer, args.quantization,
args.seed, args.n, args.use_beam_search, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.trust_remote_code, args.dtype, args.max_model_len,
args.max_model_len, args.enforce_eager, args.enforce_eager, args.kv_cache_dtype,
args.kv_cache_dtype, args.device, args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_prefix_caching, args.enable_chunked_prefill,
args.gpu_memory_utilization, args.download_dir) args.max_num_batched_tokens, args.gpu_memory_utilization,
args.download_dir)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@ -235,6 +242,18 @@ def main(args: argparse.Namespace):
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s") f"{total_num_tokens / elapsed_time:.2f} tokens/s")
# Output JSON results if specified
if args.output_json:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": total_num_tokens / elapsed_time,
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.") parser = argparse.ArgumentParser(description="Benchmark the throughput.")
@ -259,7 +278,7 @@ if __name__ == "__main__":
parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
choices=['awq', 'gptq', 'squeezellm', None], choices=[*QUANTIZATION_METHODS, None],
default=None) default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", parser.add_argument("--n",
@ -304,27 +323,51 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="enforce eager execution") help="enforce eager execution")
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", '--kv-cache-dtype',
type=str, type=str,
choices=["auto", "fp8_e5m2"], choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto", default="auto",
help= help='Data type for kv cache storage. If "auto", will use model '
'Data type for kv cache storage. If "auto", will use model data type.') 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
default=None,
help='Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument( parser.add_argument(
"--device", "--device",
type=str, type=str,
default="cuda", default="cuda",
choices=["cuda"], choices=["cuda", "cpu"],
help='device type for vLLM execution, supporting CUDA only currently.') help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument( parser.add_argument(
"--enable-prefix-caching", "--enable-prefix-caching",
action='store_true', action='store_true',
help="enable automatic prefix caching for vLLM backend.") help="enable automatic prefix caching for vLLM backend.")
parser.add_argument("--enable-chunked-prefill",
action='store_true',
help="enable chunked prefill for vLLM backend.")
parser.add_argument('--max-num-batched-tokens',
type=int,
default=None,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--download-dir', parser.add_argument('--download-dir',
type=str, type=str,
default=None, default=None,
help='directory to download and load the weights, ' help='directory to download and load the weights, '
'default to the default cache dir of huggingface') 'default to the default cache dir of huggingface')
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
args = parser.parse_args() args = parser.parse_args()
if args.tokenizer is None: if args.tokenizer is None:
args.tokenizer = args.model args.tokenizer = args.model

View File

@ -0,0 +1,352 @@
import argparse
import copy
import itertools
import pickle as pkl
import time
from typing import Callable, Iterable, List, Tuple
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]
# helpers
def to_fp8(tensor: torch.tensor) -> torch.tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.tensor) -> torch.tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.tensor, torch.tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5
if dtype == torch.int8:
return to_int8(a), to_int8(b)
if dtype == torch.float8_e4m3fn:
return to_fp8(a), to_fp8(b)
raise ValueError("unsupported dtype")
# impl
def pytorch_i8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor:
return torch.mm(a, b)
def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor:
return torch._scaled_mm(a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype)
def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
scale_a: torch.tensor, scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor:
return torch._scaled_mm(a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
use_fast_accum=True)
def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor,
out_dtype: torch.dtype) -> torch.tensor:
return ops.cutlass_scaled_mm_dq(a,
b,
scale_a,
scale_b,
out_dtype=out_dtype)
# bench
def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
scale_b: torch.tensor, out_dtype: torch.dtype, label: str,
sub_label: str, fn: Callable, description: str) -> TMeasurement:
min_run_time = 1
globals = {
"a": a,
"b": b,
"scale_a": scale_a,
"scale_b": scale_b,
"out_dtype": out_dtype,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(a, b, scale_a, scale_b, out_dtype)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
assert dtype == torch.int8
a, b = make_rand_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
timers = []
# pytorch impl
timers.append(
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
torch.bfloat16, label, sub_label, pytorch_i8_impl,
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
# cutlass impl
timers.append(
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
torch.bfloat16, label, sub_label, cutlass_impl,
"cutlass_i8_i8_bf16_scaled_mm"))
return timers
def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
assert dtype == torch.float8_e4m3fn
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
timers = []
# pytorch impl: bf16 output, without fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm"))
# pytorch impl: bf16 output, with fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
pytorch_fp8_impl_fast_accum,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"))
# pytorch impl: fp16 output, without fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm"))
# pytorch impl: fp16 output, with fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
pytorch_fp8_impl_fast_accum,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"))
# cutlass impl: bf16 output
timers.append(
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
torch.bfloat16, label, sub_label, cutlass_impl,
"cutlass_fp8_fp8_bf16_scaled_mm"))
# cutlass impl: fp16 output
timers.append(
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
torch.float16, label, sub_label, cutlass_impl,
"cutlass_fp8_fp8_fp16_scaled_mm"))
return timers
def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label)
if dtype == torch.float8_e4m3fn:
return bench_fp8(dtype, m, k, n, label, sub_label)
raise ValueError("unsupported type")
# runner
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
def run(dtype: torch.dtype,
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
f"MKN=({m}x{k}x{n})")
print_timers(timers)
results.extend(timers)
return results
# output makers
def make_output(data: Iterable[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]],
base_description: str,
timestamp=None):
print(f"== All Results {base_description} ====")
print_timers(data)
# pickle all the results
timestamp = int(time.time()) if timestamp is None else timestamp
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
pkl.dump(data, f)
# argparse runners
def run_square_bench(args):
dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs)
make_output(data, MKNs, f"square_bench-{args.dtype}")
def run_range_bench(args):
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
n = len(dim_sizes)
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
MKNs = list(zip(Ms, Ks, Ns))
data = run(args.dtype, MKNs)
make_output(data, MKNs, f"range_bench-{args.dtype}")
def run_model_bench(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
KNs = []
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KNs.append(KN)
return KNs
model_bench_data = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
Ms = args.batch_sizes
KNs = model_shapes(model, tp_size)
MKNs = []
for m in Ms:
for k, n in KNs:
MKNs.append((m, k, n))
data = run(args.dtype, MKNs)
model_bench_data.append(data)
# Print all results
for data, model_tp in zip(model_bench_data, models_tps):
model, tp_size = model_tp
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
print_timers(data)
timestamp = int(time.time())
all_data = []
for d in model_bench_data:
all_data.extend(d)
# pickle all data
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
pkl.dump(all_data, f)
if __name__ == '__main__':
def to_torch_dtype(dt):
if dt == "int8":
return torch.int8
if dt == "fp8":
return torch.float8_e4m3fn
raise ValueError("unsupported dtype")
parser = argparse.ArgumentParser(
description="""
Benchmark Cutlass GEMM.
To run square GEMMs:
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
To run constant N and K and sweep M:
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
To run dimensions from a model:
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--dtype",
type=to_torch_dtype,
required=True,
help="Available options are ['int8', 'fp8']")
subparsers = parser.add_subparsers(dest="cmd")
square_parser = subparsers.add_parser("square_bench")
square_parser.add_argument("--dim-start", type=int, required=True)
square_parser.add_argument("--dim-end", type=int, required=True)
square_parser.add_argument("--dim-increment", type=int, required=True)
square_parser.set_defaults(func=run_square_bench)
range_parser = subparsers.add_parser("range_bench")
range_parser.add_argument("--dim-start", type=int, required=True)
range_parser.add_argument("--dim-end", type=int, required=True)
range_parser.add_argument("--dim-increment", type=int, required=True)
range_parser.add_argument("--m-constant", type=int, default=None)
range_parser.add_argument("--n-constant", type=int, default=None)
range_parser.add_argument("--k-constant", type=int, default=None)
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument("--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys())
model_parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=DEFAULT_TP_SIZES)
model_parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()
args.func(args)

View File

@ -0,0 +1,37 @@
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536
# TP1 shapes
WEIGHT_SHAPES = {
"mistralai/Mistral-7B-v0.1": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-2-7b-hf": [
([4096, 12288], 1),
([4096, 4096], 0),
([4096, 22016], 1),
([11008, 4096], 0),
],
"meta-llama/Llama-2-13b-hf": [
([5120, 15360], 1),
([5120, 5120], 0),
([5120, 27648], 1),
([13824, 5120], 0),
],
"meta-llama/Llama-2-70b-hf": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
}

View File

@ -0,0 +1,302 @@
import argparse
import os
import sys
from typing import Optional
import torch
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.aqlm import (
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
optimized_dequantize_gemm)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def torch_mult(
input: torch.Tensor, # [..., in_features]
weights: torch.Tensor,
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
) -> torch.Tensor:
output = F.linear(input, weights)
return output
def dequant_out_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
if bias is None:
output = F.linear(input, weights, bias)
orig_shape = output.shape
flattened_output = output.view(-1, output.size(-1))
f_scales = scales.view(-1, scales.shape[0])
b_scales = f_scales.expand(flattened_output.shape[0], -1)
flattened_output *= b_scales
return flattened_output.view(orig_shape)
else:
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)
def dequant_weight_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)
def dequant_no_scale(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
return F.linear(input, weights, bias)
# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
# the generic pytorch version.
# Just visual comparison.
def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None:
n = parts.sum().item()
device = torch.device('cuda:0')
code_range = (1 << bits) // 2
ingroups = 8
codes = torch.randint(-code_range,
code_range,
size=(n, k // ingroups, nbooks),
dtype=get_int_dtype(bits),
device=device)
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
dtype=torch.float16,
device=device)
count = 0
for index in range(16):
for i in range(8):
for book in range(nbooks):
codebooks[book, index, 0, i] = count * (10**book)
count += 1
print("codes shape", codes.shape)
for i in range(16):
for book in range(nbooks):
codes[0, i, book] = i
codes[0, -i, book] = i
weights = dequantize_weight(codes, codebooks, None)
weights2 = ops.aqlm_dequant(codes, codebooks, parts)
print("weights shape:", weights.shape)
print("weights2 shape:", weights2.shape)
print("weights are:", weights)
print("weights2 are:", weights2)
print("first 128 weights are", weights[0, 0:128].to(torch.int32))
print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))
print("last 128 weights are", weights[0, -128:])
print("last 128 weights2 are:", weights2[0, -128:])
def main():
parser = argparse.ArgumentParser(description="Benchmark aqlm performance.")
# Add arguments
parser.add_argument("--nbooks",
type=int,
default=1,
help="Number of codebooks (default: 1)")
parser.add_argument("--bits",
type=int,
default=16,
help="Number of bits per code element (default: 16)")
parser.add_argument(
"--test",
type=bool,
default=False,
help="Run the decompression/dequant tester rather than benchmarking "
"(default: False)")
# Parse the arguments
args = parser.parse_args()
# Extract values
nbooks = args.nbooks
bits = args.bits
if args.test:
dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
return
# Otherwise, benchmark.
methods = [
ops.aqlm_gemm,
dequant_out_scale,
generic_dequantize_gemm,
optimized_dequantize_gemm,
dequant_weight_scale,
torch_mult,
dequant_no_scale,
]
filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
print(f"writing benchmarks to file {filename}")
with open(filename, "w") as f:
sys.stdout = f
print('m | k | n | n parts', end='')
for method in methods:
print(f" | {method.__name__.replace('_', ' ')} (µs)", end='')
print('')
# These are reasonable prefill sizes.
ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
(4096, (11008, 11008)), (11008, (4096, )))
# reasonable ranges for m.
for m in [
1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
128, 256, 512, 1024, 1536, 2048, 3072, 4096
]:
print(f'{m}', file=sys.__stdout__)
for ksp in ksandpartions:
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
methods)
sys.stdout = sys.__stdout__
def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int,
methods):
# I didn't see visible improvements from increasing these, but feel free :)
num_warmup_trials = 1
num_trials = 1
num_calls = 100
# warmup.
for method in methods:
for _ in range(num_warmup_trials):
run_timing(
num_calls=num_calls,
m=m,
k=k,
parts=parts,
nbooks=nbooks,
bits=bits,
method=method,
)
n = parts.sum().item()
print(f'{m} | {k} | {n} | {parts.tolist()}', end='')
for method in methods:
best_time_us = 1e20
for _ in range(num_trials):
kernel_dur_ms = run_timing(
num_calls=num_calls,
m=m,
k=k,
parts=parts,
nbooks=nbooks,
bits=bits,
method=method,
)
kernel_dur_us = 1000 * kernel_dur_ms
if kernel_dur_us < best_time_us:
best_time_us = kernel_dur_us
print(f' | {kernel_dur_us:.0f}', end='')
print('')
def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor,
nbooks: int, bits: int, method) -> float:
n = parts.sum().item()
device = torch.device('cuda:0')
input = torch.randn((1, m, k), dtype=torch.float16, device=device)
code_range = (1 << bits) // 2
ingroups = 8
codes = torch.randint(-code_range,
code_range,
size=(n, k // ingroups, nbooks),
dtype=get_int_dtype(bits),
device=device)
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
dtype=torch.float16,
device=device)
scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
# for comparison to just a pytorch mult.
weights = torch.randn((n, k), dtype=torch.float16, device=device)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
if method is torch_mult:
for i in range(num_calls):
torch_mult(input, weights, scales)
else:
for i in range(num_calls):
method(input, codes, codebooks, scales, parts, None)
end_event.record()
end_event.synchronize()
dur_ms = start_event.elapsed_time(end_event) / num_calls
return dur_ms
if __name__ == "__main__":
sys.exit(main())

View File

@ -0,0 +1,233 @@
import argparse
import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MarlinWorkspace, marlin_24_quantize, marlin_quantize)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights)
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
ACT_ORDER_OPTS = [False, True]
K_FULL_OPTS = [False, True]
def bench_run(results, model, act_order, is_k_full, num_bits, group_size,
size_m, size_k, size_n):
label = "Quant Matmul"
sub_label = ("{}, act={} k_full={}, b={}, g={}, "
"MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits,
group_size, size_m, size_k, size_n))
print(f"Testing: {sub_label}")
a = torch.randn(size_m, size_k).to(torch.half).cuda()
b = torch.rand(size_k, size_n).to(torch.half).cuda()
a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda())
# Marlin quant
(
marlin_w_ref,
marlin_q_w,
marlin_s,
marlin_g_idx,
marlin_sort_indices,
marlin_rand_perm,
) = marlin_quantize(b, num_bits, group_size, act_order)
# Marlin_24 quant
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
# GPTQ quant
(w_ref, q_w, s, g_idx,
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
# For act_order, sort the "weights" and "g_idx"
# so that group ids are increasing
repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
if act_order:
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
# Prepare
marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL)
globals = {
# Gen params
"num_bits": num_bits,
"group_size": group_size,
"size_m": size_m,
"size_n": size_n,
"size_k": size_k,
"a": a,
"a_tmp": a_tmp,
# Marlin params
"marlin_w_ref": marlin_w_ref,
"marlin_q_w": marlin_q_w,
"marlin_s": marlin_s,
"marlin_g_idx": marlin_g_idx,
"marlin_sort_indices": marlin_sort_indices,
"marlin_rand_perm": marlin_rand_perm,
"marlin_workspace": marlin_workspace,
"is_k_full": is_k_full,
# Marlin_24 params
"marlin_24_w_ref": marlin_24_w_ref,
"marlin_24_q_w_comp": marlin_24_q_w_comp,
"marlin_24_meta": marlin_24_meta,
"marlin_24_s": marlin_24_s,
"marlin_24_workspace": marlin_24_workspace,
# GPTQ params
"q_w_gptq": q_w_gptq,
"repack_sort_indices": repack_sort_indices,
# Kernels
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
"gptq_marlin_repack": ops.gptq_marlin_repack,
}
min_run_time = 1
# Warmup pytorch
for i in range(5):
torch.matmul(a, marlin_w_ref)
results.append(
benchmark.Timer(
stmt="torch.matmul(a, marlin_w_ref)",
globals=globals,
label=label,
sub_label=sub_label,
description="pytorch_gemm",
).blocked_autorange(min_run_time=min_run_time))
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_gemm",
).blocked_autorange(min_run_time=min_run_time))
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
results.append(
benchmark.Timer(
stmt=
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_24_gemm",
).blocked_autorange(min_run_time=min_run_time))
results.append(
benchmark.Timer(
stmt=
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
description="gptq_marlin_repack",
).blocked_autorange(min_run_time=min_run_time))
def main(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
results = []
for model in args.models:
for layer in WEIGHT_SHAPES[model]:
size_k = layer[0]
size_n = layer[1]
if len(args.limit_k) > 0 and size_k not in args.limit_k:
continue
if len(args.limit_n) > 0 and size_n not in args.limit_n:
continue
for act_order in ACT_ORDER_OPTS:
if len(args.limit_act_order
) > 0 and act_order not in args.limit_act_order:
continue
for is_k_full in K_FULL_OPTS:
if len(args.limit_k_full
) > 0 and is_k_full not in args.limit_k_full:
continue
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
if len(args.limit_num_bits
) > 0 and num_bits not in args.limit_num_bits:
continue
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
if len(
args.limit_group_size
) > 0 and group_size not in args.limit_group_size:
continue
# For act_order, the group_size must be less than
# size_k
if act_order and (group_size == size_k
or group_size == -1):
continue
for size_m in args.batch_sizes:
bench_run(results, model, act_order, is_k_full,
num_bits, group_size, size_m, size_k,
size_n)
compare = benchmark.Compare(results)
compare.print()
# For quick benchmarking use:
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
#
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark Marlin across specified models/shapes/batches")
parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(),
)
parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])
args = parser.parse_args()
main(args)

View File

@ -1,3 +1,4 @@
import argparse
import json import json
import os import os
import sys import sys
@ -5,68 +6,70 @@ import sys
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import triton import triton
from tqdm import tqdm
from vllm.model_executor.layers.fused_moe import (fused_moe, from vllm.model_executor.layers.fused_moe import (fused_moe,
get_config_file_name) get_config_file_name)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def main(model, tp_size, gpu, dtype: str):
def main(): os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
method = fused_moe method = fused_moe
for bs in [ for bs in [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096 2048, 3072, 4096
]: ]:
run_grid(bs, method=method) run_grid(bs,
model=model,
method=method,
gpu=gpu,
tp_size=tp_size,
dtype=dtype)
def run_grid(bs, method): def run_grid(bs, model, method, gpu, tp_size, dtype: str):
d_model = 4096 if model == '8x7B':
d_model = 4096
model_intermediate_size = 14336
num_layers = 32
elif model == '8x22B':
d_model = 6144
model_intermediate_size = 16384
num_layers = 56
else:
raise ValueError(f'Unsupported Mixtral model {model}')
num_total_experts = 8 num_total_experts = 8
top_k = 2 top_k = 2
tp_size = 2 # tp_size = 2
model_intermediate_size = 14336
num_layers = 32
num_calls = 100 num_calls = 100
num_warmup_trials = 1 num_warmup_trials = 1
num_trials = 1 num_trials = 1
configs = [] configs = []
if bs <= 16:
BLOCK_SIZES_M = [16]
elif bs <= 32:
BLOCK_SIZES_M = [16, 32]
elif bs <= 64:
BLOCK_SIZES_M = [16, 32, 64]
elif bs <= 128:
BLOCK_SIZES_M = [16, 32, 64, 128]
else:
BLOCK_SIZES_M = [16, 32, 64, 128, 256]
for block_size_n in [32, 64, 128, 256]: for block_size_n in [32, 64, 128, 256]:
for block_size_m in BLOCK_SIZES_M: for block_size_m in [16, 32, 64, 128, 256]:
for block_size_k in [64, 128, 256]: for block_size_k in [64, 128, 256]:
for group_size_m in [1, 16, 32, 64]: for group_size_m in [1, 16, 32, 64]:
for num_warps in [4, 8]: for num_warps in [4, 8]:
configs.append({ for num_stages in [2, 3, 4, 5]:
"BLOCK_SIZE_M": block_size_m, configs.append({
"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_K": block_size_k, "BLOCK_SIZE_N": block_size_n,
"GROUP_SIZE_M": group_size_m, "BLOCK_SIZE_K": block_size_k,
"num_warps": num_warps, "GROUP_SIZE_M": group_size_m,
"num_stages": 4, "num_warps": num_warps,
}) "num_stages": num_stages,
})
best_config = None best_config = None
best_time_us = 1e20 best_time_us = 1e20
for config in configs: print(f'{tp_size=} {bs=}')
print(f'{tp_size=} {bs=}')
print(f'{config}') for config in tqdm(configs):
# warmup # warmup
print('warming up')
try: try:
for _ in range(num_warmup_trials): for _ in range(num_warmup_trials):
run_timing( run_timing(
@ -79,12 +82,12 @@ def run_grid(bs, method):
model_intermediate_size=model_intermediate_size, model_intermediate_size=model_intermediate_size,
method=method, method=method,
config=config, config=config,
dtype=dtype,
) )
except triton.runtime.autotuner.OutOfResources: except triton.runtime.autotuner.OutOfResources:
continue continue
# trial # trial
print('benchmarking')
for _ in range(num_trials): for _ in range(num_trials):
kernel_dur_ms = run_timing( kernel_dur_ms = run_timing(
num_calls=num_calls, num_calls=num_calls,
@ -96,6 +99,7 @@ def run_grid(bs, method):
model_intermediate_size=model_intermediate_size, model_intermediate_size=model_intermediate_size,
method=method, method=method,
config=config, config=config,
dtype=dtype,
) )
kernel_dur_us = 1000 * kernel_dur_ms kernel_dur_us = 1000 * kernel_dur_ms
@ -105,16 +109,18 @@ def run_grid(bs, method):
best_config = config best_config = config
best_time_us = kernel_dur_us best_time_us = kernel_dur_us
print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' tqdm.write(
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
f'{d_model=} {model_intermediate_size=} {num_layers=}') f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
f'{d_model=} {model_intermediate_size=} {num_layers=}')
print("best_time_us", best_time_us) print("best_time_us", best_time_us)
print("best_config", best_config) print("best_config", best_config)
# holds Dict[str, Dict[str, int]] # holds Dict[str, Dict[str, int]]
filename = get_config_file_name(num_total_experts, filename = get_config_file_name(num_total_experts,
model_intermediate_size // tp_size) model_intermediate_size // tp_size,
"float8" if dtype == "float8" else None)
print(f"writing config to file {filename}") print(f"writing config to file {filename}")
existing_content = {} existing_content = {}
if os.path.exists(filename): if os.path.exists(filename):
@ -128,27 +134,48 @@ def run_grid(bs, method):
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
top_k: int, tp_size: int, model_intermediate_size: int, method, top_k: int, tp_size: int, model_intermediate_size: int, method,
config) -> float: config, dtype: str) -> float:
shard_intermediate_size = model_intermediate_size // tp_size shard_intermediate_size = model_intermediate_size // tp_size
hidden_states = torch.rand( hidden_states = torch.rand(
(bs, d_model), (bs, d_model),
device="cuda:0", device="cuda:0",
dtype=torch.bfloat16, dtype=torch.float16,
) )
ws = torch.rand( w1 = torch.rand(
(num_total_experts, 2 * shard_intermediate_size, d_model), (num_total_experts, 2 * shard_intermediate_size, d_model),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
w2s = torch.rand( w2 = torch.rand(
(num_total_experts, d_model, shard_intermediate_size), (num_total_experts, d_model, shard_intermediate_size),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if dtype == "float8":
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)
w1_scale = torch.ones(num_total_experts,
device=hidden_states.device,
dtype=torch.float32)
w2_scale = torch.ones(num_total_experts,
device=hidden_states.device,
dtype=torch.float32)
a1_scale = torch.ones(1,
device=hidden_states.device,
dtype=torch.float32)
a2_scale = torch.ones(1,
device=hidden_states.device,
dtype=torch.float32)
gating_output = F.softmax(torch.rand( gating_output = F.softmax(torch.rand(
(num_calls, bs, num_total_experts), (num_calls, bs, num_total_experts),
device=hidden_states.device, device=hidden_states.device,
@ -163,13 +190,18 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
for i in range(num_calls): for i in range(num_calls):
hidden_states = method( hidden_states = method(
hidden_states=hidden_states, hidden_states=hidden_states,
w1=ws, w1=w1,
w2=w2s, w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
gating_output=gating_output[i], gating_output=gating_output[i],
topk=2, topk=2,
renormalize=True, renormalize=True,
inplace=True, inplace=True,
override_config=config, override_config=config,
use_fp8=dtype == "float8",
) )
end_event.record() end_event.record()
end_event.synchronize() end_event.synchronize()
@ -179,4 +211,29 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
if __name__ == "__main__": if __name__ == "__main__":
sys.exit(main()) parser = argparse.ArgumentParser(
prog='benchmark_mixtral_moe',
description='Benchmark and tune the fused_moe kernel',
)
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['float8', 'float16'],
help='Data type used for fused_moe kernel computations',
)
parser.add_argument('--model',
type=str,
default='8x7B',
choices=['8x7B', '8x22B'],
help='The Mixtral model to benchmark')
parser.add_argument('--tp-size',
type=int,
default=2,
help='Tensor paralleli size')
parser.add_argument('--gpu',
type=int,
default=0,
help="GPU ID for benchmarking")
args = parser.parse_args()
sys.exit(main(args.model, args.tp_size, args.gpu, args.dtype))

View File

@ -5,7 +5,7 @@ from typing import Optional
import torch import torch
from vllm._C import ops from vllm import _custom_ops as ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
NUM_BLOCKS = 1024 NUM_BLOCKS = 1024
@ -16,7 +16,7 @@ PARTITION_SIZE = 512
def main( def main(
version: str, version: str,
num_seqs: int, num_seqs: int,
context_len: int, seq_len: int,
num_query_heads: int, num_query_heads: int,
num_kv_heads: int, num_kv_heads: int,
head_size: int, head_size: int,
@ -48,12 +48,12 @@ def main(
dtype=torch.float, dtype=torch.float,
device=device) device=device)
context_lens = [context_len for _ in range(num_seqs)] seq_lens = [seq_len for _ in range(num_seqs)]
max_context_len = max(context_lens) max_seq_len = max(seq_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device=device) seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
# Create the block tables. # Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = [] block_tables = []
for _ in range(num_seqs): for _ in range(num_seqs):
block_table = [ block_table = [
@ -77,8 +77,7 @@ def main(
# Prepare for the paged attention kernel. # Prepare for the paged attention kernel.
output = torch.empty_like(query) output = torch.empty_like(query)
if version == "v2": if version == "v2":
num_partitions = ((max_context_len + PARTITION_SIZE - 1) // num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
PARTITION_SIZE)
tmp_output = torch.empty( tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size), size=(num_seqs, num_query_heads, num_partitions, head_size),
dtype=output.dtype, dtype=output.dtype,
@ -97,6 +96,9 @@ def main(
torch.cuda.cudart().cudaProfilerStart() torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter() start_time = time.perf_counter()
# Using default kv_scale
kv_scale = 1.0
for _ in range(num_iters): for _ in range(num_iters):
if version == "v1": if version == "v1":
ops.paged_attention_v1( ops.paged_attention_v1(
@ -107,11 +109,12 @@ def main(
num_kv_heads, num_kv_heads,
scale, scale,
block_tables, block_tables,
context_lens, seq_lens,
block_size, block_size,
max_context_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale,
) )
elif version == "v2": elif version == "v2":
ops.paged_attention_v2( ops.paged_attention_v2(
@ -125,11 +128,12 @@ def main(
num_kv_heads, num_kv_heads,
scale, scale,
block_tables, block_tables,
context_lens, seq_lens,
block_size, block_size,
max_context_len, max_seq_len,
alibi_slopes, alibi_slopes,
kv_cache_dtype, kv_cache_dtype,
kv_scale,
) )
else: else:
raise ValueError(f"Invalid version: {version}") raise ValueError(f"Invalid version: {version}")
@ -161,12 +165,12 @@ if __name__ == '__main__':
choices=["v1", "v2"], choices=["v1", "v2"],
default="v2") default="v2")
parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--context-len", type=int, default=4096) parser.add_argument("--seq_len", type=int, default=4096)
parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-query-heads", type=int, default=64)
parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument("--head-size",
type=int, type=int,
choices=[64, 80, 96, 112, 128, 256], choices=[64, 80, 96, 112, 128, 192, 256],
default=128) default=128)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--use-alibi", action="store_true") parser.add_argument("--use-alibi", action="store_true")
@ -179,11 +183,11 @@ if __name__ == '__main__':
parser.add_argument( parser.add_argument(
"--kv-cache-dtype", "--kv-cache-dtype",
type=str, type=str,
choices=["auto", "fp8_e5m2"], choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
default="auto", default="auto",
help= help="Data type for kv cache storage. If 'auto', will use model "
'Data type for kv cache storage. If "auto", will use model data type.') "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
parser.add_argument("--device", type=str, choices=["cuda"], default="cuda") "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
@ -192,7 +196,7 @@ if __name__ == '__main__':
main( main(
version=args.version, version=args.version,
num_seqs=args.batch_size, num_seqs=args.batch_size,
context_len=args.context_len, seq_len=args.seq_len,
num_query_heads=args.num_query_heads, num_query_heads=args.num_query_heads,
num_kv_heads=args.num_kv_heads, num_kv_heads=args.num_kv_heads,
head_size=args.head_size, head_size=args.head_size,

View File

@ -93,7 +93,7 @@ if __name__ == '__main__':
parser.add_argument("--num-heads", type=int, default=8) parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument("--head-size",
type=int, type=int,
choices=[64, 80, 96, 112, 128, 256], choices=[64, 80, 96, 112, 128, 192, 256],
default=128) default=128)
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
parser.add_argument("--dtype", parser.add_argument("--dtype",

View File

@ -0,0 +1,75 @@
WEIGHT_SHAPES = {
"ideal": [[4 * 256 * 32, 256 * 32]],
"mistralai/Mistral-7B-v0.1/TP1": [
[4096, 6144],
[4096, 4096],
[4096, 28672],
[14336, 4096],
],
"mistralai/Mistral-7B-v0.1/TP2": [
[4096, 3072],
[2048, 4096],
[4096, 14336],
[7168, 4096],
],
"mistralai/Mistral-7B-v0.1/TP4": [
[4096, 1536],
[1024, 4096],
[4096, 7168],
[3584, 4096],
],
"meta-llama/Llama-2-7b-hf/TP1": [
[4096, 12288],
[4096, 4096],
[4096, 22016],
[11008, 4096],
],
"meta-llama/Llama-2-7b-hf/TP2": [
[4096, 6144],
[2048, 4096],
[4096, 11008],
[5504, 4096],
],
"meta-llama/Llama-2-7b-hf/TP4": [
[4096, 3072],
[1024, 4096],
[4096, 5504],
[2752, 4096],
],
"meta-llama/Llama-2-13b-hf/TP1": [
[5120, 15360],
[5120, 5120],
[5120, 27648],
[13824, 5120],
],
"meta-llama/Llama-2-13b-hf/TP2": [
[5120, 7680],
[2560, 5120],
[5120, 13824],
[6912, 5120],
],
"meta-llama/Llama-2-13b-hf/TP4": [
[5120, 3840],
[1280, 5120],
[5120, 6912],
[3456, 5120],
],
"meta-llama/Llama-2-70b-hf/TP1": [
[8192, 10240],
[8192, 8192],
[8192, 57344],
[28672, 8192],
],
"meta-llama/Llama-2-70b-hf/TP2": [
[8192, 5120],
[4096, 8192],
[8192, 28672],
[14336, 8192],
],
"meta-llama/Llama-2-70b-hf/TP4": [
[8192, 2560],
[2048, 8192],
[8192, 14336],
[7168, 8192],
],
}

View File

@ -4,7 +4,7 @@ PORT=8000
MODEL=$1 MODEL=$1
TOKENS=$2 TOKENS=$2
docker run --gpus all --shm-size 1g -p $PORT:80 \ docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \
-v $PWD/data:/data \ -v $PWD/data:/data \
ghcr.io/huggingface/text-generation-inference:1.4.0 \ ghcr.io/huggingface/text-generation-inference:1.4.0 \
--model-id $MODEL \ --model-id $MODEL \

View File

@ -0,0 +1,63 @@
import argparse
import cProfile
import pstats
from vllm import LLM, SamplingParams
# A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
] * 1000
LONG_PROMPT = ' '.join(LONG_PROMPT)
def main(args):
llm = LLM(
model=args.model,
enforce_eager=True,
enable_prefix_caching=True,
tensor_parallel_size=args.tensor_parallel_size,
use_v2_block_manager=args.use_v2_block_manager,
)
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
profiler = cProfile.Profile()
print("------warm up------")
for i in range(3):
output = llm.generate(LONG_PROMPT, sampling_params)
print(output[0].outputs[0].text)
print("------start generating------")
for i in range(3):
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)',
globals(), locals())
# analyze the runtime of hashing function
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
total_time = 0
total_calls = 0
for func in stats.stats:
if 'hash_of_block' in func[2]:
total_time = stats.stats[func][3]
total_calls = stats.stats[func][0]
percentage = (total_time / stats.total_tt) * 100
print(f"Hashing took {total_time:.2f} seconds,"
f"{percentage:.2f}% of the total runtime.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Benchmark the performance of hashing function in'
'automatic prefix caching.')
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--output-len', type=int, default=10)
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='enable prefix caching')
parser.add_argument('--use-v2-block-manager',
action='store_true',
help='Use BlockSpaceMangerV2')
args = parser.parse_args()
main(args)

View File

@ -99,7 +99,9 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
"Failed to determine torch nvcc compiler flags") "Failed to determine torch nvcc compiler flags")
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2") list(APPEND GPU_FLAGS "-DENABLE_FP8")
endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
list(REMOVE_ITEM GPU_FLAGS list(REMOVE_ITEM GPU_FLAGS
"-D__CUDA_NO_HALF_OPERATORS__" "-D__CUDA_NO_HALF_OPERATORS__"
"-D__CUDA_NO_HALF_CONVERSIONS__" "-D__CUDA_NO_HALF_CONVERSIONS__"
@ -117,6 +119,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
list(APPEND GPU_FLAGS list(APPEND GPU_FLAGS
"-DUSE_ROCM" "-DUSE_ROCM"
"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__" "-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc") "-fno-gpu-rdc")

View File

@ -63,6 +63,7 @@ DEFAULT_CONDA_PATTERNS = {
"magma", "magma",
"triton", "triton",
"optree", "optree",
"nccl",
} }
DEFAULT_PIP_PATTERNS = { DEFAULT_PIP_PATTERNS = {
@ -73,6 +74,7 @@ DEFAULT_PIP_PATTERNS = {
"triton", "triton",
"optree", "optree",
"onnx", "onnx",
"nccl",
} }

View File

@ -10,11 +10,11 @@
namespace vllm { namespace vllm {
// Activation and gating kernel template. // Activation and gating kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void act_and_mul_kernel( __global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d] const scalar_t* __restrict__ input, // [..., 2, d]
const int d) { const int d) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
@ -23,72 +23,66 @@ __global__ void act_and_mul_kernel(
} }
} }
template<typename T> template <typename T>
__device__ __forceinline__ T silu_kernel(const T& x) { __device__ __forceinline__ T silu_kernel(const T& x) {
// x * sigmoid(x) // x * sigmoid(x)
return (T) (((float) x) / (1.0f + expf((float) -x))); return (T)(((float)x) / (1.0f + expf((float)-x)));
} }
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) { __device__ __forceinline__ T gelu_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'none' approximation. // Equivalent to PyTorch GELU with 'none' approximation.
// Refer to: // Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
const float f = (float) x; const float f = (float)x;
constexpr float ALPHA = M_SQRT1_2; constexpr float ALPHA = M_SQRT1_2;
return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA))); return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
} }
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) { __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'tanh' approximation. // Equivalent to PyTorch GELU with 'tanh' approximation.
// Refer to: // Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
const float f = (float) x; const float f = (float)x;
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
constexpr float KAPPA = 0.044715; constexpr float KAPPA = 0.044715;
float x_cube = f * f * f; float x_cube = f * f * f;
float inner = BETA * (f + KAPPA * x_cube); float inner = BETA * (f + KAPPA * x_cube);
return (T) (0.5f * f * (1.0f + ::tanhf(inner))); return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
} }
} // namespace vllm } // namespace vllm
// Launch activation and gating kernel. // Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
int d = input.size(-1) / 2; \ int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \ int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \ input.scalar_type(), "act_and_mul_kernel", [&] { \
"act_and_mul_kernel", \ vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
[&] { \ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \ input.data_ptr<scalar_t>(), d); \
out.data_ptr<scalar_t>(), \ });
input.data_ptr<scalar_t>(), \
d); \
});
void silu_and_mul( void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d]
torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
} }
void gelu_and_mul( void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d]
torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
} }
void gelu_tanh_and_mul( void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d]
torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
} }
@ -96,11 +90,11 @@ void gelu_tanh_and_mul(
namespace vllm { namespace vllm {
// Element-wise activation kernel template. // Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel( __global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d] const scalar_t* __restrict__ input, // [..., d]
const int d) { const int d) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
@ -108,54 +102,49 @@ __global__ void activation_kernel(
} }
} }
} // namespace vllm } // namespace vllm
// Launch element-wise activation kernel. // Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \ int d = input.size(-1); \
int64_t num_tokens = input.numel() / d; \ int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
input.scalar_type(), \ vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
"activation_kernel", \ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
[&] { \ input.data_ptr<scalar_t>(), d); \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \ });
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
namespace vllm { namespace vllm {
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_new_kernel(const T& x) { __device__ __forceinline__ T gelu_new_kernel(const T& x) {
const float x3 = (float) (x * x * x); const float x3 = (float)(x * x * x);
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3)))); const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
return ((T) 0.5) * x * (((T) 1.0) + t); return ((T)0.5) * x * (((T)1.0) + t);
} }
template<typename T> template <typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T& x) { __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
const float f = (float) x; const float f = (float)x;
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x)); const T t =
return ((T) 0.5) * x * (((T) 1.0) + t); (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
return ((T)0.5) * x * (((T)1.0) + t);
} }
} // namespace vllm } // namespace vllm
void gelu_new( void gelu_new(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., d]
torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
} }
void gelu_fast( void gelu_fast(torch::Tensor& out, // [..., d]
torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., d]
torch::Tensor& input) // [..., d]
{ {
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
} }

View File

@ -4,4 +4,4 @@
#include "dtype_float16.cuh" #include "dtype_float16.cuh"
#include "dtype_float32.cuh" #include "dtype_float32.cuh"
#include "dtype_bfloat16.cuh" #include "dtype_bfloat16.cuh"
#include "dtype_fp8_e5m2.cuh" #include "dtype_fp8.cuh"

View File

@ -1,5 +1,6 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@ -22,31 +23,31 @@
namespace vllm { namespace vllm {
// A vector type to store Q, K, V elements. // A vector type to store Q, K, V elements.
template<typename T, int VEC_SIZE> template <typename T, int VEC_SIZE>
struct Vec {}; struct Vec {};
// A vector type to store FP32 accumulators. // A vector type to store FP32 accumulators.
template<typename T> template <typename T>
struct FloatVec {}; struct FloatVec {};
// Template vector operations. // Template vector operations.
template<typename Acc, typename A, typename B> template <typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b); inline __device__ Acc mul(A a, B b);
template<typename T> template <typename T>
inline __device__ float sum(T v); inline __device__ float sum(T v);
template<typename T> template <typename T>
inline __device__ float dot(T a, T b) { inline __device__ float dot(T a, T b) {
return sum(mul<T, T, T>(a, b)); return sum(mul<T, T, T>(a, b));
} }
template<typename A, typename T> template <typename A, typename T>
inline __device__ float dot(T a, T b) { inline __device__ float dot(T a, T b) {
return sum(mul<A, T, T>(a, b)); return sum(mul<A, T, T>(a, b));
} }
template<typename T> template <typename T>
inline __device__ void zero(T& dst) { inline __device__ void zero(T& dst) {
constexpr int WORDS = sizeof(T) / 4; constexpr int WORDS = sizeof(T) / 4;
union { union {
@ -61,4 +62,4 @@ inline __device__ void zero(T& dst) {
dst = tmp.raw; dst = tmp.raw;
} }
} // namespace vllm } // namespace vllm

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,6 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@ -26,7 +27,7 @@
namespace vllm { namespace vllm {
// Q*K^T operation. // Q*K^T operation.
template<int THREAD_GROUP_SIZE, typename Vec, int N> template <int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
using A_vec = typename FloatVec<Vec>::Type; using A_vec = typename FloatVec<Vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately). // Compute the parallel products for Q*K^T (treat vector lanes separately).
@ -45,12 +46,12 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
return qk; return qk;
} }
template<typename T, int THREAD_GROUP_SIZE> template <typename T, int THREAD_GROUP_SIZE>
struct Qk_dot { struct Qk_dot {
template<typename Vec, int N> template <typename Vec, int N>
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k); return qk_dot_<THREAD_GROUP_SIZE>(q, k);
} }
}; };
} // namespace vllm } // namespace vllm

View File

@ -1,6 +1,8 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@ -28,8 +30,8 @@
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
typedef __hip_bfloat162 __nv_bfloat162; typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#endif #endif
#include <stdint.h> #include <stdint.h>
@ -50,37 +52,37 @@ struct bf16_8_t {
}; };
// BF16 vector types for Q, K, V. // BF16 vector types for Q, K, V.
template<> template <>
struct Vec<__nv_bfloat16, 1> { struct Vec<__nv_bfloat16, 1> {
using Type = __nv_bfloat16; using Type = __nv_bfloat16;
}; };
template<> template <>
struct Vec<__nv_bfloat16, 2> { struct Vec<__nv_bfloat16, 2> {
using Type = __nv_bfloat162; using Type = __nv_bfloat162;
}; };
template<> template <>
struct Vec<__nv_bfloat16, 4> { struct Vec<__nv_bfloat16, 4> {
using Type = bf16_4_t; using Type = bf16_4_t;
}; };
template<> template <>
struct Vec<__nv_bfloat16, 8> { struct Vec<__nv_bfloat16, 8> {
using Type = bf16_8_t; using Type = bf16_8_t;
}; };
// FP32 accumulator vector types corresponding to Vec. // FP32 accumulator vector types corresponding to Vec.
template<> template <>
struct FloatVec<__nv_bfloat16> { struct FloatVec<__nv_bfloat16> {
using Type = float; using Type = float;
}; };
template<> template <>
struct FloatVec<__nv_bfloat162> { struct FloatVec<__nv_bfloat162> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct FloatVec<bf16_4_t> { struct FloatVec<bf16_4_t> {
using Type = Float4_; using Type = Float4_;
}; };
template<> template <>
struct FloatVec<bf16_8_t> { struct FloatVec<bf16_8_t> {
using Type = Float8_; using Type = Float8_;
}; };
@ -108,9 +110,9 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
assert(false); assert(false);
#else #else
#ifndef USE_ROCM #ifndef USE_ROCM
return a + b; return a + b;
#else #else
return __hadd(a, b); return __hadd(a, b);
#endif #endif
#endif #endif
} }
@ -161,7 +163,7 @@ inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
} }
// Vector multiplication. // Vector multiplication.
template<> template <>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
@ -170,7 +172,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#endif #endif
} }
template<> template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
@ -179,12 +181,12 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
#endif #endif
} }
template<> template <>
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
} }
template<> template <>
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
bf16_4_t c; bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
@ -192,7 +194,7 @@ inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
return c; return c;
} }
template<> template <>
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
bf16_4_t c; bf16_4_t c;
@ -201,7 +203,7 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
return c; return c;
} }
template<> template <>
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
bf16_8_t c; bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
@ -211,7 +213,7 @@ inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
return c; return c;
} }
template<> template <>
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
bf16_8_t c; bf16_8_t c;
@ -222,26 +224,26 @@ inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
return c; return c;
} }
template<> template <>
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
float fa = __bfloat162float(a); float fa = __bfloat162float(a);
float fb = __bfloat162float(b); float fb = __bfloat162float(b);
return fa * fb; return fa * fb;
} }
template<> template <>
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
float2 fa = bf1622float2(a); float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b); float2 fb = bf1622float2(b);
return mul<float2, float2, float2>(fa, fb); return mul<float2, float2, float2>(fa, fb);
} }
template<> template <>
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
} }
template<> template <>
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
Float4_ fc; Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
@ -249,7 +251,7 @@ inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
Float4_ fc; Float4_ fc;
@ -258,7 +260,7 @@ inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
Float8_ fc; Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
@ -268,7 +270,7 @@ inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a); __nv_bfloat162 s = bf162bf162(a);
Float8_ fc; Float8_ fc;
@ -280,7 +282,8 @@ inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
} }
// Vector fused multiply-add. // Vector fused multiply-add.
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
#else #else
@ -288,7 +291,8 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bf
#endif #endif
} }
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
__nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
#else #else
@ -379,23 +383,23 @@ inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
} }
// Vector sum. // Vector sum.
template<> template <>
inline __device__ float sum(__nv_bfloat16 v) { inline __device__ float sum(__nv_bfloat16 v) {
return __bfloat162float(v); return __bfloat162float(v);
} }
template<> template <>
inline __device__ float sum(__nv_bfloat162 v) { inline __device__ float sum(__nv_bfloat162 v) {
float2 vf = bf1622float2(v); float2 vf = bf1622float2(v);
return vf.x + vf.y; return vf.x + vf.y;
} }
template<> template <>
inline __device__ float sum(bf16_4_t v) { inline __device__ float sum(bf16_4_t v) {
return sum(v.x) + sum(v.y); return sum(v.x) + sum(v.y);
} }
template<> template <>
inline __device__ float sum(bf16_8_t v) { inline __device__ float sum(bf16_8_t v) {
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
} }
@ -448,4 +452,4 @@ inline __device__ void zero(__nv_bfloat16& dst) {
#endif #endif
} }
} // namespace vllm } // namespace vllm

View File

@ -1,6 +1,8 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@ -30,37 +32,37 @@
namespace vllm { namespace vllm {
// FP16 vector types for Q, K, V. // FP16 vector types for Q, K, V.
template<> template <>
struct Vec<uint16_t, 1> { struct Vec<uint16_t, 1> {
using Type = uint16_t; using Type = uint16_t;
}; };
template<> template <>
struct Vec<uint16_t, 2> { struct Vec<uint16_t, 2> {
using Type = uint32_t; using Type = uint32_t;
}; };
template<> template <>
struct Vec<uint16_t, 4> { struct Vec<uint16_t, 4> {
using Type = uint2; using Type = uint2;
}; };
template<> template <>
struct Vec<uint16_t, 8> { struct Vec<uint16_t, 8> {
using Type = uint4; using Type = uint4;
}; };
// FP32 accumulator vector types corresponding to Vec. // FP32 accumulator vector types corresponding to Vec.
template<> template <>
struct FloatVec<uint16_t> { struct FloatVec<uint16_t> {
using Type = float; using Type = float;
}; };
template<> template <>
struct FloatVec<uint32_t> { struct FloatVec<uint32_t> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct FloatVec<uint2> { struct FloatVec<uint2> {
using Type = Float4_; using Type = Float4_;
}; };
template<> template <>
struct FloatVec<uint4> { struct FloatVec<uint4> {
using Type = Float8_; using Type = Float8_;
}; };
@ -73,8 +75,8 @@ inline __device__ uint32_t h0_h0(uint16_t a) {
return b; return b;
#else #else
union { union {
uint32_t u32; uint32_t u32;
uint16_t u16[2]; uint16_t u16[2];
} tmp; } tmp;
tmp.u16[0] = a; tmp.u16[0] = a;
tmp.u16[1] = a; tmp.u16[1] = a;
@ -130,10 +132,12 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
} tmp; } tmp;
#ifndef USE_ROCM #ifndef USE_ROCM
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n"
: "=r"(tmp.u32)
: "f"(f.y), "f"(f.x));
#else #else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif #endif
#else #else
tmp.u16[0] = float_to_half(f.x); tmp.u16[0] = float_to_half(f.x);
@ -201,7 +205,7 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
} }
// Vector multiplication. // Vector multiplication.
template<> template <>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) { inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
uint16_t c; uint16_t c;
#ifndef USE_ROCM #ifndef USE_ROCM
@ -212,7 +216,7 @@ inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
return c; return c;
} }
template<> template <>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) { inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
uint32_t c; uint32_t c;
#ifndef USE_ROCM #ifndef USE_ROCM
@ -223,12 +227,12 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
return c; return c;
} }
template<> template <>
inline __device__ uint32_t mul(uint16_t a, uint32_t b) { inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b); return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
} }
template<> template <>
inline __device__ uint2 mul(uint2 a, uint2 b) { inline __device__ uint2 mul(uint2 a, uint2 b) {
uint2 c; uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x); c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
@ -236,7 +240,7 @@ inline __device__ uint2 mul(uint2 a, uint2 b) {
return c; return c;
} }
template<> template <>
inline __device__ uint2 mul(uint16_t a, uint2 b) { inline __device__ uint2 mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
uint2 c; uint2 c;
@ -245,7 +249,7 @@ inline __device__ uint2 mul(uint16_t a, uint2 b) {
return c; return c;
} }
template<> template <>
inline __device__ uint4 mul(uint4 a, uint4 b) { inline __device__ uint4 mul(uint4 a, uint4 b) {
uint4 c; uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x); c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
@ -255,7 +259,7 @@ inline __device__ uint4 mul(uint4 a, uint4 b) {
return c; return c;
} }
template<> template <>
inline __device__ uint4 mul(uint16_t a, uint4 b) { inline __device__ uint4 mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
uint4 c; uint4 c;
@ -266,26 +270,26 @@ inline __device__ uint4 mul(uint16_t a, uint4 b) {
return c; return c;
} }
template<> template <>
inline __device__ float mul(uint16_t a, uint16_t b) { inline __device__ float mul(uint16_t a, uint16_t b) {
float fa = half_to_float(a); float fa = half_to_float(a);
float fb = half_to_float(b); float fb = half_to_float(b);
return fa * fb; return fa * fb;
} }
template<> template <>
inline __device__ float2 mul(uint32_t a, uint32_t b) { inline __device__ float2 mul(uint32_t a, uint32_t b) {
float2 fa = half2_to_float2(a); float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b); float2 fb = half2_to_float2(b);
return mul<float2, float2, float2>(fa, fb); return mul<float2, float2, float2>(fa, fb);
} }
template<> template <>
inline __device__ float2 mul(uint16_t a, uint32_t b) { inline __device__ float2 mul(uint16_t a, uint32_t b) {
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b); return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
} }
template<> template <>
inline __device__ Float4_ mul(uint2 a, uint2 b) { inline __device__ Float4_ mul(uint2 a, uint2 b) {
Float4_ fc; Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x); fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
@ -293,7 +297,7 @@ inline __device__ Float4_ mul(uint2 a, uint2 b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float4_ mul(uint16_t a, uint2 b) { inline __device__ Float4_ mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
Float4_ fc; Float4_ fc;
@ -302,7 +306,7 @@ inline __device__ Float4_ mul(uint16_t a, uint2 b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(uint4 a, uint4 b) { inline __device__ Float8_ mul(uint4 a, uint4 b) {
Float8_ fc; Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x); fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
@ -312,7 +316,7 @@ inline __device__ Float8_ mul(uint4 a, uint4 b) {
return fc; return fc;
} }
template<> template <>
inline __device__ Float8_ mul(uint16_t a, uint4 b) { inline __device__ Float8_ mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a); uint32_t s = h0_h0(a);
Float8_ fc; Float8_ fc;
@ -327,9 +331,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d; uint32_t d;
#ifndef USE_ROCM #ifndef USE_ROCM
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(d)
: "r"(a), "r"(b), "r"(c));
#else #else
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n"
: "=v"(d)
: "v"(a), "v"(b), "v"(c));
#endif #endif
return d; return d;
} }
@ -423,24 +431,24 @@ inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
} }
// Vector sum. // Vector sum.
template<> template <>
inline __device__ float sum(uint16_t v) { inline __device__ float sum(uint16_t v) {
return half_to_float(v); return half_to_float(v);
} }
template<> template <>
inline __device__ float sum(uint32_t v) { inline __device__ float sum(uint32_t v) {
float2 tmp = half2_to_float2(v); float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y; return tmp.x + tmp.y;
} }
template<> template <>
inline __device__ float sum(uint2 v) { inline __device__ float sum(uint2 v) {
uint32_t c = add(v.x, v.y); uint32_t c = add(v.x, v.y);
return sum(c); return sum(c);
} }
template<> template <>
inline __device__ float sum(uint4 v) { inline __device__ float sum(uint4 v) {
uint32_t c = add(v.x, v.y); uint32_t c = add(v.x, v.y);
c = add(c, v.z); c = add(c, v.z);
@ -470,13 +478,9 @@ inline __device__ void from_float(uint4& dst, Float8_ src) {
} }
// From float16 to float32. // From float16 to float32.
inline __device__ float to_float(uint16_t u) { inline __device__ float to_float(uint16_t u) { return half_to_float(u); }
return half_to_float(u);
}
inline __device__ float2 to_float(uint32_t u) { inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); }
return half2_to_float2(u);
}
inline __device__ Float4_ to_float(uint2 u) { inline __device__ Float4_ to_float(uint2 u) {
Float4_ tmp; Float4_ tmp;
@ -495,8 +499,6 @@ inline __device__ Float8_ to_float(uint4 u) {
} }
// Zero-out a variable. // Zero-out a variable.
inline __device__ void zero(uint16_t& dst) { inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); }
dst = uint16_t(0);
}
} // namespace vllm } // namespace vllm

View File

@ -1,6 +1,8 @@
/* /*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp * Adapted from
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team. * Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* *
@ -38,37 +40,35 @@ struct Float8_ {
}; };
// FP32 vector types for Q, K, V. // FP32 vector types for Q, K, V.
template<> template <>
struct Vec<float, 1> { struct Vec<float, 1> {
using Type = float; using Type = float;
}; };
template<> template <>
struct Vec<float, 2> { struct Vec<float, 2> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct Vec<float, 4> { struct Vec<float, 4> {
using Type = float4; using Type = float4;
}; };
// FP32 accumulator vector types corresponding to Vec. // FP32 accumulator vector types corresponding to Vec.
template<> template <>
struct FloatVec<float> { struct FloatVec<float> {
using Type = float; using Type = float;
}; };
template<> template <>
struct FloatVec<float2> { struct FloatVec<float2> {
using Type = float2; using Type = float2;
}; };
template<> template <>
struct FloatVec<float4> { struct FloatVec<float4> {
using Type = float4; using Type = float4;
}; };
// Vector addition. // Vector addition.
inline __device__ float add(float a, float b) { inline __device__ float add(float a, float b) { return a + b; }
return a + b;
}
inline __device__ float2 add(float2 a, float2 b) { inline __device__ float2 add(float2 a, float2 b) {
float2 c; float2 c;
@ -87,12 +87,12 @@ inline __device__ float4 add(float4 a, float4 b) {
} }
// Vector multiplication. // Vector multiplication.
template<> template <>
inline __device__ float mul<float, float>(float a, float b) { inline __device__ float mul<float, float>(float a, float b) {
return a * b; return a * b;
} }
template<> template <>
inline __device__ float2 mul(float2 a, float2 b) { inline __device__ float2 mul(float2 a, float2 b) {
float2 c; float2 c;
c.x = a.x * b.x; c.x = a.x * b.x;
@ -100,7 +100,7 @@ inline __device__ float2 mul(float2 a, float2 b) {
return c; return c;
} }
template<> template <>
inline __device__ float2 mul(float a, float2 b) { inline __device__ float2 mul(float a, float2 b) {
float2 c; float2 c;
c.x = a * b.x; c.x = a * b.x;
@ -108,7 +108,7 @@ inline __device__ float2 mul(float a, float2 b) {
return c; return c;
} }
template<> template <>
inline __device__ float4 mul(float4 a, float4 b) { inline __device__ float4 mul(float4 a, float4 b) {
float4 c; float4 c;
c.x = a.x * b.x; c.x = a.x * b.x;
@ -118,7 +118,7 @@ inline __device__ float4 mul(float4 a, float4 b) {
return c; return c;
} }
template<> template <>
inline __device__ float4 mul(float a, float4 b) { inline __device__ float4 mul(float a, float4 b) {
float4 c; float4 c;
c.x = a * b.x; c.x = a * b.x;
@ -129,9 +129,7 @@ inline __device__ float4 mul(float a, float4 b) {
} }
// Vector fused multiply-add. // Vector fused multiply-add.
inline __device__ float fma(float a, float b, float c) { inline __device__ float fma(float a, float b, float c) { return a * b + c; }
return a * b + c;
}
inline __device__ float2 fma(float2 a, float2 b, float2 c) { inline __device__ float2 fma(float2 a, float2 b, float2 c) {
float2 d; float2 d;
@ -182,35 +180,33 @@ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
} }
// Vector sum. // Vector sum.
template<> template <>
inline __device__ float sum(float v) { inline __device__ float sum(float v) {
return v; return v;
} }
template<> template <>
inline __device__ float sum(float2 v) { inline __device__ float sum(float2 v) {
return v.x + v.y; return v.x + v.y;
} }
template<> template <>
inline __device__ float sum(float4 v) { inline __device__ float sum(float4 v) {
return v.x + v.y + v.z + v.w; return v.x + v.y + v.z + v.w;
} }
template<> template <>
inline __device__ float sum(Float4_ v) { inline __device__ float sum(Float4_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y; return v.x.x + v.x.y + v.y.x + v.y.y;
} }
template<> template <>
inline __device__ float sum(Float8_ v) { inline __device__ float sum(Float8_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
} }
// Vector dot product. // Vector dot product.
inline __device__ float dot(float a, float b) { inline __device__ float dot(float a, float b) { return a * b; }
return a * b;
}
inline __device__ float dot(float2 a, float2 b) { inline __device__ float dot(float2 a, float2 b) {
float2 c = mul<float2, float2, float2>(a, b); float2 c = mul<float2, float2, float2>(a, b);
@ -232,42 +228,24 @@ inline __device__ float dot(Float8_ a, Float8_ b) {
} }
// From float to float. // From float to float.
inline __device__ void from_float(float& dst, float src) { inline __device__ void from_float(float& dst, float src) { dst = src; }
dst = src;
}
inline __device__ void from_float(float2& dst, float2 src) { inline __device__ void from_float(float2& dst, float2 src) { dst = src; }
dst = src;
}
inline __device__ void from_float(float4& dst, float4 src) { inline __device__ void from_float(float4& dst, float4 src) { dst = src; }
dst = src;
}
// From float to float. // From float to float.
inline __device__ float to_float(float u) { inline __device__ float to_float(float u) { return u; }
return u;
}
inline __device__ float2 to_float(float2 u) { inline __device__ float2 to_float(float2 u) { return u; }
return u;
}
inline __device__ float4 to_float(float4 u) { inline __device__ float4 to_float(float4 u) { return u; }
return u;
}
inline __device__ Float4_ to_float(Float4_ u) { inline __device__ Float4_ to_float(Float4_ u) { return u; }
return u;
}
inline __device__ Float8_ to_float(Float8_ u) { inline __device__ Float8_ to_float(Float8_ u) { return u; }
return u;
}
// Zero-out a variable. // Zero-out a variable.
inline __device__ void zero(float& dst) { inline __device__ void zero(float& dst) { dst = 0.f; }
dst = 0.f;
}
} // namespace vllm } // namespace vllm

View File

@ -0,0 +1,41 @@
#pragma once
#include "attention_generic.cuh"
#include <stdint.h>
#ifdef ENABLE_FP8
#ifndef USE_ROCM
#include <cuda_fp8.h>
#endif // USE_ROCM
#endif // ENABLE_FP8
namespace vllm {
enum class Fp8KVCacheDataType {
kAuto = 0,
kFp8E4M3 = 1,
kFp8E5M2 = 2,
};
// fp8 vector types for quantization of kv cache
template <>
struct Vec<uint8_t, 1> {
using Type = uint8_t;
};
template <>
struct Vec<uint8_t, 2> {
using Type = uint16_t;
};
template <>
struct Vec<uint8_t, 4> {
using Type = uint32_t;
};
template <>
struct Vec<uint8_t, 8> {
using Type = uint2;
};
} // namespace vllm

View File

@ -1,35 +0,0 @@
#pragma once
#include "attention_generic.cuh"
#include <stdint.h>
#ifdef ENABLE_FP8_E5M2
#include <cuda_fp8.h>
#endif
namespace vllm {
#ifdef ENABLE_FP8_E5M2
// fp8 vector types for quantization of kv cache
template<>
struct Vec<uint8_t, 1> {
using Type = uint8_t;
};
template<>
struct Vec<uint8_t, 2> {
using Type = uint16_t;
};
template<>
struct Vec<uint8_t, 4> {
using Type = uint32_t;
};
template<>
struct Vec<uint8_t, 8> {
using Type = uint2;
};
#endif // ENABLE_FP8_E5M2
} // namespace vllm

View File

@ -5,25 +5,24 @@
#include <map> #include <map>
#include <vector> #include <vector>
void swap_blocks( void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
torch::Tensor& src, const torch::Tensor& block_mapping);
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping);
void copy_blocks( void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& key_caches, std::vector<torch::Tensor>& value_caches,
std::vector<torch::Tensor>& value_caches, const torch::Tensor& block_mapping);
const std::map<int64_t, std::vector<int64_t>>& block_mapping);
void reshape_and_cache( void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& value, torch::Tensor& slot_mapping,
torch::Tensor& key_cache, const std::string& kv_cache_dtype, const float kv_scale);
torch::Tensor& value_cache,
torch::Tensor& slot_mapping, void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
const std::string& kv_cache_dtype); torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);
// Just for unittest // Just for unittest
void convert_fp8_e5m2( void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
torch::Tensor& src_cache, const float scale, const std::string& kv_cache_dtype);
torch::Tensor& dst_cache);

View File

@ -4,8 +4,11 @@
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#ifdef ENABLE_FP8_E5M2
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" #ifdef USE_ROCM
#include "quantization/fp8/amd/quant_utils.cuh"
#else
#include "quantization/fp8/nvidia/quant_utils.cuh"
#endif #endif
#include <algorithm> #include <algorithm>
@ -15,20 +18,17 @@
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#endif #endif
void swap_blocks( void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
torch::Tensor& src, const torch::Tensor& block_mapping) {
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping) {
torch::Device src_device = src.device(); torch::Device src_device = src.device();
torch::Device dst_device = dst.device(); torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type; cudaMemcpyKind memcpy_type;
if (src_device.is_cuda() && dst_device.is_cuda()) { if (src_device.is_cuda() && dst_device.is_cuda()) {
TORCH_CHECK( TORCH_CHECK(src_device.index() == dst_device.index(),
src_device.index() == dst_device.index(), "src and dst must be on the same GPU");
"src and dst must be on the same GPU");
memcpy_type = cudaMemcpyDeviceToDevice; memcpy_type = cudaMemcpyDeviceToDevice;
} else if (src_device.is_cuda() && dst_device.is_cpu()) { } else if (src_device.is_cuda() && dst_device.is_cpu()) {
memcpy_type = cudaMemcpyDeviceToHost; memcpy_type = cudaMemcpyDeviceToHost;
@ -38,41 +38,44 @@ void swap_blocks(
TORCH_CHECK(false, "Invalid device combination"); TORCH_CHECK(false, "Invalid device combination");
} }
char *src_ptr = static_cast<char*>(src.data_ptr()); // NOTE(youkaichao): keep in mind that `block_mapping` should be
char *dst_ptr = static_cast<char*>(dst.data_ptr()); // a cpu tensor, otherwise every `item` call will require a gpu-cpu
// synchronization.
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
char* src_ptr = static_cast<char*>(src.data_ptr());
char* dst_ptr = static_cast<char*>(dst.data_ptr());
const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); const at::cuda::OptionalCUDAGuard device_guard(
src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large. // NOTE(woosuk): This can be slow if the number of blocks is large.
for (const auto& pair : block_mapping) { const int64_t num_blocks = block_mapping.size(0);
int64_t src_block_number = pair.first; for (size_t i = 0; i < num_blocks; i++) {
int64_t dst_block_number = pair.second; int64_t src_block_number = block_mapping[i][0].item<int64_t>();
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
int64_t src_offset = src_block_number * block_size_in_bytes; int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync( cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
dst_ptr + dst_offset, block_size_in_bytes, memcpy_type, stream);
src_ptr + src_offset,
block_size_in_bytes,
memcpy_type,
stream);
} }
} }
namespace vllm { namespace vllm {
// Grid: (num_layers, num_pairs) // Grid: (num_layers, num_pairs)
template<typename scalar_t> template <typename scalar_t>
__global__ void copy_blocks_kernel( __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs,
int64_t* key_cache_ptrs, int64_t* value_cache_ptrs,
int64_t* value_cache_ptrs, const int64_t* __restrict__ block_mapping,
const int64_t* __restrict__ block_mapping, const int numel_per_block) {
const int numel_per_block) {
const int layer_idx = blockIdx.x; const int layer_idx = blockIdx.x;
const int pair_idx = blockIdx.y; const int pair_idx = blockIdx.y;
scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]); scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]); scalar_t* value_cache =
reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
int64_t src_block_number = block_mapping[2 * pair_idx]; int64_t src_block_number = block_mapping[2 * pair_idx];
int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
@ -90,12 +93,11 @@ __global__ void copy_blocks_kernel(
} }
} }
} // namespace vllm } // namespace vllm
void copy_blocks( void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& key_caches, std::vector<torch::Tensor>& value_caches,
std::vector<torch::Tensor>& value_caches, const torch::Tensor& block_mapping) {
const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
int num_layers = key_caches.size(); int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size()); TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) { if (num_layers == 0) {
@ -109,29 +111,23 @@ void copy_blocks(
int64_t key_cache_ptrs[num_layers]; int64_t key_cache_ptrs[num_layers];
int64_t value_cache_ptrs[num_layers]; int64_t value_cache_ptrs[num_layers];
for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr()); key_cache_ptrs[layer_idx] =
value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr()); reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
value_cache_ptrs[layer_idx] =
reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
} }
// Create block mapping array.
std::vector<int64_t> block_mapping_vec; // block_mapping is a 2D tensor with shape (num_pairs, 2).
for (const auto& pair : block_mapping) { int num_pairs = block_mapping.size(0);
int64_t src_block_number = pair.first;
for (int64_t dst_block_number : pair.second) {
block_mapping_vec.push_back(src_block_number);
block_mapping_vec.push_back(dst_block_number);
}
}
int64_t* block_mapping_array = block_mapping_vec.data();
int num_pairs = block_mapping_vec.size() / 2;
// Move the data structures to the GPU. // Move the data structures to the GPU.
// NOTE: This synchronizes the CPU and GPU. // NOTE: This synchronizes the CPU and GPU.
torch::Tensor key_cache_ptrs_tensor = torch::from_blob( torch::Tensor key_cache_ptrs_tensor =
key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
torch::Tensor value_cache_ptrs_tensor = torch::from_blob( .to(cache_device);
value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::Tensor value_cache_ptrs_tensor =
torch::Tensor block_mapping_tensor = torch::from_blob( torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device); .to(cache_device);
// Launch the kernel. // Launch the kernel.
const int numel_per_block = key_caches[0][0].numel(); const int numel_per_block = key_caches[0][0].numel();
@ -140,30 +136,28 @@ void copy_blocks(
const at::cuda::OptionalCUDAGuard device_guard(cache_device); const at::cuda::OptionalCUDAGuard device_guard(cache_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(), key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(), value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping_tensor.data_ptr<int64_t>(), block_mapping.data_ptr<int64_t>(), numel_per_block);
numel_per_block); }));
}));
} }
namespace vllm { namespace vllm {
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache> template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel( __global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x,
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] // block_size, x]
const int64_t* __restrict__ slot_mapping, // [num_tokens] cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size,
const int key_stride, // block_size]
const int value_stride, const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int num_heads, const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int head_size, const int block_size, const int x,
const int block_size, const float kv_scale) {
const int x) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx]; const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) { if (slot_idx < 0) {
@ -184,55 +178,84 @@ __global__ void reshape_and_cache_kernel(
const int x_idx = head_offset / x; const int x_idx = head_offset / x;
const int x_offset = head_offset % x; const int x_offset = head_offset % x;
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x const int64_t tgt_key_idx =
+ head_idx * (head_size / x) * block_size * x block_idx * num_heads * (head_size / x) * block_size * x +
+ x_idx * block_size * x head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
+ block_offset * x block_offset * x + x_offset;
+ x_offset; const int64_t tgt_value_idx =
const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size block_idx * num_heads * head_size * block_size +
+ head_idx * head_size * block_size head_idx * head_size * block_size + head_offset * block_size +
+ head_offset * block_size block_offset;
+ block_offset;
scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx]; scalar_t tgt_value = value[src_value_idx];
if constexpr (is_fp8_e5m2_kv_cache) { if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
#ifdef ENABLE_FP8_E5M2
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
#else
assert(false);
#endif
} else {
key_cache[tgt_key_idx] = tgt_key; key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value; value_cache[tgt_value_idx] = tgt_value;
} else {
key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
} }
} }
} }
} // namespace vllm template <typename scalar_t>
__global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
// head_size]
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
// head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, const int key_stride, const int value_stride,
const int num_heads, const int head_size, const int block_size) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int64_t tgt_value_idx = block_idx * block_stride +
block_offset * num_heads * head_size +
head_idx * head_size + head_offset;
k_cache[tgt_value_idx] = key[src_key_idx];
v_cache[tgt_value_idx] = value[src_value_idx];
}
}
} // namespace vllm
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \ // KV_T is the stored data type of kv-cache.
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \ // CACHE_T is the data type of key and value tensors.
reinterpret_cast<KV_T*>(key.data_ptr()), \ // KV_DTYPE is the real data type of kv-cache.
reinterpret_cast<KV_T*>(value.data_ptr()), \ #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \ vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \ <<<grid, block, 0, stream>>>( \
slot_mapping.data_ptr<int64_t>(), \ reinterpret_cast<KV_T*>(key.data_ptr()), \
key_stride, \ reinterpret_cast<KV_T*>(value.data_ptr()), \
value_stride, \ reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
num_heads, \ reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
head_size, \ slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
block_size, \ num_heads, head_size, block_size, x, kv_scale);
x);
void reshape_and_cache( void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size] torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size] torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor&
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& slot_mapping, // [num_tokens] torch::Tensor&
const std::string& kv_cache_dtype) value_cache, // [num_blocks, num_heads, head_size, block_size]
{ torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, const float kv_scale) {
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
int head_size = key.size(2); int head_size = key.size(2);
@ -246,57 +269,80 @@ void reshape_and_cache(
dim3 block(std::min(num_heads * head_size, 512)); dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (kv_cache_dtype == "auto") {
if (key.dtype() == at::ScalarType::Float) { DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE(float, float, false); CALL_RESHAPE_AND_CACHE)
} else if (key.dtype() == at::ScalarType::Half) { }
CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
} else if (key.dtype() == at::ScalarType::BFloat16) { void reshape_and_cache_flash(
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false); torch::Tensor& key, // [num_tokens, num_heads, head_size]
} torch::Tensor& value, // [num_tokens, num_heads, head_size]
} else if (kv_cache_dtype == "fp8_e5m2") { torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
if (key.dtype() == at::ScalarType::Float) { torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
CALL_RESHAPE_AND_CACHE(float, uint8_t, true); torch::Tensor& slot_mapping, // [num_tokens]
} else if (key.dtype() == at::ScalarType::Half) { const std::string& kv_cache_dtype) {
CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true); // FIXME: only support auto datatype, does not support fp8
} else if (key.dtype() == at::ScalarType::BFloat16) { if (kv_cache_dtype != "auto") {
CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
}
} else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
} }
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = k_cache.size(1);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
int block_stride = k_cache.stride(0);
TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), "reshape_and_cache_flash", [&] {
vllm::reshape_and_cache_flash_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
value_stride, num_heads, head_size, block_size);
});
} }
namespace vllm { namespace vllm {
template<typename Tout, typename Tin> template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
__global__ void convert_fp8_e5m2_kernel( __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
const Tin* __restrict__ src_cache, Tout* __restrict__ dst_cache,
Tout* __restrict__ dst_cache, const float kv_scale,
const int64_t block_stride) { const int64_t block_stride) {
const int64_t block_idx = blockIdx.x; const int64_t block_idx = blockIdx.x;
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
int64_t idx = block_idx * block_stride + i; int64_t idx = block_idx * block_stride + i;
#ifdef ENABLE_FP8_E5M2 dst_cache[idx] =
dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]); fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
#else
assert(false);
#endif
} }
} }
} // namespace vllm } // namespace vllm
#define CALL_CONVERT_FP8_E5M2(Tout, Tin) \ #define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>( \ vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<Tin*>(src_cache.data_ptr()), \ reinterpret_cast<Tin*>(src_cache.data_ptr()), \
reinterpret_cast<Tout*>(dst_cache.data_ptr()), \ reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride);
block_stride);
// Only for testing.
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const float kv_scale, const std::string& kv_cache_dtype) {
torch::Device src_device = src_cache.device();
torch::Device dst_device = dst_cache.device();
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
TORCH_CHECK(src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
at::cuda::OptionalCUDAGuard device_guard(src_device);
void convert_fp8_e5m2(
torch::Tensor& src_cache,
torch::Tensor& dst_cache)
{
int64_t num_blocks = src_cache.size(0); int64_t num_blocks = src_cache.size(0);
int64_t block_stride = src_cache.stride(0); int64_t block_stride = src_cache.stride(0);
@ -304,17 +350,37 @@ void convert_fp8_e5m2(
dim3 block(std::min(block_stride, int64_t(512))); dim3 block(std::min(block_stride, int64_t(512)));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if (src_cache.dtype() == at::ScalarType::Float) { if (kv_cache_dtype == "auto") {
CALL_CONVERT_FP8_E5M2(uint8_t, float); if (src_cache.dtype() == at::ScalarType::Float) {
} else if (src_cache.dtype() == at::ScalarType::Half) { CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t); } else if (src_cache.dtype() == at::ScalarType::Half) {
} else if (src_cache.dtype() == at::ScalarType::BFloat16) { CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16); } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
} else if (dst_cache.dtype() == at::ScalarType::Float) { CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
CALL_CONVERT_FP8_E5M2(float, uint8_t); } else if (dst_cache.dtype() == at::ScalarType::Float) {
} else if (dst_cache.dtype() == at::ScalarType::Half) { CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t); } else if (dst_cache.dtype() == at::ScalarType::Half) {
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) { CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t); } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
}
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (src_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(uint8_t, __nv_bfloat16,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::Half) {
CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
}
} else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
} }
} }

View File

@ -1,10 +1,10 @@
#include "cpu_types.hpp" #include "cpu_types.hpp"
namespace { namespace {
template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8 &), template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8&),
bool is_gated> bool is_gated>
void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input, void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input,
scalar_t *__restrict__ output) { scalar_t* __restrict__ output) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
@ -34,13 +34,13 @@ void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
} }
} }
FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 zeros(0.0); const vec_op::FP32Vec8 zeros(0.0);
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
return x / (ones + (zeros - x).exp()); return x / (ones + (zeros - x).exp());
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f); const vec_op::FP32Vec8 w2(0.044715f);
@ -50,7 +50,7 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
return w3 * x * (ones + t); return w3 * x * (ones + t);
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(0.79788456f); const vec_op::FP32Vec8 w1(0.79788456f);
const vec_op::FP32Vec8 w2(0.044715f); const vec_op::FP32Vec8 w2(0.044715f);
@ -59,14 +59,14 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
return w3 * x * (ones + t); return w3 * x * (ones + t);
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT1_2); const vec_op::FP32Vec8 w1(M_SQRT1_2);
const vec_op::FP32Vec8 w2(0.5); const vec_op::FP32Vec8 w2(0.5);
return x * w2 * (ones + (x * w1).er()); return x * w2 * (ones + (x * w1).er());
} }
FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) { FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) {
const vec_op::FP32Vec8 ones(1.0); const vec_op::FP32Vec8 ones(1.0);
const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5); const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
const vec_op::FP32Vec8 w2(0.5); const vec_op::FP32Vec8 w2(0.5);
@ -75,40 +75,36 @@ FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3); const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
return x * w2 * (ones + inner.tanh()); return x * w2 * (ones + inner.tanh());
} }
}; // namespace }; // namespace
void silu_and_mul(torch::Tensor &out, torch::Tensor &input) { void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_impl", [&] {
input.scalar_type(), "silu_and_mul_impl", [&] { CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
CPU_KERNEL_GUARD_IN(silu_and_mul_impl) activation_kernel<scalar_t, silu_act, true>(
activation_kernel<scalar_t, silu_act, true>(num_tokens, d, num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
input.data_ptr<scalar_t>(), CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
out.data_ptr<scalar_t>()); });
CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
});
} }
void gelu_and_mul(torch::Tensor &out, // [..., d] void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor &input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul_impl", [&] {
input.scalar_type(), "gelu_and_mul_impl", [&] { CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
CPU_KERNEL_GUARD_IN(gelu_and_mul_impl) activation_kernel<scalar_t, gelu_act, true>(
activation_kernel<scalar_t, gelu_act, true>(num_tokens, d, num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
input.data_ptr<scalar_t>(), CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
out.data_ptr<scalar_t>()); });
CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
});
} }
void gelu_tanh_and_mul(torch::Tensor &out, // [..., d] void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor &input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1) / 2; int d = input.size(-1) / 2;
@ -123,7 +119,7 @@ void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
}); });
} }
void gelu_new(torch::Tensor &out, torch::Tensor &input) { void gelu_new(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1); int d = input.size(-1);
@ -135,7 +131,7 @@ void gelu_new(torch::Tensor &out, torch::Tensor &input) {
}); });
} }
void gelu_fast(torch::Tensor &out, torch::Tensor &input) { void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
int num_tokens = input.numel() / input.size(-1); int num_tokens = input.numel() / input.size(-1);
int d = input.size(-1); int d = input.size(-1);

View File

@ -2,7 +2,8 @@
namespace { namespace {
template <typename scalar_t> struct KernelVecType { template <typename scalar_t>
struct KernelVecType {
using q_load_vec_type = void; using q_load_vec_type = void;
using q_vec_type = void; using q_vec_type = void;
using k_load_vec_type = void; using k_load_vec_type = void;
@ -11,7 +12,8 @@ template <typename scalar_t> struct KernelVecType {
using v_load_vec_type = void; using v_load_vec_type = void;
}; };
template <> struct KernelVecType<float> { template <>
struct KernelVecType<float> {
using q_load_vec_type = vec_op::FP32Vec4; using q_load_vec_type = vec_op::FP32Vec4;
using q_vec_type = vec_op::FP32Vec16; using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::FP32Vec16; using k_load_vec_type = vec_op::FP32Vec16;
@ -21,7 +23,8 @@ template <> struct KernelVecType<float> {
}; };
#ifdef __AVX512BF16__ #ifdef __AVX512BF16__
template <> struct KernelVecType<c10::BFloat16> { template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8; using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::BF16Vec32; using q_vec_type = vec_op::BF16Vec32;
using k_load_vec_type = vec_op::BF16Vec32; using k_load_vec_type = vec_op::BF16Vec32;
@ -30,7 +33,8 @@ template <> struct KernelVecType<c10::BFloat16> {
using v_load_vec_type = vec_op::BF16Vec16; using v_load_vec_type = vec_op::BF16Vec16;
}; };
#else #else
template <> struct KernelVecType<c10::BFloat16> { template <>
struct KernelVecType<c10::BFloat16> {
using q_load_vec_type = vec_op::BF16Vec8; using q_load_vec_type = vec_op::BF16Vec8;
using q_vec_type = vec_op::FP32Vec16; using q_vec_type = vec_op::FP32Vec16;
using k_load_vec_type = vec_op::BF16Vec16; using k_load_vec_type = vec_op::BF16Vec16;
@ -41,7 +45,7 @@ template <> struct KernelVecType<c10::BFloat16> {
#endif #endif
template <typename T> template <typename T>
FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size, FORCE_INLINE std::pair<T, T> reduceSoftmax(T* data, const int size,
const int capacity) { const int capacity) {
T max = data[0]; T max = data[0];
for (int i = 1; i < size; ++i) { for (int i = 1; i < size; ++i) {
@ -67,14 +71,15 @@ FORCE_INLINE std::pair<T, T> reduceSoftmax(T *data, const int size,
} }
template <typename T> template <typename T>
FORCE_INLINE std::pair<T, T> FORCE_INLINE std::pair<T, T> reduceSoftmaxAlibi(T* data, const int size,
reduceSoftmaxAlibi(T *data, const int size, const int capacity, const int capacity,
const float alibi_slope, const int start_index, const float alibi_slope,
const int context_len) { const int start_index,
data[0] += alibi_slope * (start_index - context_len + 1); const int seq_len) {
data[0] += alibi_slope * (start_index - seq_len + 1);
T max = data[0]; T max = data[0];
for (int i = 1; i < size; ++i) { for (int i = 1; i < size; ++i) {
T qk = data[i] + alibi_slope * (start_index + i - context_len + 1); T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1);
data[i] = qk; data[i] = qk;
max = max >= qk ? max : qk; max = max >= qk ? max : qk;
} }
@ -98,7 +103,7 @@ reduceSoftmaxAlibi(T *data, const int size, const int capacity,
} }
template <typename T> template <typename T>
FORCE_INLINE void reducePartitonSoftmax(const T *max_data, T *sum_data, FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data,
const int size) { const int size) {
T max = max_data[0]; T max = max_data[0];
for (int i = 1; i < size; ++i) { for (int i = 1; i < size; ++i) {
@ -132,9 +137,9 @@ struct reduceQKBlockKernel {
static_assert(k_load_vec_type::get_elem_num() % x == 0); static_assert(k_load_vec_type::get_elem_num() % x == 0);
static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16); static_assert(q_load_vec_type::get_elem_num() * sizeof(scalar_t) == 16);
FORCE_INLINE static void call(const scalar_t *__restrict__ q, FORCE_INLINE static void call(const scalar_t* __restrict__ q,
const scalar_t *__restrict__ k_block, const scalar_t* __restrict__ k_block,
float *__restrict__ logits, float scale, float* __restrict__ logits, float scale,
const int token_num) { const int token_num) {
const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP; const int group_num = (token_num + TOKEN_PER_GROUP - 1) / TOKEN_PER_GROUP;
@ -196,8 +201,8 @@ struct reduceQKBlockKernel {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE,
int HEAD_PARTITION_SIZE, typename acc_t> int HEAD_PARTITION_SIZE, typename acc_t>
FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block, FORCE_INLINE void reduceValueBlock(const float* prob, const scalar_t* v_block,
acc_t &&acc) { acc_t&& acc) {
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type; using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
constexpr int ELEM_NUM = v_load_vec_type::get_elem_num(); constexpr int ELEM_NUM = v_load_vec_type::get_elem_num();
static_assert(BLOCK_SIZE == ELEM_NUM); static_assert(BLOCK_SIZE == ELEM_NUM);
@ -209,66 +214,65 @@ FORCE_INLINE void reduceValueBlock(const float *prob, const scalar_t *v_block,
acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec; acc[head_elem_idx] = acc[head_elem_idx] + prob_vec * fp32_v_vec;
}); });
} }
}; // namespace }; // namespace
// Paged attention v1 // Paged attention v1
namespace { namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE> template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE>
struct paged_attention_v1_impl { struct paged_attention_v1_impl {
static void static void call(
call(scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int const int* __restrict__ block_tables, // [num_seqs,
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] // max_num_blocks_per_seq]
const int *__restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads) { const int num_seqs, const int num_heads) {
constexpr int x = 16 / sizeof(scalar_t); constexpr int x = 16 / sizeof(scalar_t);
const int num_queries_per_kv = num_heads / num_kv_heads; const int num_queries_per_kv = num_heads / num_kv_heads;
static_assert(BLOCK_SIZE == 16); static_assert(BLOCK_SIZE == 16);
int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE; int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE;
int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0; int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0;
TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0); TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0);
const int parallel_work_item_num = omp_get_max_threads(); const int parallel_work_item_num = omp_get_max_threads();
size_t logits_bytes = size_t logits_bytes =
parallel_work_item_num * max_context_len_padded * sizeof(float); parallel_work_item_num * max_seq_len_padded * sizeof(float);
float *logits = (float *)std::aligned_alloc( float* logits = (float*)std::aligned_alloc(
64, logits_bytes); // Cacheline alignment for each context token. 64, logits_bytes); // Cacheline alignment for each context token.
// [parallel_work_item_num, max_context_len_padded] // [parallel_work_item_num, max_seq_len_padded]
#pragma omp parallel for collapse(2) schedule(dynamic, 1) #pragma omp parallel for collapse(2) schedule(dynamic, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
int context_len = context_lens[seq_idx]; int seq_len = seq_lens[seq_idx];
const int *seq_block_table = const int* seq_block_table =
block_tables + max_num_blocks_per_seq * seq_idx; block_tables + max_num_blocks_per_seq * seq_idx;
const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv; const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t *__restrict__ q_vec_ptr = const scalar_t* __restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE; q + seq_idx * q_stride + head_idx * HEAD_SIZE;
const int last_block_token_num = const int last_block_token_num = seq_len - (block_num - 1) * BLOCK_SIZE;
context_len - (block_num - 1) * BLOCK_SIZE; float* __restrict__ thread_block_logits =
float *__restrict__ thread_block_logits = logits + omp_get_thread_num() * max_seq_len_padded;
logits + omp_get_thread_num() * max_context_len_padded;
// Compute logits // Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx]; const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t *__restrict__ k_block_cache_ptr = const scalar_t* __restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride + k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
float *__restrict__ head_block_logits = float* __restrict__ head_block_logits =
thread_block_logits + block_idx * BLOCK_SIZE; thread_block_logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call( reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
@ -278,12 +282,11 @@ struct paged_attention_v1_impl {
// Compute softmax // Compute softmax
if (alibi_slopes) { if (alibi_slopes) {
reduceSoftmaxAlibi(thread_block_logits, context_len, reduceSoftmaxAlibi(thread_block_logits, seq_len,
block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0,
context_len); seq_len);
} else { } else {
reduceSoftmax(thread_block_logits, context_len, reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE);
block_num * BLOCK_SIZE);
} }
// Compute value // Compute value
@ -293,14 +296,14 @@ struct paged_attention_v1_impl {
for (int head_part_idx = 0; head_part_idx < head_partition_num; for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) { ++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition]; vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t *__restrict__ out_ptr = scalar_t* __restrict__ out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
head_part_idx * head_elem_num_per_partition; head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx]; const int64_t physical_block_idx = seq_block_table[block_idx];
const float *__restrict__ prob_vec_ptr = const float* __restrict__ prob_vec_ptr =
thread_block_logits + block_idx * BLOCK_SIZE; thread_block_logits + block_idx * BLOCK_SIZE;
const scalar_t *__restrict__ v_block_cache_ptr = const scalar_t* __restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride + v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
@ -311,7 +314,7 @@ struct paged_attention_v1_impl {
if (block_idx != block_num - 1) { if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx = const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1]; seq_block_table[block_idx + 1];
const scalar_t *__restrict__ next_v_block_cache_ptr = const scalar_t* __restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride + v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
@ -340,16 +343,16 @@ struct paged_attention_v1_impl {
#define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \ paged_attention_v1_impl<T, HEAD_SIZE, BLOCK_SIZE>::call( \
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \
num_heads); num_heads);
template <typename T, int BLOCK_SIZE> template <typename T, int BLOCK_SIZE>
void paged_attention_v1_impl_launcher( void paged_attention_v1_impl_launcher(
torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &context_lens, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) { const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
@ -359,67 +362,73 @@ void paged_attention_v1_impl_launcher(
int kv_head_stride = key_cache.stride(1); int kv_head_stride = key_cache.stride(1);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
const float *alibi_slopes_ptr = const float* alibi_slopes_ptr =
alibi_slopes alibi_slopes
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
T *out_ptr = reinterpret_cast<T *>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T *query_ptr = reinterpret_cast<T *>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr()); T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int *context_lens_ptr = context_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) { switch (head_size) {
case 64: case 64:
LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break; break;
case 80: case 80:
LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break; break;
case 96: case 96:
LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break; break;
case 112: case 112:
LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break; break;
case 128: case 128:
LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break; break;
case 256: case 192:
LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); LAUNCH_V1_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
break; break;
default: case 256:
TORCH_CHECK(false, "Unsupported head size: ", head_size); LAUNCH_V1_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break; break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
} }
} }
#define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ #define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \ paged_attention_v1_impl_launcher<T, BLOCK_SIZE>( \
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
context_lens, max_context_len, alibi_slopes); seq_lens, max_seq_len, alibi_slopes);
#define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \ switch (block_size) { \
case 16: \ case 16: \
CALL_V1_KERNEL_LAUNCHER(T, 16); \ CALL_V1_KERNEL_LAUNCHER(T, 16); \
break; \ break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \ break; \
} }
} // namespace } // namespace
void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, void paged_attention_v1(
torch::Tensor &key_cache, torch::Tensor &value_cache, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
torch::Tensor &context_lens, int block_size, int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
int max_context_len, const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const c10::optional<torch::Tensor> &alibi_slopes, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const std::string &kv_cache_dtype) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
[&] { [&] {
CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
@ -433,23 +442,24 @@ namespace {
template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE> template <typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int PARTITION_SIZE>
struct paged_attention_v2_impl { struct paged_attention_v2_impl {
static void call( static void call(
scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
float *__restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads,
float // max_num_partitions]
*__restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads,
scalar_t *__restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions]
// max_num_partitions, head_size] scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] // max_num_partitions, head_size]
const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
// head_size/x, block_size, x] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x]
// head_size, block_size] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale, const int num_kv_heads, const float scale,
const int const int* __restrict__ block_tables, // [num_seqs,
*__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] // max_num_blocks_per_seq]
const int *__restrict__ context_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float *__restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const int num_seqs, const int num_heads, const int max_num_partitions) { const int num_seqs, const int num_heads, const int max_num_partitions) {
constexpr int x = 16 / sizeof(scalar_t); constexpr int x = 16 / sizeof(scalar_t);
@ -464,27 +474,25 @@ struct paged_attention_v2_impl {
for (int partition_idx = 0; partition_idx < max_num_partitions; for (int partition_idx = 0; partition_idx < max_num_partitions;
++partition_idx) { ++partition_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int start_token_idx = partition_idx * PARTITION_SIZE; const int start_token_idx = partition_idx * PARTITION_SIZE;
if (start_token_idx >= context_len) if (start_token_idx >= seq_len) continue;
continue;
const int partition_num = const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
const bool no_reduce = (partition_num == 1); const bool no_reduce = (partition_num == 1);
const int context_token_num = const int token_num =
(std::min(context_len, start_token_idx + PARTITION_SIZE) - (std::min(seq_len, start_token_idx + PARTITION_SIZE) -
start_token_idx); start_token_idx);
const int block_num = const int block_num = (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
(context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE;
const int last_block_token_num = const int last_block_token_num =
context_token_num - (block_num - 1) * BLOCK_SIZE; token_num - (block_num - 1) * BLOCK_SIZE;
const int *seq_block_table = block_tables + const int* seq_block_table = block_tables +
max_num_blocks_per_seq * seq_idx + max_num_blocks_per_seq * seq_idx +
start_token_idx / BLOCK_SIZE; start_token_idx / BLOCK_SIZE;
const int64_t kv_head_idx = head_idx / num_queries_per_kv; const int64_t kv_head_idx = head_idx / num_queries_per_kv;
const scalar_t *__restrict__ q_vec_ptr = const scalar_t* __restrict__ q_vec_ptr =
q + seq_idx * q_stride + head_idx * HEAD_SIZE; q + seq_idx * q_stride + head_idx * HEAD_SIZE;
float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0}; float logits[PARTITION_SIZE] __attribute__((aligned(64))) = {0};
@ -492,10 +500,10 @@ struct paged_attention_v2_impl {
// Compute logits // Compute logits
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx]; const int64_t physical_block_idx = seq_block_table[block_idx];
const scalar_t *__restrict__ k_block_cache_ptr = const scalar_t* __restrict__ k_block_cache_ptr =
k_cache + physical_block_idx * kv_block_stride + k_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
float *__restrict__ head_block_logits = float* __restrict__ head_block_logits =
logits + block_idx * BLOCK_SIZE; logits + block_idx * BLOCK_SIZE;
reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call( reduceQKBlockKernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, x>::call(
@ -506,16 +514,16 @@ struct paged_attention_v2_impl {
std::pair<float, float> max_and_sum; std::pair<float, float> max_and_sum;
if (alibi_slopes) { if (alibi_slopes) {
max_and_sum = reduceSoftmaxAlibi( max_and_sum = reduceSoftmaxAlibi(
logits, context_token_num, block_num * BLOCK_SIZE, logits, token_num, block_num * BLOCK_SIZE,
alibi_slopes[head_idx], start_token_idx, context_len); alibi_slopes[head_idx], start_token_idx, seq_len);
} else { } else {
max_and_sum = reduceSoftmax(logits, context_token_num, max_and_sum =
block_num * BLOCK_SIZE); reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE);
} }
auto &&[max_logit, exp_sum] = max_and_sum; auto&& [max_logit, exp_sum] = max_and_sum;
scalar_t *__restrict__ output_buffer = nullptr; scalar_t* __restrict__ output_buffer = nullptr;
if (!no_reduce) { if (!no_reduce) {
auto idx = seq_idx * num_heads * max_num_partitions + auto idx = seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx; head_idx * max_num_partitions + partition_idx;
@ -537,13 +545,13 @@ struct paged_attention_v2_impl {
for (int head_part_idx = 0; head_part_idx < head_partition_num; for (int head_part_idx = 0; head_part_idx < head_partition_num;
++head_part_idx) { ++head_part_idx) {
vec_op::FP32Vec16 accums[head_elem_num_per_partition]; vec_op::FP32Vec16 accums[head_elem_num_per_partition];
scalar_t *__restrict__ out_ptr = scalar_t* __restrict__ out_ptr =
output_buffer + head_part_idx * head_elem_num_per_partition; output_buffer + head_part_idx * head_elem_num_per_partition;
for (int block_idx = 0; block_idx < block_num; ++block_idx) { for (int block_idx = 0; block_idx < block_num; ++block_idx) {
const int64_t physical_block_idx = seq_block_table[block_idx]; const int64_t physical_block_idx = seq_block_table[block_idx];
const float *__restrict__ prob_vec_ptr = const float* __restrict__ prob_vec_ptr =
logits + block_idx * BLOCK_SIZE; logits + block_idx * BLOCK_SIZE;
const scalar_t *__restrict__ v_block_cache_ptr = const scalar_t* __restrict__ v_block_cache_ptr =
v_cache + physical_block_idx * kv_block_stride + v_cache + physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
@ -554,7 +562,7 @@ struct paged_attention_v2_impl {
if (block_idx != block_num - 1) { if (block_idx != block_num - 1) {
const int64_t next_physical_block_idx = const int64_t next_physical_block_idx =
seq_block_table[block_idx + 1]; seq_block_table[block_idx + 1];
const scalar_t *__restrict__ next_v_block_cache_ptr = const scalar_t* __restrict__ next_v_block_cache_ptr =
v_cache + next_physical_block_idx * kv_block_stride + v_cache + next_physical_block_idx * kv_block_stride +
kv_head_idx * kv_head_stride + kv_head_idx * kv_head_stride +
BLOCK_SIZE * head_part_idx * head_elem_num_per_partition; BLOCK_SIZE * head_part_idx * head_elem_num_per_partition;
@ -582,12 +590,11 @@ struct paged_attention_v2_impl {
#pragma omp parallel for collapse(2) schedule(static, 1) #pragma omp parallel for collapse(2) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int partition_num = const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) if (partition_num == 1) continue;
continue;
reducePartitonSoftmax( reducePartitonSoftmax(
max_logits + seq_idx * num_heads * max_num_partitions + max_logits + seq_idx * num_heads * max_num_partitions +
@ -602,30 +609,29 @@ struct paged_attention_v2_impl {
using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type; using v_load_vec_type = typename KernelVecType<scalar_t>::v_load_vec_type;
static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE); static_assert(v_load_vec_type::get_elem_num() == BLOCK_SIZE);
constexpr int head_elem_num_per_group = constexpr int head_elem_num_per_group =
16; // Note: didn't align with the cacheline size, due to some HEAD_SIZE 16; // Note: didn't align with the cacheline size, due to some
// didn't align with 64 bytes // HEAD_SIZE didn't align with 64 bytes
static_assert(HEAD_SIZE % head_elem_num_per_group == 0); static_assert(HEAD_SIZE % head_elem_num_per_group == 0);
constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group; constexpr int head_group_num = HEAD_SIZE / head_elem_num_per_group;
const float *__restrict__ rescale_factors = exp_sums; const float* __restrict__ rescale_factors = exp_sums;
#pragma omp parallel for collapse(3) schedule(static, 1) #pragma omp parallel for collapse(3) schedule(static, 1)
for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) {
for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) {
for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { for (int group_idx = 0; group_idx < head_group_num; ++group_idx) {
const int context_len = context_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
const int partition_num = const int partition_num =
(context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE;
if (partition_num == 1) if (partition_num == 1) continue;
continue;
const float *__restrict__ seq_head_rescale_factors = const float* __restrict__ seq_head_rescale_factors =
rescale_factors + seq_idx * num_heads * max_num_partitions + rescale_factors + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions; head_idx * max_num_partitions;
const scalar_t *__restrict__ seq_head_tmp_out = const scalar_t* __restrict__ seq_head_tmp_out =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + head_idx * max_num_partitions * HEAD_SIZE +
group_idx * head_elem_num_per_group; group_idx * head_elem_num_per_group;
scalar_t *__restrict__ seq_head_output = scalar_t* __restrict__ seq_head_output =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE +
group_idx * head_elem_num_per_group; group_idx * head_elem_num_per_group;
@ -644,21 +650,21 @@ struct paged_attention_v2_impl {
} }
}; };
#define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ #define LAUNCH_V2_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \
paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \ paged_attention_v2_impl<T, HEAD_SIZE, BLOCK_SIZE, PARTITION_SIZE>::call( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \
key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, num_seqs, num_heads, \ kv_block_stride, kv_head_stride, num_seqs, num_heads, \
max_num_partitions); max_num_partitions);
template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512> template <typename T, int BLOCK_SIZE, int PARTITION_SIZE = 512>
void paged_attention_v2_impl_launcher( void paged_attention_v2_impl_launcher(
torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
int max_context_len, const c10::optional<torch::Tensor> &alibi_slopes) { int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
@ -669,72 +675,78 @@ void paged_attention_v2_impl_launcher(
int max_num_partitions = exp_sums.size(-1); int max_num_partitions = exp_sums.size(-1);
// NOTE: alibi_slopes is optional. // NOTE: alibi_slopes is optional.
const float *alibi_slopes_ptr = const float* alibi_slopes_ptr =
alibi_slopes alibi_slopes
? reinterpret_cast<const float *>(alibi_slopes.value().data_ptr()) ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr; : nullptr;
T *out_ptr = reinterpret_cast<T *>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float *exp_sums_ptr = reinterpret_cast<float *>(exp_sums.data_ptr()); float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float *max_logits_ptr = reinterpret_cast<float *>(max_logits.data_ptr()); float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T *tmp_out_ptr = reinterpret_cast<T *>(tmp_out.data_ptr()); T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T *query_ptr = reinterpret_cast<T *>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T *key_cache_ptr = reinterpret_cast<T *>(key_cache.data_ptr()); T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T *value_cache_ptr = reinterpret_cast<T *>(value_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int *block_tables_ptr = block_tables.data_ptr<int>(); int* block_tables_ptr = block_tables.data_ptr<int>();
int *context_lens_ptr = context_lens.data_ptr<int>(); int* seq_lens_ptr = seq_lens.data_ptr<int>();
switch (head_size) { switch (head_size) {
case 64: case 64:
LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 64, BLOCK_SIZE);
break; break;
case 80: case 80:
LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 80, BLOCK_SIZE);
break; break;
case 96: case 96:
LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 96, BLOCK_SIZE);
break; break;
case 112: case 112:
LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 112, BLOCK_SIZE);
break; break;
case 128: case 128:
LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 128, BLOCK_SIZE);
break; break;
case 256: case 192:
LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE); LAUNCH_V2_ATTENTION_KERNEL(T, 192, BLOCK_SIZE);
break; break;
default: case 256:
TORCH_CHECK(false, "Unsupported head size: ", head_size); LAUNCH_V2_ATTENTION_KERNEL(T, 256, BLOCK_SIZE);
break; break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
} }
} }
#define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ #define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \ paged_attention_v2_impl_launcher<T, BLOCK_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, block_size, \ num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, \
max_context_len, alibi_slopes); alibi_slopes);
#define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ #define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \ switch (block_size) { \
case 16: \ case 16: \
CALL_V2_KERNEL_LAUNCHER(T, 16); \ CALL_V2_KERNEL_LAUNCHER(T, 16); \
break; \ break; \
default: \ default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \ break; \
} }
} // namespace } // namespace
void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, void paged_attention_v2(
torch::Tensor &max_logits, torch::Tensor &tmp_out, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor &value_cache, int num_kv_heads, torch::Tensor& value_cache, int num_kv_heads, float scale,
float scale, torch::Tensor &block_tables, torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
torch::Tensor &context_lens, int block_size, int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
int max_context_len, const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
const c10::optional<torch::Tensor> &alibi_slopes, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const std::string &kv_cache_dtype) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
TORCH_CHECK(kv_scale == 1.0f);
TORCH_CHECK(blocksparse_vert_stride <= 1,
"CPU backend does not support blocksparse attention yet.");
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
[&] { [&] {
CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)

View File

@ -5,25 +5,26 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void copy_blocks_cpu_impl( void copy_blocks_cpu_impl(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor> &key_caches, std::vector<torch::Tensor>& value_caches,
std::vector<torch::Tensor> &value_caches, const torch::Tensor& mapping_pairs,
const std::vector<std::pair<int64_t, int64_t>> mapping_pairs, const int element_num_per_block,
const int element_num_per_block, const int layer_num) { const int layer_num) {
const size_t pair_num = mapping_pairs.size(); const size_t pair_num = mapping_pairs.size(0);
const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; const size_t block_bytes = sizeof(scalar_t) * element_num_per_block;
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int layer = 0; layer < layer_num; ++layer) { for (int layer = 0; layer < layer_num; ++layer) {
for (size_t pair = 0; pair < pair_num; ++pair) { for (size_t pair = 0; pair < pair_num; ++pair) {
int64_t source_offset = element_num_per_block * mapping_pairs[pair].first; int64_t source_offset =
element_num_per_block * mapping_pairs[pair][0].item<int64_t>();
int64_t target_offset = int64_t target_offset =
element_num_per_block * mapping_pairs[pair].second; element_num_per_block * mapping_pairs[pair][1].item<int64_t>();
scalar_t *key_cache_ptr = key_caches[layer].data_ptr<scalar_t>(); scalar_t* key_cache_ptr = key_caches[layer].data_ptr<scalar_t>();
scalar_t *source_ptr = key_cache_ptr + source_offset; scalar_t* source_ptr = key_cache_ptr + source_offset;
scalar_t *target_ptr = key_cache_ptr + target_offset; scalar_t* target_ptr = key_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes); std::memcpy(target_ptr, source_ptr, block_bytes);
scalar_t *value_cache_ptr = value_caches[layer].data_ptr<scalar_t>(); scalar_t* value_cache_ptr = value_caches[layer].data_ptr<scalar_t>();
source_ptr = value_cache_ptr + source_offset; source_ptr = value_cache_ptr + source_offset;
target_ptr = value_cache_ptr + target_offset; target_ptr = value_cache_ptr + target_offset;
std::memcpy(target_ptr, source_ptr, block_bytes); std::memcpy(target_ptr, source_ptr, block_bytes);
@ -33,9 +34,9 @@ void copy_blocks_cpu_impl(
template <typename scalar_t> template <typename scalar_t>
void reshape_and_cache_cpu_impl( void reshape_and_cache_cpu_impl(
const scalar_t *__restrict__ key, const scalar_t *__restrict__ value, const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t *__restrict__ key_cache, scalar_t *__restrict__ value_cache, scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t *__restrict__ slot_mapping, const int num_tokens, const int64_t* __restrict__ slot_mapping, const int num_tokens,
const int key_stride, const int value_stride, const int num_heads, const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x) { const int head_size, const int block_size, const int x) {
const int block_elem_num = num_heads * head_size * block_size; const int block_elem_num = num_heads * head_size * block_size;
@ -48,14 +49,14 @@ void reshape_and_cache_cpu_impl(
int src_key_head_idx = token_idx * key_stride + head_idx * head_size; int src_key_head_idx = token_idx * key_stride + head_idx * head_size;
int src_value_head_idx = int src_value_head_idx =
token_idx * value_stride + head_idx * head_size; token_idx * value_stride + head_idx * head_size;
const scalar_t *src_key_head_ptr = key + src_key_head_idx; const scalar_t* src_key_head_ptr = key + src_key_head_idx;
const scalar_t *src_value_head_ptr = value + src_value_head_idx; const scalar_t* src_value_head_ptr = value + src_value_head_idx;
const int64_t block_index = slot_idx / block_size; const int64_t block_index = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size; const int64_t block_offset = slot_idx % block_size;
scalar_t *target_key_head_ptr = key_cache + scalar_t* target_key_head_ptr = key_cache +
block_elem_num * block_index + block_elem_num * block_index +
head_idx * block_size * head_size; head_idx * block_size * head_size;
scalar_t *target_value_head_ptr = value_cache + scalar_t* target_value_head_ptr = value_cache +
block_elem_num * block_index + block_elem_num * block_index +
head_idx * block_size * head_size; head_idx * block_size * head_size;
@ -79,39 +80,33 @@ void reshape_and_cache_cpu_impl(
} }
} }
} }
}; // namespace }; // namespace
void copy_blocks(std::vector<torch::Tensor> &key_caches, void copy_blocks(std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor> &value_caches, std::vector<torch::Tensor>& value_caches,
const std::map<int64_t, std::vector<int64_t>> &block_mapping) { const torch::Tensor& block_mapping) {
int num_layers = key_caches.size(); unsigned num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size()); TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) { if (num_layers == 0) {
return; return;
} }
std::vector<std::pair<int64_t, int64_t>> mapping_pairs;
mapping_pairs.reserve(block_mapping.size());
for (const auto &pair : block_mapping) {
for (const auto &dst : pair.second) {
mapping_pairs.emplace_back(pair.first, dst);
}
}
const int element_num_per_block = key_caches[0][0].numel(); const int element_num_per_block = key_caches[0][0].numel();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl)
copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, mapping_pairs, copy_blocks_cpu_impl<scalar_t>(key_caches, value_caches, block_mapping,
element_num_per_block, num_layers); element_num_per_block, num_layers);
CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl)
}); });
} }
void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor &key_cache, torch::Tensor &value_cache, torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor &slot_mapping, torch::Tensor& slot_mapping,
const std::string &kv_cache_dtype) { const std::string& kv_cache_dtype, float kv_scale) {
TORCH_CHECK(kv_scale == 1.0f);
int num_tokens = key.size(0); int num_tokens = key.size(0);
int num_heads = key.size(1); int num_heads = key.size(1);
int head_size = key.size(2); int head_size = key.size(2);
@ -133,7 +128,7 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
}); });
} }
void swap_blocks(torch::Tensor &src, torch::Tensor &dst, void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
const std::map<int64_t, int64_t> &block_mapping) { const torch::Tensor& block_mapping) {
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
} }

View File

@ -2,10 +2,10 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void rms_norm_impl(scalar_t *__restrict__ out, void rms_norm_impl(scalar_t* __restrict__ out,
const scalar_t *__restrict__ input, const scalar_t* __restrict__ input,
const scalar_t *__restrict__ weight, const float epsilon, const scalar_t* __restrict__ weight, const float epsilon,
const int num_tokens, const int hidden_size) { const int num_tokens, const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out,
} }
template <typename scalar_t> template <typename scalar_t>
void fused_add_rms_norm_impl(scalar_t *__restrict__ input, void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
scalar_t *__restrict__ residual, scalar_t* __restrict__ residual,
const scalar_t *__restrict__ weight, const scalar_t* __restrict__ weight,
const float epsilon, const int num_tokens, const float epsilon, const int num_tokens,
const int hidden_size) { const int hidden_size) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0); TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
} }
} }
} }
} // namespace } // namespace
void rms_norm(torch::Tensor &out, torch::Tensor &input, void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor &weight, float epsilon) { float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] { VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
CPU_KERNEL_GUARD_IN(rms_norm_impl) CPU_KERNEL_GUARD_IN(rms_norm_impl)
rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), epsilon, num_tokens, weight.data_ptr<scalar_t>(), epsilon, num_tokens,
hidden_size); hidden_size);
CPU_KERNEL_GUARD_OUT(rms_norm_impl) CPU_KERNEL_GUARD_OUT(rms_norm_impl)
}); });
} }
void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor &weight, float epsilon) { torch::Tensor& weight, float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;

View File

@ -4,107 +4,107 @@
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
void rotary_embedding_impl( void rotary_embedding_impl(
const int64_t const int64_t* __restrict__ positions, // [batch_size, seq_len] or
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens] // [num_tokens]
scalar_t scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or /// head_size] or [num_tokens, num_heads,
/// [num_tokens, num_heads, head_size] /// head_size]
scalar_t scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or // head_size] or [num_tokens, num_kv_heads,
// [num_tokens, num_kv_heads, head_size] // head_size]
const scalar_t const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size, const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) { const int num_tokens) {
using scalar_vec_t = vec_op::vec_t<scalar_t>; using scalar_vec_t = vec_op::vec_t<scalar_t>;
constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num(); constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
constexpr int ELEM_SIZE = sizeof(scalar_t);
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
TORCH_CHECK(embed_dim % VEC_ELEM_NUM == 0); bool flag = (embed_dim % VEC_ELEM_NUM == 0);
const int loop_upper = flag ? embed_dim : embed_dim - VEC_ELEM_NUM;
auto compute_loop = [&](const int64_t token_head, const scalar_t* cache_ptr,
scalar_t* qk) {
int j = 0;
for (; j < loop_upper; j += VEC_ELEM_NUM) {
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const scalar_vec_t cos(cache_ptr + x_index);
const scalar_vec_t sin(cache_ptr + y_index);
const scalar_vec_t q_x(qk + out_x);
const scalar_vec_t q_y(qk + out_y);
vec_op::FP32Vec8 fp32_cos(cos);
vec_op::FP32Vec8 fp32_sin(sin);
vec_op::FP32Vec8 fp32_q_x(q_x);
vec_op::FP32Vec8 fp32_q_y(q_y);
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
scalar_vec_t(out1).save(qk + out_x);
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
scalar_vec_t(out2).save(qk + out_y);
}
if (!flag) {
for (; j < embed_dim; ++j) {
const int x_index = j;
const int y_index = embed_dim + j;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const float fp32_cos = cache_ptr[x_index];
const float fp32_sin = cache_ptr[y_index];
const float fp32_q_x = qk[out_x];
const float fp32_q_y = qk[out_y];
qk[out_x] = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
qk[out_y] = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
}
}
};
#pragma omp parallel for #pragma omp parallel for
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
for (int i = 0; i < num_heads; ++i) { for (int i = 0; i < num_heads; ++i) {
const int head_idx = i; const int head_idx = i;
const int64_t token_head = const int64_t token_head =
token_idx * query_stride + head_idx * head_size; token_idx * query_stride + head_idx * head_size;
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { compute_loop(token_head, cache_ptr, query);
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const scalar_vec_t cos(cache_ptr + x_index);
const scalar_vec_t sin(cache_ptr + y_index);
const scalar_vec_t q_x(query + out_x);
const scalar_vec_t q_y(query + out_y);
vec_op::FP32Vec8 fp32_cos(cos);
vec_op::FP32Vec8 fp32_sin(sin);
vec_op::FP32Vec8 fp32_q_x(q_x);
vec_op::FP32Vec8 fp32_q_y(q_y);
auto out1 = fp32_q_x * fp32_cos - fp32_q_y * fp32_sin;
scalar_vec_t(out1).save(query + out_x);
auto out2 = fp32_q_y * fp32_cos + fp32_q_x * fp32_sin;
scalar_vec_t(out2).save(query + out_y);
}
} }
for (int i = 0; i < num_kv_heads; ++i) { for (int i = 0; i < num_kv_heads; ++i) {
const int head_idx = i; const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int64_t token_head = token_idx * key_stride + head_idx * head_size;
for (int j = 0; j < embed_dim; j += VEC_ELEM_NUM) { compute_loop(token_head, cache_ptr, key);
const int rot_offset = j;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int64_t out_x = token_head + x_index;
const int64_t out_y = token_head + y_index;
const scalar_vec_t cos(cache_ptr + x_index);
const scalar_vec_t sin(cache_ptr + y_index);
const scalar_vec_t k_x(key + out_x);
const scalar_vec_t k_y(key + out_y);
vec_op::FP32Vec8 fp32_cos(cos);
vec_op::FP32Vec8 fp32_sin(sin);
vec_op::FP32Vec8 fp32_k_x(k_x);
vec_op::FP32Vec8 fp32_k_y(k_y);
auto out1 = fp32_k_x * fp32_cos - fp32_k_y * fp32_sin;
scalar_vec_t(out1).save(key + out_x);
auto out2 = fp32_k_y * fp32_cos + fp32_k_x * fp32_sin;
scalar_vec_t(out2).save(key + out_y);
}
} }
} }
} }
template <typename scalar_t> template <typename scalar_t>
void rotary_embedding_gptj_impl( void rotary_embedding_gptj_impl(
const int64_t const int64_t* __restrict__ positions, // [batch_size, seq_len] or
*__restrict__ positions, // [batch_size, seq_len] or [num_tokens] // [num_tokens]
scalar_t scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
*__restrict__ query, /// [batch_size, seq_len, num_heads, head_size] or /// head_size] or [num_tokens, num_heads,
/// [num_tokens, num_heads, head_size] /// head_size]
scalar_t scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
*__restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or // head_size] or [num_tokens, num_kv_heads,
// [num_tokens, num_kv_heads, head_size] // head_size]
const scalar_t const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
*__restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride, const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size, const int num_heads, const int num_kv_heads, const int head_size,
const int num_tokens) { const int num_tokens) {
@ -114,13 +114,13 @@ void rotary_embedding_gptj_impl(
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_heads; ++i) { for (int i = 0; i < num_heads; ++i) {
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t *cos_cache_ptr = cache_ptr; const scalar_t* cos_cache_ptr = cache_ptr;
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i; const int head_idx = i;
const int64_t token_head = const int64_t token_head =
token_idx * query_stride + head_idx * head_size; token_idx * query_stride + head_idx * head_size;
scalar_t *head_query = token_head + query; scalar_t* head_query = token_head + query;
for (int j = 0; j < embed_dim; j += 1) { for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j; const int rot_offset = j;
const int x_index = 2 * rot_offset; const int x_index = 2 * rot_offset;
@ -142,12 +142,12 @@ void rotary_embedding_gptj_impl(
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_kv_heads; ++i) { for (int i = 0; i < num_kv_heads; ++i) {
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t *cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const scalar_t *cos_cache_ptr = cache_ptr; const scalar_t* cos_cache_ptr = cache_ptr;
const scalar_t *sin_cache_ptr = cache_ptr + embed_dim; const scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
const int head_idx = i; const int head_idx = i;
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int64_t token_head = token_idx * key_stride + head_idx * head_size;
scalar_t *head_key = key + token_head; scalar_t* head_key = key + token_head;
for (int j = 0; j < embed_dim; j += 1) { for (int j = 0; j < embed_dim; j += 1) {
const int rot_offset = j; const int rot_offset = j;
const int x_index = 2 * rot_offset; const int x_index = 2 * rot_offset;
@ -165,11 +165,11 @@ void rotary_embedding_gptj_impl(
} }
} }
} }
}; // namespace }; // namespace
void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor &key, int head_size, torch::Tensor& key, int head_size,
torch::Tensor &cos_sin_cache, bool is_neox) { torch::Tensor& cos_sin_cache, bool is_neox) {
int num_tokens = query.numel() / query.size(-1); int num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size; int num_heads = query.size(-1) / head_size;

View File

@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); pybind11::module ops = m.def_submodule("ops", "vLLM custom operators");
// Attention ops // Attention ops
ops.def( ops.def("paged_attention_v1", &paged_attention_v1,
"paged_attention_v1", "Compute the attention between an input query and the cached "
&paged_attention_v1, "keys/values using PagedAttention.");
"Compute the attention between an input query and the cached keys/values using PagedAttention."); ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
ops.def(
"paged_attention_v2",
&paged_attention_v2,
"PagedAttention V2.");
// Activation ops // Activation ops
ops.def( ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
"silu_and_mul", ops.def("gelu_and_mul", &gelu_and_mul,
&silu_and_mul, "Activation function used in GeGLU with `none` approximation.");
"Activation function used in SwiGLU."); ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
ops.def( "Activation function used in GeGLU with `tanh` approximation.");
"gelu_and_mul", ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
&gelu_and_mul, ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
"Activation function used in GeGLU with `none` approximation.");
ops.def(
"gelu_tanh_and_mul",
&gelu_tanh_and_mul,
"Activation function used in GeGLU with `tanh` approximation.");
ops.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
ops.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
// Layernorm // Layernorm
ops.def( ops.def("rms_norm", &rms_norm,
"rms_norm", "Apply Root Mean Square (RMS) Normalization to the input tensor.");
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
ops.def( ops.def("fused_add_rms_norm", &fused_add_rms_norm,
"fused_add_rms_norm", "In-place fused Add and RMS Normalization");
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
// Rotary embedding // Rotary embedding
ops.def( ops.def("rotary_embedding", &rotary_embedding,
"rotary_embedding", "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
// Cache ops // Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def( cache_ops.def("swap_blocks", &swap_blocks,
"swap_blocks", "Swap in (out) the cache blocks from src to dst");
&swap_blocks, cache_ops.def("copy_blocks", &copy_blocks,
"Swap in (out) the cache blocks from src to dst"); "Copy the cache blocks from src to dst");
cache_ops.def( cache_ops.def("reshape_and_cache", &reshape_and_cache,
"copy_blocks", "Reshape the key and value tensors and cache them");
&copy_blocks,
"Copy the cache blocks from src to dst");
cache_ops.def(
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");
} }

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
@ -17,9 +17,14 @@
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
__shfl_xor_sync(uint32_t(-1), var, lane_mask, width)
#else #else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \
__shfl_xor(var, lane_mask, width)
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
@ -28,6 +33,13 @@
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
#endif #endif
#ifndef USE_ROCM
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
__shfl_down_sync(uint32_t(-1), var, lane_delta)
#else
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
#endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
@ -35,4 +47,3 @@
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif #endif

View File

@ -2,9 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
int get_device_attribute( int get_device_attribute(int attribute, int device_id);
int attribute,
int device_id);
int get_max_shared_memory_per_block_device_attribute( int get_max_shared_memory_per_block_device_attribute(int device_id);
int device_id);

View File

@ -2,34 +2,28 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#endif #endif
int get_device_attribute( int get_device_attribute(int attribute, int device_id) {
int attribute, int device, value;
int device_id) if (device_id < 0) {
{ cudaGetDevice(&device);
int device, value; } else {
if (device_id < 0) { device = device_id;
cudaGetDevice(&device); }
} cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
else { device);
device = device_id; return value;
}
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
return value;
} }
int get_max_shared_memory_per_block_device_attribute(int device_id) {
int get_max_shared_memory_per_block_device_attribute( int attribute;
int device_id) // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
{ // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
int attribute;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
#ifdef USE_ROCM #ifdef USE_ROCM
attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
#else #else
attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
#endif #endif
return get_device_attribute(attribute, device_id); return get_device_attribute(attribute, device_id);
} }

View File

@ -7,11 +7,11 @@
// fake pointer type // fake pointer type
using fptr_t = uint64_t; using fptr_t = uint64_t;
static_assert(sizeof(void *) == sizeof(fptr_t)); static_assert(sizeof(void*) == sizeof(fptr_t));
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets, int rank, const std::vector<int64_t>& offsets, int rank,
bool full_nvlink) { bool full_nvlink) {
int world_size = offsets.size(); int world_size = offsets.size();
if (world_size > 8) if (world_size > 8)
@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
} }
return (fptr_t) new vllm::CustomAllreduce( return (fptr_t) new vllm::CustomAllreduce(
reinterpret_cast<vllm::Signal *>(meta.data_ptr()), rank_data.data_ptr(), reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
} }
@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
* 5. A[None].expand(2, -1, -1, -1): Not OK * 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK * 6. A[:, 1:, 1:]: Not OK
*/ */
bool _is_weak_contiguous(torch::Tensor &t) { bool _is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() || return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() == (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
t.numel() * t.element_size()); t.numel() * t.element_size());
} }
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
bool full_nvlink) { bool full_nvlink) {
auto inp_size = inp.numel() * inp.element_size(); auto inp_size = inp.numel() * inp.element_size();
// custom allreduce requires input byte size to be multiples of 16 // custom allreduce requires input byte size to be multiples of 16
@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
return false; return false;
} }
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
cudaStream_t stream) { cudaStream_t stream) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
TORCH_CHECK(_is_weak_contiguous(out)); TORCH_CHECK(_is_weak_contiguous(out));
switch (out.scalar_type()) { switch (out.scalar_type()) {
case at::ScalarType::Float: { case at::ScalarType::Float: {
fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()), fa->allreduce<float>(stream, reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float *>(out.data_ptr()), reinterpret_cast<float*>(out.data_ptr()),
out.numel()); out.numel());
break; break;
} }
case at::ScalarType::Half: { case at::ScalarType::Half: {
fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()), fa->allreduce<half>(stream, reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half *>(out.data_ptr()), reinterpret_cast<half*>(out.data_ptr()), out.numel());
out.numel());
break; break;
} }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: { case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>( fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()), stream, reinterpret_cast<nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel()); reinterpret_cast<nv_bfloat16*>(out.data_ptr()), out.numel());
break; break;
} }
#endif #endif
@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
} }
} }
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream(); auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
_all_reduce(_fa, inp, out, stream); _all_reduce(_fa, inp, out, stream);
} }
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer, void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor &out) { torch::Tensor& out) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream(); auto stream = c10::cuda::getCurrentCUDAStream().stream();
@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer,
} }
void dispose(fptr_t _fa) { void dispose(fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
delete fa; delete fa;
} }
int meta_size() { return sizeof(vllm::Signal); } int meta_size() { return sizeof(vllm::Signal); }
void register_buffer(fptr_t _fa, torch::Tensor &t, void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets) { const std::vector<int64_t>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr()); fa->register_buffer(handles, offsets, t.data_ptr());
} }
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta( std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa) { fptr_t _fa) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
return fa->get_graph_buffer_ipc_meta(); return fa->get_graph_buffer_ipc_meta();
} }
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles, void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>> &offsets) { const std::vector<std::vector<int64_t>>& offsets) {
auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_graph_buffers(handles, offsets); fa->register_graph_buffers(handles, offsets);
} }

View File

@ -31,9 +31,9 @@ struct Signal {
alignas(128) uint32_t end[kMaxBlocks][8]; alignas(128) uint32_t end[kMaxBlocks][8];
}; };
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };
struct __align__(16) RankSignals { volatile Signal *signals[8]; }; struct __align__(16) RankSignals { volatile Signal* signals[8]; };
// like std::array, but aligned // like std::array, but aligned
template <typename T, int sz> template <typename T, int sz>
@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) {
// scalar add functions // scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and // for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly // bfloat is disabled so we call the intrinsics directly
DINLINE half &assign_add(half &a, half b) { DINLINE half& assign_add(half& a, half b) {
a = __hadd(a, b); a = __hadd(a, b);
return a; return a;
} }
DINLINE float &assign_add(float &a, float b) { return a += b; } DINLINE float& assign_add(float& a, float b) { return a += b; }
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
@ -80,14 +80,14 @@ template <>
DINLINE nv_bfloat16 downcast_s(float val) { DINLINE nv_bfloat16 downcast_s(float val) {
return __float2bfloat16(val); return __float2bfloat16(val);
} }
DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) { DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) {
a = __hadd(a, b); a = __hadd(a, b);
return a; return a;
} }
#endif #endif
template <typename T, int N> template <typename T, int N>
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) { DINLINE array_t<T, N>& packed_assign_add(array_t<T, N>& a, array_t<T, N> b) {
#pragma unroll #pragma unroll
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
assign_add(a.data[i], b.data[i]); assign_add(a.data[i], b.data[i]);
@ -128,7 +128,7 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against // prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes. // other volatile writes.
template <int ngpus> template <int ngpus>
DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg, DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) { int rank) {
if (threadIdx.x < ngpus) { if (threadIdx.x < ngpus) {
// reset flag for next time // reset flag for next time
@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write // Latency = 1 p2p write
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
// wait until we got true from all ranks // wait until we got true from all ranks
while (!self_sg->start[blockIdx.x][threadIdx.x]) while (!self_sg->start[blockIdx.x][threadIdx.x]);
;
} }
__syncthreads(); __syncthreads();
} }
@ -147,13 +146,13 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier, // barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses. // we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false> template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg, DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
int rank) { int rank) {
__syncthreads(); __syncthreads();
// eliminate the case that prior writes are not visible after signals become // eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of // visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than // testing. Might be the case that hardware provides stronger guarantee than
// the memory model. // the memory model.
if constexpr (!final_sync) __threadfence_system(); if constexpr (!final_sync) __threadfence_system();
if (threadIdx.x < ngpus) { if (threadIdx.x < ngpus) {
// reset flag for next time // reset flag for next time
@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write // Latency = 1 p2p write
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
// wait until we got true from all ranks // wait until we got true from all ranks
while (!self_sg->end[blockIdx.x][threadIdx.x]) while (!self_sg->end[blockIdx.x][threadIdx.x]);
;
} }
if constexpr (!final_sync) __syncthreads(); if constexpr (!final_sync) __syncthreads();
} }
template <typename P, int ngpus, typename A> template <typename P, int ngpus, typename A>
DINLINE P packed_reduce(const P *ptrs[], int idx) { DINLINE P packed_reduce(const P* ptrs[], int idx) {
A tmp = upcast(ptrs[0][idx]); A tmp = upcast(ptrs[0][idx]);
#pragma unroll #pragma unroll
for (int i = 1; i < ngpus; i++) { for (int i = 1; i < ngpus; i++) {
@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
template <typename T, int ngpus> template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) __global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData *_dp, RankSignals sg, cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result, volatile Signal* self_sg, T* __restrict__ result,
int rank, int size) { int rank, int size) {
using P = typename packed_t<T>::P; using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A; using A = typename packed_t<T>::A;
@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1)
// do the actual reduction // do the actual reduction
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
((P *)result)[idx] = ((P*)result)[idx] = packed_reduce<P, ngpus, A>((const P**)&dp.ptrs[0], idx);
packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
} }
end_sync<ngpus, true>(sg, self_sg, rank); end_sync<ngpus, true>(sg, self_sg, rank);
} }
template <typename P> template <typename P>
DINLINE P *get_tmp_buf(volatile Signal *sg) { DINLINE P* get_tmp_buf(volatile Signal* sg) {
return (P *)(((Signal *)sg) + 1); return (P*)(((Signal*)sg) + 1);
} }
template <typename T, int ngpus> template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1) __global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData *_dp, RankSignals sg, cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result, volatile Signal* self_sg, T* __restrict__ result,
int rank, int size) { int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x; int stride = gridDim.x * blockDim.x;
@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1)
int start = rank * part; int start = rank * part;
int end = rank == ngpus - 1 ? size : start + part; int end = rank == ngpus - 1 ? size : start + part;
int largest_part = part + size % ngpus; int largest_part = part + size % ngpus;
const P *ptrs[ngpus]; const P* ptrs[ngpus];
P *tmps[ngpus]; P* tmps[ngpus];
#pragma unroll #pragma unroll
for (int i = 0; i < ngpus; i++) { for (int i = 0; i < ngpus; i++) {
int target = (rank + i) % ngpus; int target = (rank + i) % ngpus;
ptrs[i] = (const P *)_dp->ptrs[target]; ptrs[i] = (const P*)_dp->ptrs[target];
tmps[i] = get_tmp_buf<P>(sg.signals[target]); tmps[i] = get_tmp_buf<P>(sg.signals[target]);
} }
auto tmp_out = tmps[0]; auto tmp_out = tmps[0];
@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1)
int gather_from_rank = ((rank + i) % ngpus); int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) { if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx; int dst_idx = gather_from_rank * part + idx;
((P *)result)[dst_idx] = tmps[i][idx]; ((P*)result)[dst_idx] = tmps[i][idx];
} }
} }
} }
@ -261,14 +258,14 @@ class CustomAllreduce {
// below are device pointers // below are device pointers
RankSignals sg_; RankSignals sg_;
std::unordered_map<void *, RankData *> buffers_; std::unordered_map<void*, RankData*> buffers_;
Signal *self_sg_; Signal* self_sg_;
// stores the registered device pointers from all ranks // stores the registered device pointers from all ranks
RankData *d_rank_data_base_, *d_rank_data_end_; RankData *d_rank_data_base_, *d_rank_data_end_;
std::vector<void *> graph_unreg_buffers_; std::vector<void*> graph_unreg_buffers_;
// a map from IPC handles to opened IPC pointers // a map from IPC handles to opened IPC pointers
std::map<IPC_KEY, char *> ipc_handles_; std::map<IPC_KEY, char*> ipc_handles_;
/** /**
* meta is a pointer to device metadata and temporary buffer for allreduce. * meta is a pointer to device metadata and temporary buffer for allreduce.
@ -279,22 +276,22 @@ class CustomAllreduce {
* note: this class does not own any device memory. Any required buffers * note: this class does not own any device memory. Any required buffers
* are passed in from the constructor * are passed in from the constructor
*/ */
CustomAllreduce(Signal *meta, void *rank_data, size_t rank_data_sz, CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz,
const cudaIpcMemHandle_t *handles, const cudaIpcMemHandle_t* handles,
const std::vector<int64_t> &offsets, int rank, const std::vector<int64_t>& offsets, int rank,
bool full_nvlink = true) bool full_nvlink = true)
: rank_(rank), : rank_(rank),
world_size_(offsets.size()), world_size_(offsets.size()),
full_nvlink_(full_nvlink), full_nvlink_(full_nvlink),
self_sg_(meta), self_sg_(meta),
d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)), d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) { for (int i = 0; i < world_size_; i++) {
Signal *rank_sg; Signal* rank_sg;
if (i != rank_) { if (i != rank_) {
char *handle = open_ipc_handle(&handles[i]); char* handle = open_ipc_handle(&handles[i]);
handle += offsets[i]; handle += offsets[i];
rank_sg = (Signal *)handle; rank_sg = (Signal*)handle;
} else { } else {
rank_sg = self_sg_; rank_sg = self_sg_;
} }
@ -302,13 +299,13 @@ class CustomAllreduce {
} }
} }
char *open_ipc_handle(const void *ipc_handle) { char* open_ipc_handle(const void* ipc_handle) {
auto [it, new_handle] = auto [it, new_handle] =
ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr}); ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr});
if (new_handle) { if (new_handle) {
char *ipc_ptr; char* ipc_ptr;
CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr, CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr,
*((const cudaIpcMemHandle_t *)ipc_handle), *((const cudaIpcMemHandle_t*)ipc_handle),
cudaIpcMemLazyEnablePeerAccess)); cudaIpcMemLazyEnablePeerAccess));
it->second = ipc_ptr; it->second = ipc_ptr;
} }
@ -323,7 +320,7 @@ class CustomAllreduce {
std::vector<int64_t> offsets(num_buffers); std::vector<int64_t> offsets(num_buffers);
for (int i = 0; i < num_buffers; i++) { for (int i = 0; i < num_buffers; i++) {
auto ptr = graph_unreg_buffers_[i]; auto ptr = graph_unreg_buffers_[i];
void *base_ptr; void* base_ptr;
// note: must share the base address of each allocation, or we get wrong // note: must share the base address of each allocation, or we get wrong
// address // address
if (cuPointerGetAttribute(&base_ptr, if (cuPointerGetAttribute(&base_ptr,
@ -331,8 +328,8 @@ class CustomAllreduce {
(CUdeviceptr)ptr) != CUDA_SUCCESS) (CUdeviceptr)ptr) != CUDA_SUCCESS)
throw std::runtime_error("failed to get pointer attr"); throw std::runtime_error("failed to get pointer attr");
CUDACHECK(cudaIpcGetMemHandle( CUDACHECK(cudaIpcGetMemHandle(
(cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr));
offsets[i] = ((char *)ptr) - ((char *)base_ptr); offsets[i] = ((char*)ptr) - ((char*)base_ptr);
} }
return std::make_pair(handles, offsets); return std::make_pair(handles, offsets);
} }
@ -344,13 +341,13 @@ class CustomAllreduce {
std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
} }
void register_buffer(const std::vector<std::string> &handles, void register_buffer(const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets, void *self) { const std::vector<int64_t>& offsets, void* self) {
check_rank_data_capacity(); check_rank_data_capacity();
RankData data; RankData data;
for (int i = 0; i < world_size_; i++) { for (int i = 0; i < world_size_; i++) {
if (i != rank_) { if (i != rank_) {
char *handle = open_ipc_handle(handles[i].data()); char* handle = open_ipc_handle(handles[i].data());
handle += offsets[i]; handle += offsets[i];
data.ptrs[i] = handle; data.ptrs[i] = handle;
} else { } else {
@ -371,17 +368,17 @@ class CustomAllreduce {
// got a different address. IPC handles have internal reference counting // got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small. // mechanism so overhead should be small.
void register_graph_buffers( void register_graph_buffers(
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>> &offsets) { const std::vector<std::vector<int64_t>>& offsets) {
auto num_buffers = graph_unreg_buffers_.size(); auto num_buffers = graph_unreg_buffers_.size();
check_rank_data_capacity(num_buffers); check_rank_data_capacity(num_buffers);
std::vector<RankData> rank_data(num_buffers); std::vector<RankData> rank_data(num_buffers);
for (int i = 0; i < num_buffers; i++) { for (int i = 0; i < num_buffers; i++) {
auto self_ptr = graph_unreg_buffers_[i]; auto self_ptr = graph_unreg_buffers_[i];
auto &rd = rank_data[i]; auto& rd = rank_data[i];
for (int j = 0; j < world_size_; j++) { for (int j = 0; j < world_size_; j++) {
if (j != rank_) { if (j != rank_) {
char *handle = char* handle =
open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
handle += offsets[j][i]; handle += offsets[j][i];
rd.ptrs[j] = handle; rd.ptrs[j] = handle;
@ -405,7 +402,7 @@ class CustomAllreduce {
* will cause contention on NVLink bus. * will cause contention on NVLink bus.
*/ */
template <typename T> template <typename T>
void allreduce(cudaStream_t stream, T *input, T *output, int size, void allreduce(cudaStream_t stream, T* input, T* output, int size,
int threads = 512, int block_limit = 36) { int threads = 512, int block_limit = 36) {
auto d = packed_t<T>::P::size; auto d = packed_t<T>::P::size;
if (size % d != 0) if (size % d != 0)
@ -418,7 +415,7 @@ class CustomAllreduce {
std::to_string(kMaxBlocks) + ". Got " + std::to_string(kMaxBlocks) + ". Got " +
std::to_string(block_limit)); std::to_string(block_limit));
RankData *ptrs; RankData* ptrs;
cudaStreamCaptureStatus status; cudaStreamCaptureStatus status;
CUDACHECK(cudaStreamIsCapturing(stream, &status)); CUDACHECK(cudaStreamIsCapturing(stream, &status));
if (status == cudaStreamCaptureStatusActive) { if (status == cudaStreamCaptureStatusActive) {

View File

@ -48,7 +48,7 @@ __global__ void dummy_kernel() {
} }
template <typename T> template <typename T>
__global__ void set_data(T *data, int size, int myRank) { __global__ void set_data(T* data, int size, int myRank) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
data[idx] = myRank * 0.11f; data[idx] = myRank * 0.11f;
@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) {
} }
template <typename T> template <typename T>
__global__ void convert_data(const T *data1, const T *data2, double *fdata1, __global__ void convert_data(const T* data1, const T* data2, double* fdata1,
double *fdata2, int size) { double* fdata2, int size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
fdata1[idx] = data1[idx]; fdata1[idx] = data1[idx];
@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1,
} }
} }
__global__ void init_rand(curandState_t *state, int size, int nRanks) { __global__ void init_rand(curandState_t* state, int size, int nRanks) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
for (int i = 0; i < nRanks; i++) { for (int i = 0; i < nRanks; i++) {
@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) {
} }
template <typename T> template <typename T>
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth, __global__ void gen_data(curandState_t* state, T* data, double* ground_truth,
int myRank, int nRanks, int size) { int myRank, int nRanks, int size) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
idx += gridDim.x * blockDim.x) { idx += gridDim.x * blockDim.x) {
@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
} }
template <typename T> template <typename T>
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
int data_size, bool performance_test) { int data_size, bool performance_test) {
T *result; T* result;
cudaStream_t stream; cudaStream_t stream;
CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
cudaIpcMemHandle_t self_data_handle; cudaIpcMemHandle_t self_data_handle;
cudaIpcMemHandle_t data_handles[8]; cudaIpcMemHandle_t data_handles[8];
vllm::Signal *buffer; vllm::Signal* buffer;
T *self_data_copy; T* self_data_copy;
/** /**
* Allocate IPC buffer * Allocate IPC buffer
* *
@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
MPI_BYTE, MPI_COMM_WORLD)); MPI_BYTE, MPI_COMM_WORLD));
void *rank_data; void* rank_data;
size_t rank_data_sz = 16 * 1024 * 1024; size_t rank_data_sz = 16 * 1024 * 1024;
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
std::vector<int64_t> offsets(nRanks, 0); std::vector<int64_t> offsets(nRanks, 0);
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
offsets, myRank); offsets, myRank);
auto *self_data = auto* self_data =
reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) + reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) +
sizeof(vllm::Signal) + data_size * sizeof(T)); sizeof(vllm::Signal) + data_size * sizeof(T));
// hack buffer registration // hack buffer registration
{ {
std::vector<std::string> handles; std::vector<std::string> handles;
handles.reserve(nRanks); handles.reserve(nRanks);
for (int i = 0; i < nRanks; i++) { for (int i = 0; i < nRanks; i++) {
char *begin = (char *)&data_handles[i]; char* begin = (char*)&data_handles[i];
char *end = (char *)&data_handles[i + 1]; char* end = (char*)&data_handles[i + 1];
handles.emplace_back(begin, end); handles.emplace_back(begin, end);
} }
std::vector<int64_t> offsets(nRanks, std::vector<int64_t> offsets(nRanks,
@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
fa.register_buffer(handles, offsets, self_data); fa.register_buffer(handles, offsets, self_data);
} }
double *ground_truth; double* ground_truth;
CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double))); CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
curandState_t *states; curandState_t* states;
CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
CUDACHECK(cudaStreamDestroy(stream)); CUDACHECK(cudaStreamDestroy(stream));
} }
int main(int argc, char **argv) { int main(int argc, char** argv) {
int nRanks, myRank; int nRanks, myRank;
MPICHECK(MPI_Init(&argc, &argv)); MPICHECK(MPI_Init(&argc, &argv));
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
@ -296,7 +296,7 @@ int main(int argc, char **argv) {
ncclUniqueId id; ncclUniqueId id;
ncclComm_t comm; ncclComm_t comm;
if (myRank == 0) ncclGetUniqueId(&id); if (myRank == 0) ncclGetUniqueId(&id);
MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0, MPICHECK(MPI_Bcast(static_cast<void*>(&id), sizeof(id), MPI_BYTE, 0,
MPI_COMM_WORLD)); MPI_COMM_WORLD));
NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));

View File

@ -6,32 +6,30 @@
#include <torch/extension.h> #include <torch/extension.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))

View File

@ -11,26 +11,24 @@
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162; using __nv_bfloat162 = __hip_bfloat162;
#endif #endif
namespace vllm { namespace vllm {
// TODO(woosuk): Further optimize this kernel. // TODO(woosuk): Further optimize this kernel.
template<typename scalar_t> template <typename scalar_t>
__global__ void rms_norm_kernel( __global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [..., hidden_size] scalar_t* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float) input[blockIdx.x * hidden_size + idx]; const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x; variance += x * x;
} }
variance = blockReduceSum<float>(variance); variance = blockReduceSum<float>(variance);
@ -40,12 +38,12 @@ __global__ void rms_norm_kernel(
__syncthreads(); __syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) input[blockIdx.x * hidden_size + idx]; float x = (float)input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; out[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
} }
} }
/* Converter structs for the conversion from torch types to HIP/CUDA types, /* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion to be implemented for now because the relevant type conversion
@ -54,49 +52,68 @@ __global__ void rms_norm_kernel(
Each struct should have the member static constexpr bool `exists`: Each struct should have the member static constexpr bool `exists`:
If false, the optimized kernel is not used for the corresponding torch type. If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below. If true, the struct should be fully defined as shown in the examples below.
*/ */
template<typename torch_type> template <typename torch_type>
struct _typeConvert { static constexpr bool exists = false; }; struct _typeConvert {
static constexpr bool exists = false;
};
template<> #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template <>
struct _typeConvert<c10::Half> { struct _typeConvert<c10::Half> {
static constexpr bool exists = true; static constexpr bool exists = true;
using hip_type = __half; using hip_type = __half;
using packed_hip_type = __half2; using packed_hip_type = __half2;
__device__ static inline float convert(hip_type x) { return __half2float(x); } __device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); } __device__ static inline float2 convert(packed_hip_type x) {
__device__ static inline hip_type convert(float x) { return __float2half_rn(x); } return __half22float2(x);
__device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); } }
__device__ static inline hip_type convert(float x) {
return __float2half_rn(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22half2_rn(x);
}
}; };
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support // CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely // TODO: Add in ROCm support once public headers handle bf16 maturely
template<> template <>
struct _typeConvert<c10::BFloat16> { struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true; static constexpr bool exists = true;
using hip_type = __nv_bfloat16; using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162; using packed_hip_type = __nv_bfloat162;
__device__ static inline float convert(hip_type x) { return __bfloat162float(x); } __device__ static inline float convert(hip_type x) {
__device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); } return __bfloat162float(x);
__device__ static inline hip_type convert(float x) { return __float2bfloat16(x); } }
__device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); } __device__ static inline float2 convert(packed_hip_type x) {
return __bfloat1622float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2bfloat16(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22bfloat162_rn(x);
}
}; };
#endif #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops /* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel. for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented. Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops. Alignment to 16 bytes is required to use 128-bit global memory ops.
*/ */
template<typename scalar_t, int width> template <typename scalar_t, int width>
struct alignas(16) _f16Vec { struct alignas(16) _f16Vec {
/* Not theoretically necessary that width is a power of 2 but should /* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */ almost always be the case for optimization purposes */
static_assert(width > 0 && (width & (width - 1)) == 0, static_assert(width > 0 && (width & (width - 1)) == 0,
"Width is not a positive power of 2!"); "Width is not a positive power of 2!");
using Converter = _typeConvert<scalar_t>; using Converter = _typeConvert<scalar_t>;
@ -106,51 +123,49 @@ struct alignas(16) _f16Vec {
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) { __device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i+1]}; T2 temp{data[i], data[i + 1]};
temp += T2{other.data[i], other.data[i+1]}; temp += T2{other.data[i], other.data[i + 1]};
data[i] = temp.x; data[i] = temp.x;
data[i+1] = temp.y; data[i + 1] = temp.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) for (int i = 0; i < width; ++i) data[i] += other.data[i];
data[i] += other.data[i];
} }
return *this; return *this;
} }
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) { __device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i+1]}; T2 temp{data[i], data[i + 1]};
temp *= T2{other.data[i], other.data[i+1]}; temp *= T2{other.data[i], other.data[i + 1]};
data[i] = temp.x; data[i] = temp.x;
data[i+1] = temp.y; data[i + 1] = temp.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) for (int i = 0; i < width; ++i) data[i] *= other.data[i];
data[i] *= other.data[i];
} }
return *this; return *this;
} }
__device__ _f16Vec& operator*=(const float scale) { __device__ _f16Vec& operator*=(const float scale) {
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
float2 temp_f = Converter::convert(T2{data[i], data[i+1]}); float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
temp_f.x *= scale; temp_f.x *= scale;
temp_f.y *= scale; temp_f.y *= scale;
T2 temp = Converter::convert(temp_f); T2 temp = Converter::convert(temp_f);
data[i] = temp.x; data[i] = temp.x;
data[i+1] = temp.y; data[i + 1] = temp.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) { for (int i = 0; i < width; ++i) {
float temp = Converter::convert(data[i]) * scale; float temp = Converter::convert(data[i]) * scale;
data[i] = Converter::convert(temp); data[i] = Converter::convert(temp);
@ -162,13 +177,13 @@ struct alignas(16) _f16Vec {
__device__ float sum_squares() const { __device__ float sum_squares() const {
float result = 0.0f; float result = 0.0f;
if constexpr (width % 2 == 0) { if constexpr (width % 2 == 0) {
#pragma unroll #pragma unroll
for (int i = 0; i < width; i += 2) { for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i+1]}); float2 z = Converter::convert(T2{data[i], data[i + 1]});
result += z.x * z.x + z.y * z.y; result += z.x * z.x + z.y * z.y;
} }
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) { for (int i = 0; i < width; ++i) {
float x = Converter::convert(data[i]); float x = Converter::convert(data[i]);
result += x * x; result += x * x;
@ -182,15 +197,13 @@ struct alignas(16) _f16Vec {
Additional optimizations we can make in this case are Additional optimizations we can make in this case are
packed and vectorized operations, which help with the packed and vectorized operations, which help with the
memory latency bottleneck. */ memory latency bottleneck. */
template<typename scalar_t, int width> template <typename scalar_t, int width>
__global__ std::enable_if_t< __global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
(width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel( fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic // Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>); static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width); static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
@ -201,9 +214,12 @@ __global__ std::enable_if_t<
/* These and the argument pointers are all declared `restrict` as they are /* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */ in this kernel as that would be undefined behavior */
auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input); auto* __restrict__ input_v =
auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual); reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight); auto* __restrict__ residual_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx; int id = blockIdx.x * vec_hidden_size + idx;
@ -213,10 +229,11 @@ __global__ std::enable_if_t<
residual_v[id] = temp; residual_v[id] = temp;
} }
/* Keep the following if-else block in sync with the /* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */ calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) { if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance); variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance); } else
variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon); s_variance = rsqrtf(variance / hidden_size + epsilon);
} }
@ -231,52 +248,50 @@ __global__ std::enable_if_t<
} }
} }
/* Generic fused_add_rms_norm_kernel /* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations. The width field is not used here but necessary for other specializations.
*/ */
template<typename scalar_t, int width> template <typename scalar_t, int width>
__global__ std::enable_if_t< __global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
(width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel( fused_add_rms_norm_kernel(
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon, const float epsilon, const int num_tokens, const int hidden_size) {
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance; __shared__ float s_variance;
float variance = 0.0f; float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
scalar_t z = input[blockIdx.x * hidden_size + idx]; scalar_t z = input[blockIdx.x * hidden_size + idx];
z += residual[blockIdx.x * hidden_size + idx]; z += residual[blockIdx.x * hidden_size + idx];
float x = (float) z; float x = (float)z;
variance += x * x; variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z; residual[blockIdx.x * hidden_size + idx] = z;
} }
/* Keep the following if-else block in sync with the /* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */ calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) { if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance); variance = blockReduceSum<float, 1024>(variance);
} else variance = blockReduceSum<float, 256>(variance); } else
variance = blockReduceSum<float, 256>(variance);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon); s_variance = rsqrtf(variance / hidden_size + epsilon);
} }
__syncthreads(); __syncthreads();
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) residual[blockIdx.x * hidden_size + idx]; float x = (float)residual[blockIdx.x * hidden_size + idx];
input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; input[blockIdx.x * hidden_size + idx] =
((scalar_t)(x * s_variance)) * weight[idx];
} }
} }
} // namespace vllm } // namespace vllm
void rms_norm( void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size]
torch::Tensor& weight, // [hidden_size] float epsilon) {
float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
@ -284,40 +299,27 @@ void rms_norm(
dim3 block(std::min(hidden_size, 1024)); dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
input.scalar_type(), vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
"rms_norm_kernel", out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
[&] { weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>( });
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);
});
} }
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \ input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
"fused_add_rms_norm_kernel", \ vllm::fused_add_rms_norm_kernel<scalar_t, width> \
[&] { \ <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
vllm::fused_add_rms_norm_kernel \ residual.data_ptr<scalar_t>(), \
<scalar_t, width><<<grid, block, 0, stream>>>( \ weight.data_ptr<scalar_t>(), epsilon, \
input.data_ptr<scalar_t>(), \ num_tokens, hidden_size); \
residual.data_ptr<scalar_t>(), \ });
weight.data_ptr<scalar_t>(), \
epsilon, \
num_tokens, \
hidden_size); \
});
void fused_add_rms_norm( void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size] torch::Tensor& weight, // [hidden_size]
torch::Tensor& weight, // [hidden_size] float epsilon) {
float epsilon) {
int hidden_size = input.size(-1); int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int num_tokens = input.numel() / hidden_size;
@ -340,8 +342,8 @@ void fused_add_rms_norm(
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr()); auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr()); auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr()); auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \ bool ptrs_are_aligned =
&& wt_ptr % 16 == 0; inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) { if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8); LAUNCH_FUSED_ADD_RMS_NORM(8);
} else { } else {

View File

@ -3,5 +3,6 @@
#include <torch/extension.h> #include <torch/extension.h>
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs."); m.def("topk_softmax", &topk_softmax,
"Apply topk softmax to the gating outputs.");
} }

View File

@ -2,8 +2,6 @@
#include <torch/extension.h> #include <torch/extension.h>
void topk_softmax( void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& topk_weights, torch::Tensor& token_expert_indices,
torch::Tensor& topk_indices, torch::Tensor& gating_output);
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);

View File

@ -19,15 +19,22 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h"
#include <cub/cub.cuh> #ifndef USE_ROCM
#include <cub/util_type.cuh> #include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
namespace vllm { namespace vllm {
namespace moe { namespace moe {
static constexpr int WARP_SIZE = 32;
/// Aligned array type /// Aligned array type
template < template <
typename T, typename T,
@ -265,7 +272,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll #pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{ {
thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW)); thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW));
} }
// From this point, thread max in all the threads have the max within the row. // From this point, thread max in all the threads have the max within the row.
@ -282,7 +289,7 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll #pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{ {
row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW); row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW);
} }
// From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
@ -332,8 +339,8 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
#pragma unroll #pragma unroll
for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
{ {
float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW); float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW);
int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW); int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW);
// We want lower indices to "win" in every thread so we break ties this way // We want lower indices to "win" in every thread so we break ties this way
if (other_max > max_val || (other_max == max_val && other_expert < expert)) if (other_max > max_val || (other_max == max_val && other_expert < expert))
@ -383,7 +390,7 @@ struct TopkConstants
{ {
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
static constexpr int THREADS_PER_ROW = EXPERTS / VPT; static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
@ -396,7 +403,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
{ {
static constexpr std::size_t MAX_BYTES_PER_LDG = 16; static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>; using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
static constexpr int VPT = Constants::VPT; static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;

View File

@ -7,119 +7,128 @@
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#define CEILDIV(x,y) (((x) + (y) - 1) / (y)) #define CEILDIV(x, y) (((x) + (y) - 1) / (y))
namespace vllm { namespace vllm {
namespace { namespace {
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
// don't worry about overflow because num_experts is relatively small int32_t col) {
return row * total_col + col; // don't worry about overflow because num_experts is relatively small
} return row * total_col + col;
} }
} // namespace
template <typename scalar_t> template <typename scalar_t>
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
int32_t *sorted_token_ids, int32_t* sorted_token_ids,
int32_t *expert_ids, int32_t* expert_ids,
int32_t *total_tokens_post_pad, int32_t* total_tokens_post_pad,
int32_t num_experts, int32_t num_experts,
int32_t block_size, int32_t block_size, size_t numel) {
size_t numel) { const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread;
const size_t start_idx = threadIdx.x * tokens_per_thread;
extern __shared__ int32_t shared_mem[]; extern __shared__ int32_t shared_mem[];
int32_t* tokens_cnts = shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) int32_t* tokens_cnts =
int32_t* cumsum = shared_mem + (num_experts + 1) * num_experts; // 1d tensor with shape (num_experts + 1) shared_mem; // 2d tensor with shape (num_experts + 1, num_experts)
int32_t* cumsum =
shared_mem + (num_experts + 1) *
num_experts; // 1d tensor with shape (num_experts + 1)
for (int i = 0; i < num_experts; ++i) { for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
}
/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are
* assigned to expert expert_index.
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
}
__syncthreads();
// For each expert we accumulate the token counts from the different threads.
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
__syncthreads();
// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] +
CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
} }
*total_tokens_post_pad = cumsum[num_experts];
}
/** __syncthreads();
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are assigned
* to expert expert_index.
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
}
__syncthreads(); /**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
// For each expert we accumulate the token counts from the different threads. /**
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; * Each thread processes a token shard, calculating the index of each token
for (int i = 1; i <= blockDim.x; ++i) { * after sorting by expert number. Given the example topk_ids =
tokens_cnts[index(num_experts, i, threadIdx.x)] += tokens_cnts[index(num_experts, i-1, threadIdx.x)]; * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
} * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
__syncthreads(); */
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
// We accumulate the token counts of all experts in thread 0. int32_t expert_id = topk_ids[i];
if (threadIdx.x == 0) { /** The cumsum[expert_id] stores the starting index of the tokens that the
cumsum[0] = 0; * expert with expert_id needs to process, and
for (int i = 1; i <= num_experts; ++i) { * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], block_size) * block_size; * processed by the expert with expert_id within the current thread's token
} * shard.
*total_tokens_post_pad = cumsum[num_experts]; */
} int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
__syncthreads(); cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
/** ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
* For each expert, each thread processes the tokens of the corresponding blocks }
* and stores the corresponding expert_id for each block.
*/
for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
/**
* Each thread processes a token shard, calculating the index of each token after
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
* where * represents a padding value(preset in python).
*/
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
* stores the indices of the tokens processed by the expert with expert_id within
* the current thread's token shard.
*/
int32_t rank_post_pad = tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
}
}
} }
} // namespace vllm
void moe_align_block_size( void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
torch::Tensor topk_ids, int block_size, torch::Tensor sorted_token_ids,
int num_experts, torch::Tensor experts_ids,
int block_size, torch::Tensor num_tokens_post_pad) {
torch::Tensor sorted_token_ids, const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
torch::Tensor experts_ids, VLLM_DISPATCH_INTEGRAL_TYPES(
torch::Tensor num_tokens_post_pad) { topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
VLLM_DISPATCH_INTEGRAL_TYPES( // tensors
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { const int32_t shared_mem =
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors ((num_experts + 1) * num_experts + (num_experts + 1)) *
const int32_t shared_mem = ((num_experts + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); sizeof(int32_t);
// set dynamic shared mem // set dynamic shared mem
auto kernel = vllm::moe_align_block_size_kernel<scalar_t>; auto kernel = vllm::moe_align_block_size_kernel<scalar_t>;
AT_CUDA_CHECK( AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize((void *)kernel, shared_mem)); (void*)kernel, shared_mem));
kernel<<<1, num_experts, shared_mem, stream>>>( kernel<<<1, num_experts, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(), topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
sorted_token_ids.data_ptr<int32_t>(), experts_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
block_size,
topk_ids.numel()); topk_ids.numel());
}); });
} }

View File

@ -3,157 +3,139 @@
#include <torch/extension.h> #include <torch/extension.h>
void paged_attention_v1( void paged_attention_v1(
torch::Tensor& out, torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& query, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& key_cache, torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
torch::Tensor& value_cache, int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
int num_kv_heads, const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
float scale, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
torch::Tensor& block_tables, const int blocksparse_block_size, const int blocksparse_head_sliding_step);
torch::Tensor& context_lens,
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype);
void paged_attention_v2( void paged_attention_v2(
torch::Tensor& out, torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& exp_sums, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& max_logits, torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& tmp_out, torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
torch::Tensor& query, int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
torch::Tensor& key_cache, const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
torch::Tensor& value_cache, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
int num_kv_heads, const int blocksparse_block_size, const int blocksparse_head_sliding_step);
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype);
void rms_norm( void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
torch::Tensor& out, float epsilon);
torch::Tensor& input,
torch::Tensor& weight,
float epsilon);
void fused_add_rms_norm( void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& input, torch::Tensor& weight, float epsilon);
torch::Tensor& residual,
torch::Tensor& weight,
float epsilon);
void rotary_embedding( void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& positions, torch::Tensor& key, int head_size,
torch::Tensor& query, torch::Tensor& cos_sin_cache, bool is_neox);
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox);
void batched_rotary_embedding( void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& positions, torch::Tensor& key, int head_size,
torch::Tensor& query, torch::Tensor& cos_sin_cache, bool is_neox,
torch::Tensor& key, int rot_dim,
int head_size, torch::Tensor& cos_sin_cache_offsets);
torch::Tensor& cos_sin_cache,
bool is_neox,
int rot_dim,
torch::Tensor& cos_sin_cache_offsets);
void silu_and_mul( void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_and_mul( void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_tanh_and_mul( void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_new( void gelu_new(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
void gelu_fast( void gelu_fast(torch::Tensor& out, torch::Tensor& input);
torch::Tensor& out,
torch::Tensor& input);
#ifndef USE_ROCM #ifndef USE_ROCM
torch::Tensor awq_gemm( torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
torch::Tensor _in_feats, const torch::Tensor& codebooks,
torch::Tensor _kernel, const torch::Tensor& scales,
torch::Tensor _scaling_factors, const torch::Tensor& codebook_partition_sizes,
torch::Tensor _zeros, const std::optional<torch::Tensor>& bias);
int split_k_iters);
torch::Tensor awq_dequantize( torch::Tensor aqlm_dequant(const torch::Tensor& codes,
torch::Tensor _kernel, const torch::Tensor& codebooks,
torch::Tensor _scaling_factors, const torch::Tensor& codebook_partition_sizes);
torch::Tensor _zeros,
int split_k_iters, torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
int thx, torch::Tensor _scaling_factors, torch::Tensor _zeros,
int thy); int split_k_iters);
torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros, int split_k_iters, int thx,
int thy);
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k);
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
torch::Tensor& workspace, int64_t num_bits,
int64_t size_m, int64_t size_n,
int64_t size_k);
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& g_idx,
torch::Tensor& perm, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full);
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);
int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
torch::Tensor marlin_gemm(
torch::Tensor& a,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
torch::Tensor& workspace,
int64_t size_m,
int64_t size_n,
int64_t size_k);
#endif #endif
void squeezellm_gemm( void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor vec, torch::Tensor const& scale);
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor lookup_table);
torch::Tensor gptq_gemm( void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor a, torch::Tensor lookup_table);
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_exllama,
int bit);
void gptq_shuffle( torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor q_weight, torch::Tensor b_gptq_qzeros,
torch::Tensor q_perm, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
int bit); bool use_exllama, int bit);
void moe_align_block_size( void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit);
torch::Tensor topk_ids,
int num_experts, void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
int block_size, torch::Tensor& scale);
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor num_tokens_post_pad); torch::Tensor& scale);
void moe_align_block_size(torch::Tensor topk_ids, int num_experts,
int block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM #ifndef USE_ROCM
using fptr_t = uint64_t; using fptr_t = uint64_t;
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets, int rank, const std::vector<int64_t>& offsets, int rank,
bool full_nvlink);
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
bool full_nvlink); bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size,
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &reg_buffer, bool full_nvlink);
torch::Tensor &out); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out);
void dispose(fptr_t _fa); void dispose(fptr_t _fa);
int meta_size(); int meta_size();
void register_buffer(fptr_t _fa, torch::Tensor &t, void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string> &handles, const std::vector<std::string>& handles,
const std::vector<int64_t> &offsets); const std::vector<int64_t>& offsets);
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa); std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles, fptr_t _fa);
const std::vector<std::vector<int64_t>> &offsets); void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
const std::vector<std::vector<int64_t>>& offsets);
#endif #endif

View File

@ -7,14 +7,10 @@
namespace vllm { namespace vllm {
template<typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding( inline __device__ void apply_token_rotary_embedding(
scalar_t* __restrict__ arr, scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) {
const scalar_t* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index; int x_index, y_index;
scalar_t cos, sin; scalar_t cos, sin;
if (IS_NEOX) { if (IS_NEOX) {
@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding(
arr[y_index] = y * cos + x * sin; arr[y_index] = y * cos + x * sin;
} }
template<typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding( inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] // head_size] or [num_tokens, num_heads,
const scalar_t* cache_ptr, // head_size]
const int head_size, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int num_heads, // head_size] or [num_tokens, num_kv_heads,
const int num_kv_heads, // head_size]
const int rot_dim, const scalar_t* cache_ptr, const int head_size, const int num_heads,
const int token_idx, const int num_kv_heads, const int rot_dim, const int token_idx,
const int64_t query_stride, const int64_t query_stride, const int64_t key_stride) {
const int64_t key_stride)
{
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr; const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim; const scalar_t* sin_ptr = cache_ptr + embed_dim;
@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding(
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr, apply_token_rotary_embedding<scalar_t, IS_NEOX>(
sin_ptr, rot_offset, embed_dim); query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
} }
const int nk = num_kv_heads * embed_dim; const int nk = num_kv_heads * embed_dim;
@ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding(
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr, apply_token_rotary_embedding<scalar_t, IS_NEOX>(
sin_ptr, rot_offset, embed_dim); key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
} }
} }
template<typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel( __global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] const int64_t* __restrict__ positions, // [batch_size, seq_len] or
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] // [num_tokens]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // head_size] or [num_tokens, num_heads,
const int rot_dim, // head_size]
const int64_t query_stride, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int64_t key_stride, // head_size] or [num_tokens, num_kv_heads,
const int num_heads, // head_size]
const int num_kv_heads, const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
const int head_size) { // 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
} }
template<typename scalar_t, bool IS_NEOX> template <typename scalar_t, bool IS_NEOX>
__global__ void batched_rotary_embedding_kernel( __global__ void batched_rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] const int64_t* __restrict__ positions, // [batch_size, seq_len] or
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] // [num_tokens]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] // head_size] or [num_tokens, num_heads,
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens] // head_size]
const int rot_dim, scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
const int64_t query_stride, // head_size] or [num_tokens, num_kv_heads,
const int64_t key_stride, // head_size]
const int num_heads, const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
const int num_kv_heads, // 2]
const int head_size) { const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len]
// or [num_tokens]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
// Each thread block is responsible for one token. // Each thread block is responsible for one token.
const int token_idx = blockIdx.x; const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx]; int64_t pos = positions[token_idx];
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; const scalar_t* cache_ptr =
cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); apply_rotary_embedding<scalar_t, IS_NEOX>(
query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
} }
} // namespace vllm } // namespace vllm
void rotary_embedding( void rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] // [num_tokens, num_heads * head_size]
int head_size, torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] // [num_tokens, num_kv_heads * head_size]
bool is_neox) { int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int64_t num_tokens = query.numel() / query.size(-1); int64_t num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size; int num_heads = query.size(-1) / head_size;
@ -135,36 +141,21 @@ void rotary_embedding(
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
query.scalar_type(), if (is_neox) {
"rotary_embedding", vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
[&] { positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
if (is_neox) { key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( query_stride, key_stride, num_heads, num_kv_heads, head_size);
positions.data_ptr<int64_t>(), } else {
query.data_ptr<scalar_t>(), vllm::rotary_embedding_kernel<scalar_t, false>
key.data_ptr<scalar_t>(), <<<grid, block, 0, stream>>>(
cos_sin_cache.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
rot_dim, key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
query_stride, rot_dim, query_stride, key_stride, num_heads, num_kv_heads,
key_stride, head_size);
num_heads, }
num_kv_heads, });
head_size);
} else {
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
} }
/* /*
@ -172,14 +163,15 @@ Batched version of rotary embedding, pack multiple LoRAs together
and process in batched manner. and process in batched manner.
*/ */
void batched_rotary_embedding( void batched_rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] // [num_tokens, num_heads * head_size]
int head_size, torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] // [num_tokens, num_kv_heads * head_size]
bool is_neox, int head_size,
int rot_dim, torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
torch::Tensor& cos_sin_cache_offsets // [num_tokens] bool is_neox, int rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
) { ) {
int64_t num_tokens = cos_sin_cache_offsets.size(0); int64_t num_tokens = cos_sin_cache_offsets.size(0);
int num_heads = query.size(-1) / head_size; int num_heads = query.size(-1) / head_size;
@ -191,36 +183,21 @@ void batched_rotary_embedding(
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] {
query.scalar_type(), if (is_neox) {
"rotary_embedding", vllm::batched_rotary_embedding_kernel<scalar_t, true>
[&] { <<<grid, block, 0, stream>>>(
if (is_neox) { positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>( key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
positions.data_ptr<int64_t>(), cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
query.data_ptr<scalar_t>(), key_stride, num_heads, num_kv_heads, head_size);
key.data_ptr<scalar_t>(), } else {
cos_sin_cache.data_ptr<scalar_t>(), vllm::batched_rotary_embedding_kernel<scalar_t, false>
cos_sin_cache_offsets.data_ptr<int64_t>(), <<<grid, block, 0, stream>>>(
rot_dim, positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
query_stride, key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
key_stride, cos_sin_cache_offsets.data_ptr<int64_t>(), rot_dim, query_stride,
num_heads, key_stride, num_heads, num_kv_heads, head_size);
num_kv_heads, }
head_size); });
} else {
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
} }

View File

@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)

View File

@ -1,4 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)

View File

@ -1,4 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)

View File

@ -1,4 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)

View File

@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16)

View File

@ -1,4 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)

View File

@ -14,6 +14,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 128) \ f(in_T, out_T, W_T, narrow, 128) \
f(in_T, out_T, W_T, narrow, 256) \ f(in_T, out_T, W_T, narrow, 256) \
f(in_T, out_T, W_T, narrow, 512) \ f(in_T, out_T, W_T, narrow, 512) \
f(in_T, out_T, W_T, narrow, 640) \
f(in_T, out_T, W_T, narrow, 768) \ f(in_T, out_T, W_T, narrow, 768) \
f(in_T, out_T, W_T, narrow, 1024) \ f(in_T, out_T, W_T, narrow, 1024) \
f(in_T, out_T, W_T, narrow, 1152) \ f(in_T, out_T, W_T, narrow, 1152) \
@ -27,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 2752) \ f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2816) \ f(in_T, out_T, W_T, narrow, 2816) \
f(in_T, out_T, W_T, narrow, 3072) \ f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3328) \
f(in_T, out_T, W_T, narrow, 3456) \ f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \ f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \ f(in_T, out_T, W_T, narrow, 4096) \
@ -35,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 5504) \ f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \ f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \ f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6400) \
f(in_T, out_T, W_T, narrow, 6848) \ f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \ f(in_T, out_T, W_T, narrow, 7168) \
@ -46,11 +49,13 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 13696) \ f(in_T, out_T, W_T, narrow, 13696) \
f(in_T, out_T, W_T, narrow, 13824) \ f(in_T, out_T, W_T, narrow, 13824) \
f(in_T, out_T, W_T, narrow, 14336) \ f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 15360) \
f(in_T, out_T, W_T, narrow, 16384) \ f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 20480) \ f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 22016) \ f(in_T, out_T, W_T, narrow, 22016) \
f(in_T, out_T, W_T, narrow, 24576) \ f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 27392) \ f(in_T, out_T, W_T, narrow, 27392) \
f(in_T, out_T, W_T, narrow, 27648) \
f(in_T, out_T, W_T, narrow, 28672) \ f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \ f(in_T, out_T, W_T, narrow, 32256) \
@ -58,9 +63,91 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 32768) \ f(in_T, out_T, W_T, narrow, 32768) \
f(in_T, out_T, W_T, narrow, 33024) \ f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 43264) \
f(in_T, out_T, W_T, narrow, 49152) \ f(in_T, out_T, W_T, narrow, 49152) \
f(in_T, out_T, W_T, narrow, 64000) \
f(in_T, out_T, W_T, narrow, 64256) \
f(in_T, out_T, W_T, narrow, 64512) \
f(in_T, out_T, W_T, narrow, 102400) \
f(in_T, out_T, W_T, narrow, 102656) \
f(in_T, out_T, W_T, narrow, 102912) \
f(in_T, out_T, W_T, narrow, 128000) \
f(in_T, out_T, W_T, narrow, 128256) \
f(in_T, out_T, W_T, narrow, 128512) \
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py
// Used for defining kernels going from the variety of
// dim in to the narrow dim out
// Using it for the fully sharded column
// parallel LoRA A which splits the rank dim
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, 128, narrow) \
f(in_T, out_T, W_T, 256, narrow) \
f(in_T, out_T, W_T, 512, narrow) \
f(in_T, out_T, W_T, 640, narrow) \
f(in_T, out_T, W_T, 768, narrow) \
f(in_T, out_T, W_T, 1024, narrow) \
f(in_T, out_T, W_T, 1152, narrow) \
f(in_T, out_T, W_T, 1280, narrow) \
f(in_T, out_T, W_T, 1536, narrow) \
f(in_T, out_T, W_T, 1728, narrow) \
f(in_T, out_T, W_T, 1792, narrow) \
f(in_T, out_T, W_T, 2048, narrow) \
f(in_T, out_T, W_T, 2304, narrow) \
f(in_T, out_T, W_T, 2560, narrow) \
f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3328, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \
f(in_T, out_T, W_T, 4608, narrow) \
f(in_T, out_T, W_T, 5120, narrow) \
f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6400, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \
f(in_T, out_T, W_T, 8192, narrow) \
f(in_T, out_T, W_T, 9216, narrow) \
f(in_T, out_T, W_T, 10240, narrow) \
f(in_T, out_T, W_T, 11008, narrow) \
f(in_T, out_T, W_T, 12288, narrow) \
f(in_T, out_T, W_T, 13696, narrow) \
f(in_T, out_T, W_T, 13824, narrow) \
f(in_T, out_T, W_T, 14336, narrow) \
f(in_T, out_T, W_T, 15360, narrow) \
f(in_T, out_T, W_T, 16384, narrow) \
f(in_T, out_T, W_T, 20480, narrow) \
f(in_T, out_T, W_T, 22016, narrow) \
f(in_T, out_T, W_T, 24576, narrow) \
f(in_T, out_T, W_T, 27392, narrow) \
f(in_T, out_T, W_T, 27648, narrow) \
f(in_T, out_T, W_T, 28672, narrow) \
f(in_T, out_T, W_T, 32000, narrow) \
f(in_T, out_T, W_T, 32256, narrow) \
f(in_T, out_T, W_T, 32512, narrow) \
f(in_T, out_T, W_T, 32768, narrow) \
f(in_T, out_T, W_T, 33024, narrow) \
f(in_T, out_T, W_T, 36864, narrow) \
f(in_T, out_T, W_T, 43264, narrow) \
f(in_T, out_T, W_T, 49152, narrow) \
f(in_T, out_T, W_T, 64000, narrow) \
f(in_T, out_T, W_T, 64256, narrow) \
f(in_T, out_T, W_T, 64512, narrow) \
f(in_T, out_T, W_T, 102400, narrow) \
f(in_T, out_T, W_T, 102656, narrow) \
f(in_T, out_T, W_T, 102912, narrow) \
f(in_T, out_T, W_T, 128000, narrow) \
f(in_T, out_T, W_T, 128256, narrow) \
f(in_T, out_T, W_T, 128512, narrow) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA // Keep above in sync with vllm/lora/layers::SamplerWithLoRA
// Keep this in sync with vllm/config::LoRAConfig // Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
@ -68,4 +155,14 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
f(in_T, out_T, W_T, 8, 64) \
f(in_T, out_T, W_T, 16, 64) \
f(in_T, out_T, W_T, 32, 64) \
f(in_T, out_T, W_T, 64, 64)
// clang-format on // clang-format on

View File

@ -1,4 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)

View File

@ -1,4 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)

View File

@ -1,4 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)

View File

@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half)

View File

@ -1,4 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)

View File

@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half)

View File

@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16)

View File

@ -1,4 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)

View File

@ -1,4 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)

View File

@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half)

View File

@ -1,4 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)

View File

@ -1,4 +0,0 @@
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)

View File

@ -1,8 +1,14 @@
#pragma once #pragma once
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#ifndef USE_ROCM
#include <cooperative_groups.h> #include <cooperative_groups.h>
#else
#include <hip/hip_cooperative_groups.h>
#endif
#ifndef USE_ROCM
#include <cuda/pipeline> #include <cuda/pipeline>
#endif
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <iostream> #include <iostream>
#include <stdio.h> #include <stdio.h>
@ -11,6 +17,24 @@
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
#ifdef USE_ROCM
template <size_t len>
__host__ __device__
inline void* memcpy_blocking(void *dst, const void *src) {
// Does not handle the case of long datatypes
char *d = reinterpret_cast<char *>(dst);
const char *s = reinterpret_cast<const char *>(src);
size_t i = 0;
#pragma unroll
for (i = 0; i < len; ++i) {
d[i] = s[i];
}
return dst;
}
#endif
#ifndef USE_ROCM
// nthrs = (32, 4) // nthrs = (32, 4)
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size, template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T, size_t W_copy_size, int tx, int ty, int tz, typename in_T,
@ -141,6 +165,81 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
} }
} }
#else
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
typename out_T, typename W_T>
__global__ void
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
const W_T *__restrict__ W,
const int64_t *__restrict__ indicies, int64_t y_offset,
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
float scale) {
size_t batch_idx = blockIdx.y;
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
if (idx < 0) {
return;
}
size_t j = blockIdx.x;
constexpr size_t tile_size = tx * ty * vec_size;
constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size;
__shared__ float y_warpwise[ty];
float y = 0;
vec_t<in_T, vec_size> x_vec;
vec_t<W_T, vec_size> w_vec;
size_t tile_idx;
#pragma unroll
for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
x_vec.load(X + (batch_idx * feat_in) +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size);
w_vec.load(W + (idx * feat_out + j) * feat_in +
tile_idx * tile_size +
(threadIdx.y * tx + threadIdx.x) * vec_size);
}
float sum = 0.f;
#pragma unroll
for (size_t i = 0; i < vec_size; ++i) {
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
}
#pragma unroll
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
sum += VLLM_SHFL_DOWN_SYNC(sum, offset);
}
__syncthreads();
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
y += sum;
}
}
if (threadIdx.x == 0) {
y_warpwise[threadIdx.y] = y;
}
__syncthreads();
float y_write = 0.f;
#pragma unroll
for (size_t i = 0; i < ty; ++i) {
y_write += y_warpwise[i];
}
// write Y;
if (threadIdx.x == 0 && threadIdx.y == 0) {
size_t y_idx = batch_idx * full_y_size + y_offset + j;
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(y_write));
}
}
#endif
// nthrs = (2, 16, 4) // nthrs = (2, 16, 4)
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz, template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
typename in_T, typename out_T, typename W_T> typename in_T, typename out_T, typename W_T>
@ -172,7 +271,11 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
float sum = 0.f; float sum = 0.f;
#pragma unroll #pragma unroll
for (size_t i = 0; i < vec_size; ++i) { for (size_t i = 0; i < vec_size; ++i) {
#ifndef USE_ROCM
sum += float(w_vec[i]) * float(x_vec[i]) * scale; sum += float(w_vec[i]) * float(x_vec[i]) * scale;
#else
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
#endif
} }
cg::thread_block_tile g = cg::tiled_partition<tx>(block); cg::thread_block_tile g = cg::tiled_partition<tx>(block);
@ -183,8 +286,14 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
sum = g.shfl(sum, 0); sum = g.shfl(sum, 0);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
#ifndef USE_ROCM
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum); threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
#else
size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
threadIdx.z * ty + threadIdx.y;
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(sum));
#endif
} }
} }
@ -199,7 +308,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
constexpr int tz = 4; constexpr int tz = 4;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if constexpr (feat_in < feat_out) { if constexpr (feat_in <= feat_out) {
static_assert(feat_in % vec_size == 0); static_assert(feat_in % vec_size == 0);
constexpr int tx = feat_in / vec_size; constexpr int tx = feat_in / vec_size;
@ -236,6 +345,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
scale); scale);
} }
} else { } else {
#ifndef USE_ROCM
static_assert(feat_in % (vec_size * 32) == 0 || static_assert(feat_in % (vec_size * 32) == 0 ||
feat_in % (vec_size * 16) == 0 || feat_in % (vec_size * 16) == 0 ||
feat_in % (vec_size * 8) == 0); feat_in % (vec_size * 8) == 0);
@ -279,6 +389,50 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
full_y_size, num_layers, layer_idx, full_y_size, num_layers, layer_idx,
scale); scale);
} }
#else
constexpr size_t rocm_warp_size = warpSize;
#define CHECK_INPUT_TILEABLE_BY(vec_size_) \
feat_in % (rocm_warp_size * vec_size_) == 0
#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \
if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \
constexpr size_t vec_size_shrink = vec_size_; \
constexpr int tx = tx_; \
constexpr int ty = ty_; \
dim3 nblks(feat_out, batch_size); \
dim3 nthrs(tx, ty); \
bgmv_shrink_kernel<feat_in, feat_out, vec_size_shrink, \
vec_size_shrink * sizeof(in_T), \
vec_size_shrink * sizeof(W_T), \
tx, ty, tz> \
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, \
full_y_size, num_layers, layer_idx, \
scale); \
}
static_assert(CHECK_INPUT_TILEABLE_BY(32) ||
CHECK_INPUT_TILEABLE_BY(16) ||
CHECK_INPUT_TILEABLE_BY( 8) ||
CHECK_INPUT_TILEABLE_BY( 4) ||
CHECK_INPUT_TILEABLE_BY( 2) ||
CHECK_INPUT_TILEABLE_BY( 1));
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1)
#undef CHECK_INPUT_TILEABLE_BY
#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM
#endif
} }
} }
@ -289,6 +443,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale); int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ #define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \ INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T) INST_BGMV(wide, narrow, in_T, out_T, W_T)

View File

@ -10,6 +10,7 @@ TEMPLATE = """
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype}) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
""".lstrip() # noqa: E501 """.lstrip() # noqa: E501
for input_dtype in DTYPES: for input_dtype in DTYPES:
@ -18,6 +19,26 @@ for input_dtype in DTYPES:
if weight_dtype == "fp32": if weight_dtype == "fp32":
# FP32 weights are not supported. # FP32 weights are not supported.
continue continue
if output_dtype == "fp32":
# LoRA A matrix.
if input_dtype != weight_dtype:
# NOTE(woosuk): While Punica supports the case where the
# input and weight dtypes are different, we only generate
# the kernels the same dtypes to reduce the binary size.
continue
elif input_dtype == "fp32":
# LoRA B matrix.
if output_dtype != weight_dtype:
# NOTE(woosuk): While Punica supports the case where the
# output and weight dtypes are different, we only generate
# the kernels the same dtypes to reduce the binary size.
continue
elif not (input_dtype == output_dtype == weight_dtype):
# NOTE(woosuk): While Punica supports mixed data types for
# input, output, and weight, we only generate the kernels with
# the same data types to reduce the binary size.
continue
kernel_definition = TEMPLATE.format( kernel_definition = TEMPLATE.format(
input_dtype=DTYPE_MAP[input_dtype], input_dtype=DTYPE_MAP[input_dtype],
output_dtype=DTYPE_MAP[output_dtype], output_dtype=DTYPE_MAP[output_dtype],

Some files were not shown because too many files have changed in this diff Show More