Compare commits

...

210 Commits

Author SHA1 Message Date
852ef5b4f5 Bump up the version to v0.1.5 (#944) 2023-09-07 16:15:31 -07:00
db09d4ad83 [FIX] Fix Alibi implementation in PagedAttention kernel (#945)
* [FIX] Fix Alibi implementation in PagedAttention kernel

* Fix test_attention

* Fix

---------

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Oliver-ss <yuansongwx@outlook.com>
2023-09-07 15:53:14 -07:00
c957c741d9 Enable safetensors loading for all models (#974) 2023-09-07 15:49:52 -07:00
c07ece5ca4 Make AsyncLLMEngine more robust & fix batched abort (#969)
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Co-authored-by: Avnish Narayan <38871737+avnishn@users.noreply.github.com>
2023-09-07 13:43:45 -07:00
7a9c20c715 Bum up transformers version (#976) 2023-09-07 13:15:53 -07:00
005ba458b5 Set torch default dtype in a context manager (#971)
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
2023-09-07 15:39:37 +09:00
320a622ec4 [BugFix] Implement RoPE for GPT-J (#941) 2023-09-06 11:54:33 +09:00
c9927c1a6a Use queue for finished requests (#957) 2023-09-05 19:27:23 -07:00
fbd80ad409 Clean up kernel unit tests (#938) 2023-09-05 16:57:38 -07:00
22379d5513 fix: typo (#948) 2023-09-04 23:22:30 -07:00
1696725879 Initialize AsyncLLMEngine bg loop correctly (#943) 2023-09-04 17:41:22 -07:00
002800f081 Align vLLM's beam search implementation with HF generate (#857) 2023-09-04 17:29:42 -07:00
e15932bb60 Only emit warning about internal tokenizer if it isn't being used (#939) 2023-09-05 00:50:55 +09:00
ce741ba3e4 Refactor AsyncLLMEngine (#880) 2023-09-03 21:43:43 -07:00
bf87484efa [BugFix] Fix NaN errors in paged attention kernel (#936) 2023-09-04 09:20:06 +09:00
8ce9c50d40 Avoid compiling kernels for double data type (#933) 2023-09-02 14:59:47 +09:00
32b6816e55 Add tests for models (#922) 2023-09-01 11:19:43 +09:00
c128d69856 Fix README.md Link (#927) 2023-08-31 17:18:34 -07:00
55b28b1eee [Docs] Minor fixes in supported models (#920)
* Minor fix in supported models

* Add another small fix for Aquila model

---------

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2023-08-31 16:28:39 -07:00
e11222333f fix: bug fix when penalties are negative (#913)
Co-authored-by: dongyong-lee <dongyong.lee@navercorp.com>
2023-09-01 00:37:17 +09:00
28873a2799 Improve _prune_hidden_states micro-benchmark (#707) 2023-08-31 13:28:43 +09:00
0080d8329d Add acknowledgement to a16z grant 2023-08-30 02:26:47 -07:00
0d93f15694 Accelerate LLaMA model loading (#234) 2023-08-30 01:00:13 -07:00
becd7a56f1 Enable request body OpenAPI spec for OpenAI endpoints (#865) 2023-08-29 21:54:08 -07:00
75471386de use flash-attn via xformers (#877) 2023-08-29 21:52:13 -07:00
d2b2eed67c [Fix] Fix a condition for ignored sequences (#867) 2023-08-27 23:00:56 -07:00
4b6f069b6f Add support for CodeLlama (#854) 2023-08-25 12:44:07 -07:00
791d79de32 Bump up the version to v0.1.4 (#846) 2023-08-25 12:28:00 +09:00
94d2f59895 Set replacement=True in torch.multinomial (#858) 2023-08-25 12:22:01 +09:00
75c0ca9d43 Clean up code (#844) 2023-08-23 16:44:15 -07:00
2a4ec90854 Fix for breaking changes in xformers 0.0.21 (#834) 2023-08-23 17:44:21 +09:00
85ebcda94d Fix typo of Aquila in README.md (#836) 2023-08-22 20:48:36 -07:00
d64bf1646c Implement approximate GELU kernels (#828) 2023-08-23 07:43:21 +09:00
a41c20435e Add compute capability 8.9 to default targets (#829) 2023-08-23 07:28:38 +09:00
eedac9dba0 fix: revert code to avoid no attribute problem (#827) 2023-08-22 11:55:16 -07:00
14f9c72bfd Update Supported Model List (#825) 2023-08-22 11:51:44 -07:00
ad5f2fe34c Add support for aquila (#663)
* add aquila

Signed-off-by: ftgreat <ftgreat@163.com>

* fix some bug

Signed-off-by: shunxing1234 <xw747777271@gmail.com>

* delete pdb

Signed-off-by: shunxing1234 <xw747777271@gmail.com>

* fix bugs

Signed-off-by: shunxing1234 <xw747777271@gmail.com>

* fix bugs

Signed-off-by: shunxing1234 <xw747777271@gmail.com>

* delete whitespace

Signed-off-by: shunxing1234 <xw747777271@gmail.com>

* format

* fix order

---------

Signed-off-by: ftgreat <ftgreat@163.com>
Signed-off-by: shunxing1234 <xw747777271@gmail.com>
Co-authored-by: ftgreat <ftgreat@163.com>
2023-08-22 00:13:36 -07:00
4f8584756d Fix mqa is false case in gpt_bigcode (#806) 2023-08-21 22:22:06 -07:00
65fc1c3127 set default coompute capability according to cuda version (#773) 2023-08-21 16:05:44 -07:00
c393af6cd7 [Feature | CI] Added a github action to build wheels (#746) 2023-08-21 16:59:15 +09:00
0c04ce3234 Fix typo in sampling_params.py (#788) 2023-08-18 10:12:46 +09:00
73b3de79ea explicitly del state (#784) 2023-08-17 12:56:04 -07:00
d1744376ae Align with huggingface Top K sampling (#753) 2023-08-15 16:44:33 -07:00
805de738f6 Fix typo in tokenizer.py (#750)
conjuction -> conjunction
2023-08-14 22:26:36 -07:00
1b151ed181 Fix baichuan doc style (#748) 2023-08-13 20:57:31 -07:00
e06f504a76 Supports tokens and arrays of tokens as inputs to the OpenAI completion API (#715) 2023-08-11 12:14:34 -07:00
WRH
462ae5220a [Fix] unwantted bias in InternLM Model (#740) 2023-08-11 11:40:37 -07:00
66c54aa9c3 Check the max prompt length for the OpenAI completions API (#472) 2023-08-08 17:43:49 -07:00
735ecfff61 add internlm model (#528) 2023-08-08 16:35:06 -07:00
a57d13cc96 add QWen-7b (#685)
Co-authored-by: wq.chu <wq.chu@tianrang-inc.com>
2023-08-08 13:50:38 -07:00
79af7e96a0 [OPTIMIZATION] Optimizes the single_query_cached_kv_attention kernel (#420) 2023-08-04 10:57:29 -07:00
621980bdc0 fix: incorrect bigcode attention heads num (#676) 2023-08-04 10:35:22 -07:00
aa84c92ef6 Bump up version to 0.1.3 (#657) 2023-08-02 16:46:53 -07:00
f7389f4763 [Doc] Add Baichuan 13B to supported models (#656) 2023-08-02 16:45:12 -07:00
55fe8a81ec Refactor scheduler (#658) 2023-08-02 16:42:01 -07:00
e8ddc08ec8 [BUG FIX] upgrade fschat version to 0.2.23 (#650)
Co-authored-by: hao.yu <hao.yu@cn-c017.server.mila.quebec>
2023-08-02 14:05:59 -07:00
1b0bd0fe8a Add Falcon support (new) (#592) 2023-08-02 14:04:39 -07:00
20044cab7a Fix log message in scheduler (#652) 2023-08-02 13:35:10 -07:00
64f23c2900 fix baichuan for different position embedding for 7b and 13b models (#643) 2023-08-01 22:22:51 -07:00
d4c7755ca8 fix biachuan-7b tp (#598)
Co-authored-by: wq.chu <wq.chu@tianrang-inc.com>
2023-08-01 15:41:36 -07:00
aa39e42c5a fix doc (#622) 2023-07-31 13:11:57 -07:00
953f28cf9a fix ModuleNotFoundError (#599)
Co-authored-by: fangli <fangli@tencent.com>
2023-07-29 20:52:41 -07:00
c0d00f5be6 [Fix] fix import error of RayWorker (#604) (#605) 2023-07-27 23:37:40 -07:00
58a072be15 [Fix] Add model sequence length into model config (#575) 2023-07-25 23:46:30 -07:00
82ad323dee [Fix] Add chat completion Example and simplify dependencies (#576) 2023-07-25 23:45:48 -07:00
df5dd3c68e Add Baichuan-7B to README (#494) 2023-07-25 15:25:12 -07:00
2d867b55fa fixed tensor parallel is not defined (#564) 2023-07-25 14:16:51 -07:00
d7a1c6d614 Fix paged attention testing. (#495)
Signed-off-by: Tao Peng <jiankeng.pt@alibaba-inc.com>
2023-07-24 21:01:56 -07:00
7d5a155e4a [Fix] Fix GPTBigcoder for distributed execution (#503) 2023-07-24 18:36:33 -07:00
1dde34e0f8 GPTJConfig has no attribute rotary. (#532) 2023-07-24 11:29:30 -07:00
6fc2a38b11 Add support for LLaMA-2 (#505) 2023-07-20 11:38:27 -07:00
c487a221ee Fix bad assert in initialize_cluster if PG already exists (#526) 2023-07-19 23:17:12 -07:00
9925c17940 Ray placement group support (#397) 2023-07-19 22:49:31 -07:00
8c4b2592fb fix: enable trust-remote-code in api server & benchmark. (#509) 2023-07-19 17:06:15 -07:00
WRH
cf21a9bd5c support trust_remote_code in benchmark (#518) 2023-07-19 17:02:40 -07:00
16c3e295a8 fix(ray_utils): ignore re-init error (#465) 2023-07-19 17:01:19 -07:00
bda41c70dd hotfix attn alibi wo head mapping (#496)
Co-authored-by: oliveryuan <oliveryuan@basemind.com>
2023-07-18 11:31:48 -07:00
453bafb96f Merge pull request #498 from MoeedDar/main
Fixed old name reference for max_seq_len
2023-07-18 09:22:56 -07:00
328d231c17 Fixed old name reference for max_seq_len 2023-07-18 16:47:59 +01:00
b4b195b360 fix max seq len (#489) 2023-07-17 23:20:20 -07:00
20b0d88d16 Add support for baichuan (#365) 2023-07-17 13:50:55 -07:00
2bdea7ac11 [Fix] Fix the condition of max_seq_len (#477) 2023-07-17 00:33:48 -04:00
58df2883cb [Doc] Add doc for running vLLM on the cloud (#426)
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2023-07-16 13:37:14 -07:00
6d7d95a70a Offload port selection to OS (#467) 2023-07-15 23:11:02 -07:00
96853af5a8 Optimize MQA Kernel (#452) 2023-07-14 20:06:40 -04:00
dbed69058c Fix the KeyError when loading bloom-based models (#441) 2023-07-13 21:58:09 -07:00
7b6ae94059 add vocab padding for LLama(Support WizardLM) (#411) 2023-07-13 23:56:22 -04:00
c6dfc3cdbe Fix handling of special tokens in decoding. (#418) 2023-07-12 11:14:56 -04:00
51be365143 fix: freeze pydantic to v1 (#429) 2023-07-12 11:10:55 -04:00
c894836108 [Model] Add support for GPT-J (#226)
Co-authored-by: woWoosuk Kwon <woosuk.kwon@berkeley.edu>
2023-07-08 17:55:16 -07:00
75beba29b5 Don't try to load training_args.bin (#373) 2023-07-08 15:26:28 -07:00
ddfdf470ae Add trust_remote_code arg to get_config (#405) 2023-07-08 15:24:17 -07:00
b6fbb9a565 Sort the outputs before return (#402) 2023-07-08 14:48:18 -07:00
2179e4f4c5 avoid python list copy in sequence initialization (#401) 2023-07-08 12:42:08 -07:00
a945fcc2ae Add trust-remote-code flag to handle remote tokenizers (#364) 2023-07-07 11:04:58 -07:00
be54f8e5c4 [Fix] Change /generate response-type to json for non-streaming (#374) 2023-07-06 18:15:17 -07:00
b396cb4998 fix: only response [DONE] once when streaming response. (#378) 2023-07-06 18:08:40 -07:00
1c395b4eaa Bump up the version (#300) 2023-07-04 21:41:53 -07:00
3d64cf019e [Server] use fastchat.model.model_adapter.get_conversation_template method to get model template (#357) 2023-07-04 21:39:59 -07:00
98fe8cb542 [Server] Add option to specify chat template for chat endpoint (#345) 2023-07-03 23:01:56 -07:00
ffa6d2f9f9 [Docs] Fix typo (#346) 2023-07-03 16:51:47 -07:00
404422f42e [Model] Add support for MPT (#334) 2023-07-03 16:47:53 -07:00
7717d0838b Fix an endless loop issue when engine_step throws a RuntimeError (#339) 2023-07-03 15:22:28 -07:00
42e0c1df78 [Quality] Add CI for formatting (#343) 2023-07-03 14:50:56 -07:00
e41f06702c Add support for BLOOM (#331) 2023-07-03 13:12:35 -07:00
d6fa1be3a8 [Quality] Add code formatter and linter (#326) 2023-07-03 11:31:55 -07:00
0ffded812a [Fix] Better error message for batched prompts (#342) 2023-07-03 09:27:31 -07:00
0bd2a573a5 Allow send list of str for the Prompt on openai demo endpoint /v1/completions (#323)
* allow str or List[str] for prompt

* Update vllm/entrypoints/openai/api_server.py

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>

---------

Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2023-07-03 09:17:50 -07:00
49b26e2cec feat: add ChatCompletion endpoint in OpenAI demo server. (#330) 2023-07-02 22:54:33 -07:00
dafd924c1f Raise error for long prompt (#273) 2023-06-30 18:48:49 -07:00
598dc4b79a [Fix] Weight loading for GPTBigCode (#313) 2023-06-29 22:14:17 -07:00
85de093472 [Fix] Do not pin memory when in WSL (#312) 2023-06-29 15:00:21 -07:00
f72297562f Add news for the vllm+skypilot example (#314) 2023-06-29 12:32:37 -07:00
9d27b09d12 Update README.md (#306) 2023-06-29 06:52:15 -07:00
998d9d1509 [Tokenizer] Add tokenizer mode (#298) 2023-06-28 14:19:22 -07:00
425040d4c1 remove floats == 0 comparison (#285) 2023-06-28 14:11:51 -07:00
4338cc4750 [Tokenizer] Add an option to specify tokenizer (#284) 2023-06-28 09:46:58 -07:00
bdd6b4c8bc Add LLM.set_tokenizer (#283) 2023-06-28 00:28:29 -07:00
2b7d3aca2e Update setup.py (#282)
Co-authored-by: neubig <neubig@gmail.com>
2023-06-27 14:34:23 -07:00
4026a049d3 expand coverage of gpt2 model loading (#271) 2023-06-27 06:27:41 -07:00
43710e8d09 [Fix] Fix default port number in benchmark scripts (#265) 2023-06-26 13:15:35 -07:00
526df28fb2 [BugFix] Fix a bug in counting running sequences (#266) 2023-06-26 13:09:02 -07:00
2cf1a333b6 [Doc] Documentation for distributed inference (#261) 2023-06-26 11:34:23 -07:00
0b7db411b5 [Bug] Fix the OOM condition for CPU cache (#260) 2023-06-26 11:16:13 -07:00
471a7a4566 Compatible with Decapoda Research llama hf version (#251) 2023-06-26 09:23:57 -07:00
6214dd6ce9 Update README.md (#236) 2023-06-25 16:58:06 -07:00
0603379863 fix wrong using getattr to get dict value (#232) 2023-06-24 22:00:24 -07:00
665c48963b [Docs] Add GPTBigCode to supported models (#213) 2023-06-22 15:05:11 -07:00
298695b766 GPTBigCode (StarCoder, SantaCoder Support) (#209) 2023-06-23 01:49:27 +08:00
83658c8ace Bump up version to 0.1.1 (#204) 2023-06-22 15:33:32 +08:00
1d24ccb96c [Fix] Better error message when there is OOM during cache initialization (#203) 2023-06-22 15:30:06 +08:00
14f0b39cda [Bugfix] Fix a bug in RequestOutput.finished (#202) 2023-06-22 00:17:24 -07:00
2e0d314384 fix-ray (#193) 2023-06-22 00:21:41 +08:00
67d96c29fb Use slow tokenizer for open llama models (#168) 2023-06-20 14:19:47 +08:00
033f5c78f5 Remove e.g. in README (#167) 2023-06-20 14:00:28 +08:00
794e578de0 [Minor] Fix URLs (#166) 2023-06-19 22:57:14 -07:00
caddfc14c1 [Minor] Fix icons in doc (#165) 2023-06-19 20:35:38 -07:00
fc72e39de3 Change image urls (#164) 2023-06-20 11:15:15 +08:00
b7e62d3454 Fix repo & documentation URLs (#163) 2023-06-19 20:03:40 -07:00
364536acd1 [Docs] Minor fix (#162) 2023-06-19 19:58:23 -07:00
0b32a987dd Add and list supported models in README (#161) 2023-06-20 10:57:46 +08:00
570fb2e9cc [PyPI] Fix package info in setup.py (#158) 2023-06-19 18:05:01 -07:00
a255885f83 Add logo and polish readme (#156) 2023-06-19 16:31:13 +08:00
5822ede66e Add performance figures for dark mode (#160) 2023-06-18 23:46:24 -07:00
0370afa2e5 Remove benchmark_async_llm_server.py (#155) 2023-06-19 11:12:37 +08:00
7e2a913c64 [Minor] Fix CompletionOutput.__repr__ (#157) 2023-06-18 19:58:25 -07:00
3f92038b99 Add comments on swap space (#154) 2023-06-18 11:39:35 -07:00
dcda03b4cb Write README and front page of doc (#147) 2023-06-18 03:19:38 -07:00
bf5f121c02 Reduce GPU memory utilization to make sure OOM doesn't happen (#153) 2023-06-18 17:33:50 +08:00
bec7b2dc26 Add quickstart guide (#148) 2023-06-18 01:26:12 +08:00
0b98ba15c7 Change the name to vLLM (#150) 2023-06-17 03:07:40 -07:00
e5464ee484 Rename servers to engines (#152) 2023-06-17 17:25:21 +08:00
bab8f3dd0d [Minor] Fix benchmark_throughput.py (#151) 2023-06-16 21:00:52 -07:00
eedb46bf03 Rename servers and change port numbers to reduce confusion (#149) 2023-06-17 00:13:02 +08:00
311490a720 Add script for benchmarking serving throughput (#145) 2023-06-14 19:55:38 -07:00
da5ddcd544 Remove redundant code in ColumnParallelLinear (#146) 2023-06-10 21:25:11 -07:00
5020e1e80c Non-streaming simple fastapi server (#144) 2023-06-10 10:43:07 -07:00
4298374265 Add docstrings for LLMServer and related classes and examples (#142) 2023-06-07 18:25:20 +08:00
e38074b1e6 Support FP32 (#141) 2023-06-07 00:40:21 -07:00
376725ce74 [PyPI] Packaging for PyPI distribution (#140) 2023-06-05 20:03:14 -07:00
456941cfe4 [Docs] Write the Adding a New Model section (#138) 2023-06-05 20:01:26 -07:00
1a956e136b Fix various issues of async servers (#135) 2023-06-05 23:44:50 +08:00
8274ca23ac Add docstrings for LLM (#137) 2023-06-04 12:52:41 -07:00
62ec38ea41 Document supported models (#127) 2023-06-02 22:35:17 -07:00
0eda2e0953 Add .readthedocs.yaml (#136) 2023-06-02 22:27:44 -07:00
211318d44a Add throughput benchmarking script (#133) 2023-05-28 03:20:05 -07:00
337871c6fd Enable LLaMA fast tokenizer (#132) 2023-05-28 02:51:42 -07:00
56b7f0efa4 Add a doc for installation (#128) 2023-05-27 01:13:06 -07:00
d721168449 Improve setup script & Add a guard for bfloat16 kernels (#130) 2023-05-27 00:59:32 -07:00
4a151dd453 Add activation registry (#126) 2023-05-25 00:09:07 -07:00
057daef778 OpenAI Compatible Frontend (#116) 2023-05-23 21:39:50 -07:00
e86717833d Incrementally decode output tokens (#121) 2023-05-23 20:46:32 -07:00
aedba6d5ec Print warnings/errors for large swap space (#123) 2023-05-23 18:22:26 -07:00
a283ec2eec Add contributing guideline and mypy config (#122) 2023-05-23 17:58:51 -07:00
3f942acfe1 Fix latency benchmark script (#118) 2023-05-22 17:03:40 -07:00
19d2899439 Add initial sphinx docs (#120) 2023-05-22 17:02:44 -07:00
655a5e48df Introduce LLM class for offline inference (#115) 2023-05-21 17:04:18 -07:00
f746ced08d Implement stop strings and best_of (#114) 2023-05-21 11:18:00 -07:00
c3442c1f6f Refactor system architecture (#109) 2023-05-20 13:06:59 -07:00
7297fa6f7c Remove unused parts in Megatron-LM code and add copyright notice (#110) 2023-05-20 09:11:34 -06:00
b7955ef17b Fix timeout error in the FastAPI frontend (#34) 2023-05-19 14:00:46 -06:00
f756799b84 Use runtime profiling to replace manual memory analyzers (#81) 2023-05-19 11:35:44 -06:00
825d8892b5 Use pytest format for unit tests (#107) 2023-05-17 17:11:23 -07:00
b322fd1607 Add docstrings to some modules and classes (#100) 2023-05-14 22:32:38 -07:00
667ba3995c Add copyright headers to source files adapted from FT (#104) 2023-05-14 22:19:19 -07:00
707ec647bb Add copyright headers for HF models (#103) 2023-05-14 21:54:32 -07:00
89988ec8c2 Add Apache-2.0 license (#102) 2023-05-14 18:05:19 -07:00
6208d622ca Minor code cleaning for SamplingParams (#99) 2023-05-12 18:07:09 -07:00
42f1042e1c Enhance SamplingParams (#96) 2023-05-11 15:45:30 -07:00
55f8b0a5de Implement presence and frequency penalties (#95) 2023-05-10 23:39:12 -07:00
9f88db35da Support top-k sampling (#94) 2023-05-10 12:51:36 -07:00
ae356774ab Avoid sorting waiting queue & Minor code cleaning (#93) 2023-05-10 01:57:07 -07:00
e331957784 Log system stats (#90) 2023-05-10 01:06:53 -07:00
8d66a7b6d7 Rename variables and methods (#91) 2023-05-10 00:58:31 -07:00
ce26e57fd3 Update sample prompts in simple_server.py (#89) 2023-05-09 16:47:39 -07:00
85eb631839 Use slow tokenizer for LLaMA (#84) 2023-05-09 16:03:44 -07:00
add055e151 Enhance model loader (#83) 2023-05-09 15:46:42 -07:00
7c041ab578 Refactor system architecture (#82) 2023-05-09 15:30:12 -07:00
8917782af6 Add a system logger (#85) 2023-05-08 23:03:35 -07:00
7addca5935 Specify python package dependencies in requirements.txt (#78) 2023-05-07 16:30:43 -07:00
c84e924287 [Minor] Fix a dtype bug (#79) 2023-05-06 02:12:12 -07:00
c9d5b6d4a8 Replace FlashAttention with xformers (#70) 2023-05-05 02:01:08 -07:00
189ae23133 Use dtype from model config & Add Dolly V2 (#63) 2023-05-04 03:05:37 -07:00
e548c1488a Add support for GPT-2 (#60) 2023-05-04 02:59:56 -07:00
130d5fd8c7 Fix a bug in attention kernel (#68) 2023-05-04 02:56:09 -07:00
e070829ae8 Support bfloat16 data type (#54) 2023-05-03 14:09:44 -07:00
436e523bf1 Refactor attention kernels (#53) 2023-05-03 13:40:13 -07:00
27f1410d06 New weight loader without np copy (#52) 2023-05-03 15:32:04 +08:00
4858f3bb45 Add an option to launch cacheflow without ray (#51) 2023-04-30 15:42:17 +08:00
a96d63c21d Add support for GPT-NeoX (Pythia) (#50) 2023-04-28 00:32:10 -07:00
191 changed files with 16957 additions and 8125 deletions

101
.github/workflows/publish.yml vendored Normal file
View File

@ -0,0 +1,101 @@
# This workflow will upload a Python Package to Release asset
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions
name: Create Release
on:
push:
tags:
- v*
# Needed to create release and upload assets
permissions:
contents: write
jobs:
release:
# Retrieve tag and create release
name: Create Release
runs-on: ubuntu-latest
outputs:
upload_url: ${{ steps.create_release.outputs.upload_url }}
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Extract branch info
shell: bash
run: |
echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV
- name: Create Release
id: create_release
uses: "actions/github-script@v6"
env:
RELEASE_TAG: ${{ env.release_tag }}
with:
github-token: "${{ secrets.GITHUB_TOKEN }}"
script: |
const script = require('.github/workflows/scripts/create_release.js')
await script(github, context, core)
wheel:
name: Build Wheel
runs-on: ${{ matrix.os }}
needs: release
strategy:
fail-fast: false
matrix:
os: ['ubuntu-20.04']
python-version: ['3.8', '3.9', '3.10', '3.11']
cuda-version: ['11.8'] # Github runner can't build anything older than 11.8
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Set up Linux Env
if: ${{ runner.os == 'Linux' }}
run: |
bash -x .github/workflows/scripts/env.sh
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install CUDA ${{ matrix.cuda-version }}
run: |
bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
- name: Install PyTorch-cu${{ matrix.cuda-version }}
run: |
bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
- name: Build wheel
shell: bash
run: |
bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename)
asset_name=${wheel_name//"linux"/"manylinux1"}
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
echo "asset_name=${asset_name}" >> $GITHUB_ENV
- name: Upload Release Asset
uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ needs.release.outputs.upload_url }}
asset_path: ./dist/${{ env.wheel_name }}
asset_name: ${{ env.asset_name }}
asset_content_type: application/*
# (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
# - name: Publish package
# uses: pypa/gh-action-pypi-publish@release/v1.8
# with:
# repository-url: https://test.pypi.org/legacy/
# password: ${{ secrets.PYPI_API_TOKEN }}
# skip-existing: true

31
.github/workflows/pylint.yml vendored Normal file
View File

@ -0,0 +1,31 @@
name: pylint
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
pylint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint==2.8.2
- name: Analysing the code with pylint
run: |
pylint vllm

15
.github/workflows/scripts/build.sh vendored Normal file
View File

@ -0,0 +1,15 @@
#!/bin/bash
python_executable=python$1
cuda_home=/usr/local/cuda-$2
# Update paths
PATH=${cuda_home}/bin:$PATH
LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH
# Install requirements
$python_executable -m pip install wheel packaging
$python_executable -m pip install -r requirements.txt
# Build
$python_executable setup.py bdist_wheel --dist-dir=dist

View File

@ -0,0 +1,20 @@
// Uses Github's API to create the release and wait for result.
// We use a JS script since github CLI doesn't provide a way to wait for the release's creation and returns immediately.
module.exports = async (github, context, core) => {
try {
const response = await github.rest.repos.createRelease({
draft: false,
generate_release_notes: true,
name: process.env.RELEASE_TAG,
owner: context.repo.owner,
prerelease: false,
repo: context.repo.repo,
tag_name: process.env.RELEASE_TAG,
});
core.setOutput('upload_url', response.data.upload_url);
} catch (error) {
core.setFailed(error.message);
}
}

View File

@ -0,0 +1,18 @@
#!/bin/bash
# Replace '.' with '-' ex: 11.8 -> 11-8
cuda_version=$(echo $1 | tr "." "-")
# Removes '-' and '.' ex: ubuntu-20.04 -> ubuntu2004
OS=$(echo $2 | tr -d ".\-")
# Installs CUDA
wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb
sudo dpkg -i cuda-keyring_1.1-1_all.deb
rm cuda-keyring_1.1-1_all.deb
sudo apt -qq update
sudo apt -y install cuda-${cuda_version} cuda-nvcc-${cuda_version} cuda-libraries-dev-${cuda_version}
sudo apt clean
# Test nvcc
PATH=/usr/local/cuda-$1/bin:${PATH}
nvcc --version

56
.github/workflows/scripts/env.sh vendored Normal file
View File

@ -0,0 +1,56 @@
#!/bin/bash
# This file installs common linux environment tools
export LANG C.UTF-8
# python_version=$1
sudo apt-get update && \
sudo apt-get install -y --no-install-recommends \
software-properties-common \
sudo apt-get install -y --no-install-recommends \
build-essential \
apt-utils \
ca-certificates \
wget \
git \
vim \
libssl-dev \
curl \
unzip \
unrar \
cmake \
net-tools \
sudo \
autotools-dev \
rsync \
jq \
openssh-server \
tmux \
screen \
htop \
pdsh \
openssh-client \
lshw \
dmidecode \
util-linux \
automake \
autoconf \
libtool \
net-tools \
pciutils \
libpci-dev \
libaio-dev \
libcap2 \
libtinfo5 \
fakeroot \
devscripts \
debhelper \
nfs-common
# Remove github bloat files to free up disk space
sudo rm -rf "/usr/local/share/boost"
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
sudo rm -rf "/usr/share/dotnet"

View File

@ -0,0 +1,14 @@
#!/bin/bash
python_executable=python$1
cuda_version=$2
# Install torch
$python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya
$python_executable -m pip install torch -f https://download.pytorch.org/whl/cu${cuda_version//./}/torch_stable.html
# Print version information
$python_executable --version
$python_executable -c "import torch; print('PyTorch:', torch.__version__)"
$python_executable -c "import torch; print('CUDA:', torch.version.cuda)"
$python_executable -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"

31
.github/workflows/yapf.yml vendored Normal file
View File

@ -0,0 +1,31 @@
name: yapf
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
yapf:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install yapf==0.32.0
pip install toml==0.10.2
- name: Running yapf
run: |
yapf --diff --recursive vllm --exclude 'vllm/model_executor/parallel_utils/**'

181
.gitignore vendored
View File

@ -1,10 +1,175 @@
**/*.pyc
**/__pycache__/
*.egg-info/
*.eggs/
*.so
build/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# VSCode
.vscode/
# DS Store
.DS_Store
# Results
*.csv
# Python pickle files
*.pkl
*.png
**/log.txt
# Sphinx documentation
_build/

434
.pylintrc Normal file
View File

@ -0,0 +1,434 @@
# This Pylint rcfile contains a best-effort configuration to uphold the
# best-practices and style described in the Google Python style guide:
# https://google.github.io/styleguide/pyguide.html
#
# Its canonical open-source location is:
# https://google.github.io/styleguide/pylintrc
[MASTER]
# Files or directories to be skipped. They should be base names, not paths.
ignore=docs,parallel_utils
# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths.
ignore-patterns=
# Pickle collected data for later comparisons.
persistent=no
# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
load-plugins=
# Use multiple processes to speed up Pylint.
jobs=4
# Allow loading of arbitrary C extensions. Extensions are imported into the
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no
[MESSAGES CONTROL]
# Only show warnings with the listed confidence levels. Leave empty to show
# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED
confidence=
# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
# multiple time (only on the command line, not in the configuration file where
# it should appear only once). See also the "--disable" option for examples.
#enable=
# Disable the message, report, category or checker with the given id(s). You
# can either give multiple identifiers separated by comma (,) or put this
# option multiple times (only on the command line, not in the configuration
# file where it should appear only once).You can also use "--disable=all" to
# disable everything first and then reenable specific checks. For example, if
# you want to run only the similarities checker, you can use "--disable=all
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=abstract-method,
apply-builtin,
arguments-differ,
attribute-defined-outside-init,
backtick,
bad-option-value,
basestring-builtin,
buffer-builtin,
c-extension-no-member,
consider-using-enumerate,
cmp-builtin,
cmp-method,
coerce-builtin,
coerce-method,
delslice-method,
div-method,
duplicate-code,
eq-without-hash,
execfile-builtin,
file-builtin,
filter-builtin-not-iterating,
fixme,
getslice-method,
global-statement,
hex-method,
idiv-method,
implicit-str-concat-in-sequence,
import-error,
import-self,
import-star-module-level,
inconsistent-return-statements,
input-builtin,
intern-builtin,
invalid-str-codec,
locally-disabled,
logging-fstring-interpolation, # added by vLLM
logging-not-lazy, # added by vLLM
long-builtin,
long-suffix,
map-builtin-not-iterating,
misplaced-comparison-constant,
missing-class-docstring, # TODO (vLLM): enable
missing-function-docstring,
missing-module-docstring, # TODO (vLLM): enable
metaclass-assignment,
next-method-called,
next-method-defined,
no-absolute-import,
no-else-break,
no-else-continue,
no-else-raise,
no-else-return,
no-init, # added
no-member,
no-name-in-module,
no-self-use,
nonzero-method,
oct-method,
old-division,
old-ne-operator,
old-octal-literal,
old-raise-syntax,
parameter-unpacking,
print-statement,
raising-string,
range-builtin-not-iterating,
raw_input-builtin,
rdiv-method,
reduce-builtin,
relative-import,
reload-builtin,
round-builtin,
setslice-method,
signature-differs,
standarderror-builtin,
suppressed-message,
sys-max-int,
too-few-public-methods,
too-many-ancestors,
too-many-arguments,
too-many-boolean-expressions,
too-many-branches,
too-many-instance-attributes,
too-many-locals,
too-many-nested-blocks,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
trailing-newlines,
unichr-builtin,
unicode-builtin,
unnecessary-pass,
unpacking-in-except,
unspecified-encoding,
useless-else-on-loop,
useless-object-inheritance,
useless-suppression,
using-cmp-argument,
wrong-import-order,
xrange-builtin,
zip-builtin-not-iterating,
[REPORTS]
# Set the output format. Available formats are text, parseable, colorized, msvs
# (visual studio) and html. You can also give a reporter class, eg
# mypackage.mymodule.MyReporterClass.
output-format=text
# Tells whether to display a full report or only the messages
reports=no
# Python expression which should return a note less than 10 (10 is the highest
# note). You have access to the variables errors warning, statement which
# respectively contain the number of errors / warnings messages and the total
# number of statements analyzed. This is used by the global evaluation report
# (RP0004).
evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
# Template used to display messages. This is a python new-style format string
# used to format the message information. See doc for all details
#msg-template=
[BASIC]
# Good variable names which should always be accepted, separated by a comma
good-names=main,_
# Bad variable names which should always be refused, separated by a comma
bad-names=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Include a hint for the correct naming format with invalid-name
include-naming-hint=no
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl
# Regular expression matching correct function names
function-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$
# Regular expression matching correct variable names
variable-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct constant names
const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct attribute names
attr-rgx=^_{0,2}[a-z][a-z0-9_]*$
# Regular expression matching correct argument names
argument-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class attribute names
class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$
# Regular expression matching correct inline iteration names
inlinevar-rgx=^[a-z][a-z0-9_]*$
# Regular expression matching correct class names
class-rgx=^_?[A-Z][a-zA-Z0-9]*$
# Regular expression matching correct module names
module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$
# Regular expression matching correct method names
method-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=10
[TYPECHECK]
# List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers.
contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
# Tells whether missing members accessed in mixin class should be ignored. A
# mixin class is detected if its name ends with "mixin" (case insensitive).
ignore-mixin-members=yes
# List of module names for which member attributes should not be checked
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis. It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
# List of class names for which member attributes should not be checked (useful
# for classes with dynamically set attributes). This supports the use of
# qualified names.
ignored-classes=optparse.Values,thread._local,_thread._local
# List of members which are set dynamically and missed by pylint inference
# system, and so shouldn't trigger E1101 when accessed. Python regular
# expressions are accepted.
generated-members=
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=80
# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
# lines made too long by directives to pytype.
# Regexp for a line that is allowed to be longer than the limit.
ignore-long-lines=(?x)(
^\s*(\#\ )?<?https?://\S+>?$|
^\s*(from\s+\S+\s+)?import\s+.+$)
# Allow the body of an if to be on the same line as the test if there is no
# else.
single-line-if-stmt=yes
# Maximum number of lines in a module
max-module-lines=99999
# String used as indentation unit. The internal Google style guide mandates 2
# spaces. Google's externaly-published style guide says 4, consistent with
# PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google
# projects (like TensorFlow).
indent-string=' '
# Number of spaces of indent required inside a hanging or continued line.
indent-after-paren=4
# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
expected-line-ending-format=
[MISCELLANEOUS]
# List of note tags to take in consideration, separated by a comma.
notes=TODO
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=yes
[VARIABLES]
# Tells whether we should check for unused import in __init__ files.
init-import=no
# A regular expression matching the name of dummy variables (i.e. expectedly
# not used).
dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_)
# List of additional names supposed to be defined in builtins. Remember that
# you should avoid to define new builtins when possible.
additional-builtins=
# List of strings which can identify a callback function by name. A callback
# name must start or end with one of those strings.
callbacks=cb_,_cb
# List of qualified module names which can have objects that can redefine
# builtins.
redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools
[LOGGING]
# Logging modules to check that the string format arguments are in logging
# function parameter format
logging-modules=logging,absl.logging,tensorflow.io.logging
[SIMILARITIES]
# Minimum lines number of a similarity.
min-similarity-lines=4
# Ignore comments when computing similarities.
ignore-comments=yes
# Ignore docstrings when computing similarities.
ignore-docstrings=yes
# Ignore imports when computing similarities.
ignore-imports=no
[SPELLING]
# Spelling dictionary name. Available dictionaries: none. To make it working
# install python-enchant package.
spelling-dict=
# List of comma separated words that should not be checked.
spelling-ignore-words=
# A path to a file that contains private dictionary; one word per line.
spelling-private-dict-file=
# Tells whether to store unknown words to indicated private dictionary in
# --spelling-private-dict-file option instead of raising a message.
spelling-store-unknown-words=no
[IMPORTS]
# Deprecated modules which should not be used, separated by a comma
deprecated-modules=regsub,
TERMIOS,
Bastion,
rexec,
sets
# Create a graph of every (i.e. internal and external) dependencies in the
# given file (report RP0402 must not be disabled)
import-graph=
# Create a graph of external dependencies in the given file (report RP0402 must
# not be disabled)
ext-import-graph=
# Create a graph of internal dependencies in the given file (report RP0402 must
# not be disabled)
int-import-graph=
# Force import order to recognize a module as part of the standard
# compatibility libraries.
known-standard-library=
# Force import order to recognize a module as part of a third party library.
known-third-party=enchant, absl
# Analyse import fallback blocks. This can be used to support both Python 2 and
# 3 compatible code, which means that the block might have code that exists
# only in one or another interpreter, leading to false positives when analysed.
analyse-fallback-blocks=no
[CLASSES]
# List of method names used to declare (i.e. assign) instance attributes.
defining-attr-methods=__init__,
__new__,
setUp
# List of member names, which should be excluded from the protected access
# warning.
exclude-protected=_asdict,
_fields,
_replace,
_source,
_make
# List of valid names for the first argument in a class method.
valid-classmethod-first-arg=cls,
class_
# List of valid names for the first argument in a metaclass class method.
valid-metaclass-classmethod-first-arg=mcs
[EXCEPTIONS]
# Exceptions that will emit a warning when being caught. Defaults to
# "Exception"
overgeneral-exceptions=StandardError,
Exception,
BaseException

21
.readthedocs.yaml Normal file
View File

@ -0,0 +1,21 @@
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
version: 2
build:
os: ubuntu-22.04
tools:
python: "3.8"
sphinx:
configuration: docs/source/conf.py
# If using Sphinx, optionally build your docs in additional formats such as PDF
formats:
- pdf
# Optionally declare the Python requirements required to build your docs
python:
install:
- requirements: docs/requirements-docs.txt

77
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,77 @@
# Contributing to vLLM
Thank you for your interest in contributing to vLLM!
Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large.
There are several ways you can contribute to the project:
- Identify and report any issues or bugs.
- Request or add a new model.
- Suggest or implement new features.
However, remember that contributions aren't just about code.
We believe in the power of community support; thus, answering queries, assisting others, and enhancing the documentation are highly regarded and beneficial contributions.
Finally, one of the most impactful ways to support us is by raising awareness about vLLM.
Talk about it in your blog posts, highlighting how it's driving your incredible projects.
Express your support on Twitter if vLLM aids you, or simply offer your appreciation by starring our repository.
## Setup for development
### Build from source
```bash
pip install -r requirements.txt
pip install -e . # This may take several minutes.
```
### Testing
```bash
pip install -r requirements-dev.txt
# Static type checking
mypy
# Unit tests
pytest tests/
```
**Note:** Currently, the repository does not pass the mypy tests.
## Contributing Guidelines
### Issue Reporting
If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it.
If not, please file a new issue, providing as much relevant information as possible.
### Coding Style Guide
In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
We include a formatting script [`format.sh`](./format.sh) to format the code.
### Pull Requests
When submitting a pull request:
1. Make sure your code has been rebased on top of the latest commit on the main branch.
2. Ensure code is properly formatted by running [`format.sh`](./format.sh).
3. Include a detailed description of the changes in the pull request.
Explain why you made the changes you did.
If your pull request fixes an open issue, please include a reference to it in the description.
### Code Reviews
All submissions, including submissions by project members, require a code review.
To make the review process as smooth as possible, please:
1. Keep your changes as concise as possible.
If your pull request involves multiple unrelated changes, consider splitting it into separate pull requests.
2. Respond to all comments within a reasonable time frame.
If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
### Thank You
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM.
Your contributions make vLLM a great tool for everyone!

201
LICENSE Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

4
MANIFEST.in Normal file
View File

@ -0,0 +1,4 @@
include LICENSE
include requirements.txt
recursive-include csrc *

154
README.md
View File

@ -1,72 +1,106 @@
# CacheFlow
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/logos/vllm-logo-text-dark.png">
<img alt="vLLM" src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/logos/vllm-logo-text-light.png" width=55%>
</picture>
</p>
## Installation
<h3 align="center">
Easy, fast, and cheap LLM serving for everyone
</h3>
<p align="center">
| <a href="https://vllm.readthedocs.io/en/latest/"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://github.com/vllm-project/vllm/discussions"><b>Discussions</b></a> |
</p>
---
*Latest News* 🔥
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
---
vLLM is a fast and easy-to-use library for LLM inference and serving.
vLLM is fast with:
- State-of-the-art serving throughput
- Efficient management of attention key and value memory with **PagedAttention**
- Continuous batching of incoming requests
- Optimized CUDA kernels
vLLM is flexible and easy to use with:
- Seamless integration with popular HuggingFace models
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
- Tensor parallelism support for distributed inference
- Streaming outputs
- OpenAI-compatible API server
vLLM seamlessly supports many Huggingface models, including the following architectures:
- Aquila (`BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
- Baichuan (`baichuan-inc/Baichuan-7B`, `baichuan-inc/Baichuan-13B-Chat`, etc.)
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
```bash
pip install psutil numpy ray torch
pip install git+https://github.com/huggingface/transformers # Required for LLaMA.
pip install sentencepiece # Required for LlamaTokenizer.
pip install ninja # To parallelize the compilation of flash-attn.
pip install flash-attn # This may take up to 10 mins.
pip install -e .
pip install vllm
```
## Test simple server
## Getting Started
```bash
ray start --head
python simple_server.py
```
Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started.
- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
The detailed arguments for `simple_server.py` can be found by:
```bash
python simple_server.py --help
```
## Performance
## FastAPI server
vLLM outperforms HuggingFace Transformers (HF) by up to 24x and Text Generation Inference (TGI) by up to 3.5x, in terms of throughput.
For details, check out our [blog post](https://vllm.ai).
Install the following additional dependencies:
```bash
pip install fastapi uvicorn
```
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a10g_n1_dark.png">
<img src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a10g_n1_light.png" width="45%">
</picture>
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a100_n1_dark.png">
<img src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a100_n1_light.png" width="45%">
</picture>
<br>
<em> Serving throughput when each request asks for 1 output completion. </em>
</p>
To start the server:
```bash
ray start --head
python -m cacheflow.http_frontend.fastapi_frontend
```
<p align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a10g_n3_dark.png">
<img src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a10g_n3_light.png" width="45%">
</picture>
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a100_n3_dark.png">
<img src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/figures/perf_a100_n3_light.png" width="45%">
</picture> <br>
<em> Serving throughput when each request asks for 3 output completions. </em>
</p>
To test the server:
```bash
python -m cacheflow.http_frontend.test_cli_client
```
## Contributing
## Gradio web server
Install the following additional dependencies:
```bash
pip install gradio
```
Start the server:
```bash
python -m cacheflow.http_frontend.fastapi_frontend
# At another terminal
python -m cacheflow.http_frontend.gradio_webserver
```
## Load LLaMA weights
Since LLaMA weight is not fully public, we cannot directly download the LLaMA weights from huggingface. Therefore, you need to follow the following process to load the LLaMA weights.
1. Converting LLaMA weights to huggingface format with [this script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py).
```bash
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path/llama-7b
```
Please make sure that `llama` is included in the output directory name.
2. For all the commands above, specify the model with `--model /output/path/llama-7b` to load the model. For example:
```bash
python simple_server.py --model /output/path/llama-7b
python -m cacheflow.http_frontend.fastapi_frontend --model /output/path/llama-7b
```
We welcome and value any contributions and collaborations.
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.

View File

@ -1,165 +0,0 @@
import functools
import random
import time
from typing import List
from flash_attn.flash_attn_interface import _flash_attn_forward
import torch
from cacheflow import attention_ops
def benchmark(name, f, num_warmup = 10, num_iters = 100):
for _ in range(num_warmup):
f()
torch.cuda.synchronize()
start = time.time()
for _ in range(num_iters):
f()
torch.cuda.synchronize()
end = time.time()
print(f'{name}: {(end - start) / num_iters * 1000:.3f} ms')
@torch.inference_mode()
def benchmark_multi_query_cached_kv_attention(
query_lens: List[int],
context_lens: List[int],
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
) -> None:
print(f'query_lens: {query_lens}, context_lens: {context_lens}, '
f'num_heads: {num_heads}, head_size: {head_size}, block_size: '
f'{block_size}, num_blocks: {num_blocks}, dtype: {dtype}')
# Create query tensor.
num_queries = len(query_lens)
cu_query_lens = [0]
for query_len in query_lens:
cu_query_lens.append(cu_query_lens[-1] + query_len)
num_total_tokens = cu_query_lens[-1]
qkv = torch.randn(
num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
query, _, _ = qkv.unbind(dim=1)
# Create key and value cache.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.randn(
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
value_block_shape = (num_heads, head_size, block_size)
value_cache = torch.randn(
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
# Create block tables.
max_context_len = max(context_lens)
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_queries):
block_table = [
random.randint(0, num_blocks - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
# Create input and output data structures.
cu_query_lens = torch.tensor(cu_query_lens, dtype=torch.int, device='cuda')
context_len_tensor = torch.tensor(context_lens, dtype=torch.int, device='cuda')
scale = float(1.0 / (head_size ** 0.5))
output = torch.empty(
num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda')
# Run our implementation.
def run_ours():
attention_ops.multi_query_cached_kv_attention(
cu_query_lens,
output,
query,
key_cache,
value_cache,
scale,
block_tables,
context_len_tensor,
block_size,
max_context_len,
)
benchmark('Ours', run_ours)
# Upper bound: Flash attention.
# Becuase Flash attention cannot read our own cache,
# we make key and value tensors contiguous.
num_kv_tokens = sum(context_lens)
cu_context_lens = [0]
for context_len in context_lens:
cu_context_lens.append(cu_context_lens[-1] + context_len)
cu_context_lens = torch.tensor(cu_context_lens, dtype=torch.int, device='cuda')
qkv = torch.randn(
num_kv_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
_, key, value = qkv.unbind(dim=1)
ref_output = torch.empty_like(output)
# Run Flash attention.
def run_flash_attn():
_flash_attn_forward(
query,
key,
value,
ref_output,
cu_query_lens,
cu_context_lens,
max(query_lens),
max_context_len,
dropout_p=0.0,
softmax_scale=scale,
causal=True,
return_softmax=False,
)
benchmark('Flash attention', run_flash_attn)
if __name__ == '__main__':
BLOCK_SIZE = 8
NUM_BLOCKS = 1024
DTYPE = torch.half
# LLaMA-13B and OPT-13B
NUM_HEADS = 40
HEAD_SIZE = 128
run_benchmark = functools.partial(
benchmark_multi_query_cached_kv_attention,
num_heads=NUM_HEADS,
head_size=HEAD_SIZE,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
dtype=DTYPE,
)
run_benchmark(
query_lens=[64] * 1,
context_lens=[64] * 1,
)
run_benchmark(
query_lens=[128] * 1,
context_lens=[128] * 1,
)
run_benchmark(
query_lens=[64] * 8,
context_lens=[64] * 8,
)
run_benchmark(
query_lens=[128] * 8,
context_lens=[128] * 8,
)
run_benchmark(
query_lens=[64, 32, 16],
context_lens=[128, 256, 64],
)
run_benchmark(
query_lens=[1024],
context_lens=[1024],
)

View File

@ -1,81 +0,0 @@
import functools
import random
import time
import torch
from cacheflow import cache_ops
def benchmark(name, f, size: int, num_warmup = 10, num_iters = 100):
for _ in range(num_warmup):
f()
torch.cuda.synchronize()
start = time.time()
for _ in range(num_iters):
f()
torch.cuda.synchronize()
end = time.time()
avg_time = (end - start) / num_iters
print(f'[Latency] {name}: {avg_time * 1000:.3f} ms')
print(f'[Throughput] {name}: {size / avg_time / 2 ** 30:.3f} GB/s')
@torch.inference_mode()
def test_gather_cached_kv(
num_tokens: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
) -> None:
print(f'num_tokens: {num_tokens}, num_heads: {num_heads}, '
f'head_size: {head_size}, block_size: {block_size}, '
f'num_blocks: {num_blocks}, dtype: {dtype}')
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
qkv = torch.randn(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
_, key, value = qkv.unbind(dim=1)
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda')
value_cache_shape = (num_blocks, num_heads, head_size, block_size)
value_cache = torch.randn(
size=value_cache_shape, dtype=dtype, device='cuda')
# Run Flash attention.
def run():
cache_ops.gather_cached_kv(key, value, key_cache, value_cache, slot_mapping)
benchmark('gather_cached_kv', run,
size=num_tokens * num_heads * head_size * 2 * qkv.element_size())
if __name__ == '__main__':
BLOCK_SIZE = 8
NUM_BLOCKS = 1024
DTYPE = torch.half
# LLaMA-13B and OPT-13B
NUM_HEADS = 40
HEAD_SIZE = 128
run_benchmark = functools.partial(
test_gather_cached_kv,
num_heads=NUM_HEADS,
head_size=HEAD_SIZE,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
dtype=DTYPE,
)
for i in range(6, 12):
run_benchmark(num_tokens=2 ** i)

View File

@ -1,105 +0,0 @@
import argparse
import time
from typing import List
from tqdm import tqdm
import numpy as np
import torch
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.master.server import (Server, add_server_arguments,
initialize_ray_cluster)
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory
def main(args: argparse.Namespace):
# TODO(zhuohan): Support pipeline parallelism.
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_ray_cluster(
address='local',
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
# Create a server.
server = Server(
model=args.model,
model_path=args.model_path,
use_dummy_weights=args.use_dummy_weights,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
)
# Create a frontend.
frontend = SimpleFrontend(
model_name=args.model,
block_size=args.block_size,
)
sampling_params_dict = {
'n': args.n,
'temperature': 0.0 if args.use_beam_search else 1.0,
'top_p': 1.0,
'use_beam_search': args.use_beam_search,
'stop_token_ids': set(),
'max_num_steps': args.output_len,
}
sampling_params = SamplingParams.from_dict(sampling_params_dict)
print(sampling_params)
input_token_ids = [0] * args.input_len
def profile_step(profile=False):
if profile:
torch.cuda.cudart().cudaProfilerStart()
for _ in range(args.batch_size):
frontend._add_query(input_token_ids, sampling_params)
server.add_sequence_groups(frontend.get_inputs())
start_time = time.time()
while True:
server.step()
if not server.has_unfinished_requests():
break
end_time = time.time()
latency = end_time - start_time
if profile:
torch.cuda.cudart().cudaProfilerStop()
return latency
print("Warm up step")
profile_step()
# Benchmark.
latencies = []
for _ in tqdm(range(3), desc="Profile step"):
latencies.append(profile_step())
print(f'Avg latency: {np.mean(latencies)} seconds')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
parser = add_server_arguments(parser)
parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--n', type=int, default=1)
parser.add_argument('--use-beam-search', action='store_true')
args = parser.parse_args()
args.max_num_batched_tokens = max(
args.max_num_batched_tokens, args.batch_size * args.input_len)
print(args)
main(args)

View File

@ -1,290 +0,0 @@
import argparse
import logging
import os
import pickle
import time
from typing import List
from tqdm import tqdm
from transformers import AutoConfig
from benchmark.trace import generate_text_completion_requests
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.master.server import (Server, add_server_arguments,
initialize_ray_cluster)
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory
logger = logging.getLogger(__name__)
def main(args: argparse.Namespace):
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_ray_cluster(
address='local',
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
# Create a server.
server = Server(
model=args.model,
model_path=args.model_path,
use_dummy_weights=args.use_dummy_weights,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_num_batched_tokens=args.max_num_batched_tokens,
max_num_sequences=args.max_num_sequences,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
collect_stats=True,
do_memory_analysis=args.do_memory_analysis,
)
# Create a frontend.
frontend = SimpleFrontend(
model_name=args.model,
block_size=args.block_size,
)
# Generate requests.
requests = generate_text_completion_requests(
args.dataset,
args.request_rate,
args.duration,
args.seed,
args.n1,
args.n2,
args.n3,
args.n4,
args.n6,
args.n2_beam,
args.n4_beam,
args.n6_beam,
args.n8_beam,
)
# Warm up.
logger.info('Warming up.')
num_warmup_requests = 8
warmup_input_len = 8
warmup_output_len = 32
warmup_sampling_params = SamplingParams(
n=1,
temperature=1.0,
top_p=0.99,
max_num_steps=warmup_output_len,
use_beam_search=False,
stop_token_ids=set(),
num_logprobs=0,
context_window_size=None,
)
for _ in range(num_warmup_requests):
frontend._add_query([0] * warmup_input_len, warmup_sampling_params)
server.add_sequence_groups(frontend.get_inputs())
while True:
server.step()
if not server.has_unfinished_requests():
break
# Start benchmarking.
logger.info('Start benchmarking.')
# Initialize tqdm.
pbar = tqdm(total=len(requests), desc='Finished requests')
finished = []
server.scheduler.reset_stats()
start_time = time.time()
while True:
now = time.time()
if args.timeout is not None and now - start_time > args.timeout:
logger.info('Timeout. Stop benchmarking.')
break
while requests:
if requests[0][0] <= now - start_time:
request_time, input_tokens, sampling_params = requests.pop(0)
frontend._add_query(
input_tokens, sampling_params, arrival_time=start_time + request_time)
else:
break
server.add_sequence_groups(frontend.get_inputs())
updated_seq_groups = server.step()
now = time.time()
for seq_group in updated_seq_groups:
if not seq_group.is_finished():
continue
arrival_time = seq_group.arrival_time
finish_time = now
for seq in seq_group.get_seqs():
seq_len = seq.get_len()
output_len = seq_len - seq.prompt_len
finished.append({
'group_id': seq_group.group_id,
'seq_id': seq.seq_id,
'arrival_time': arrival_time,
'finish_time': finish_time,
'prompt_len': seq.prompt_len,
'output_len': output_len,
})
pbar.update(1)
if not (requests or server.has_unfinished_requests()):
break
pbar.close()
logger.info('Finish benchmarking. Saving stats.')
server.scheduler.save_stats(args.output_dir)
with open(os.path.join(args.output_dir, 'sequences.pkl'), 'wb') as f:
pickle.dump(finished, f)
logger.info('Done.')
def get_model_name(model: str) -> str:
OPT_MODELS = [
'opt-125m',
'opt-350m',
'opt-1.3b',
'opt-2.7b',
'opt-6.7b',
'opt-13b',
'opt-30b',
'opt-66b',
'opt-175b',
]
for opt_model in OPT_MODELS:
if opt_model in model:
return opt_model
config = AutoConfig.from_pretrained(model)
assert config.model_type == 'llama'
hidden_size = config.hidden_size
if hidden_size == 4096:
return 'llama-7b'
elif hidden_size == 5120:
return 'llama-13b'
elif hidden_size == 6656:
return 'llama-30b'
elif hidden_size == 8192:
return 'llama-65b'
else:
raise ValueError(f'Unknown model: {model}')
def get_dataset_name(dataset: str) -> str:
if 'sharegpt' in dataset.lower():
return 'sharegpt'
elif 'alpaca' in dataset.lower():
return 'alpaca'
else:
raise ValueError(f'Unknown dataset: {dataset}')
def get_sampling_dir_name(
n1: float,
n2: float,
n3: float,
n4: float,
n6: float,
n2_beam: float,
n4_beam: float,
n6_beam: float,
n8_beam: float,
) -> str:
method = ''
if n1 > 0.0:
method = 'n1' if n1 == 1.0 else method + f'n1-{n1}-'
if n2 > 0.0:
method = 'n2' if n2 == 1.0 else method + f'n2-{n2}-'
if n3 > 0.0:
method = 'n3' if n3 == 1.0 else method + f'n3-{n3}-'
if n4 > 0.0:
method = 'n4' if n4 == 1.0 else method + f'n4-{n4}-'
if n6 > 0.0:
method = 'n6' if n6 == 1.0 else method + f'n6-{n6}-'
if n2_beam > 0.0:
method = 'n2-beam' if n2_beam == 1.0 else method + f'n2-beam-{n2_beam}-'
if n4_beam > 0.0:
method = 'n4-beam' if n4_beam == 1.0 else method + f'n4-beam-{n4_beam}-'
if n6_beam > 0.0:
method = 'n6-beam' if n6_beam == 1.0 else method + f'n6-beam-{n6_beam}-'
if n8_beam > 0.0:
method = 'n8-beam' if n8_beam == 1.0 else method + f'n8-beam-{n8_beam}-'
return method[:-1] if method.endswith('-') else method
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
parser = add_server_arguments(parser)
parser.add_argument('--output-dir', type=str, help='path to output directory', default=None)
parser.add_argument('--dataset', type=str, help='path to dataset', required=True)
parser.add_argument('--request-rate', type=float, help='reqs/sec', required=True)
parser.add_argument('--duration', type=int, help='duration in seconds', required=True)
parser.add_argument('--do-memory-analysis', action='store_true',
help='do memory analysis (This will lower the throughput. Use this only for analysis.)')
parser.add_argument('--timeout', type=int, help='time out in seconds', default=None)
parser.add_argument('--n1', type=float, help='ratio of requests with n=1', default=0.0)
parser.add_argument('--n2', type=float, help='ratio of requests with n=2', default=0.0)
parser.add_argument('--n3', type=float, help='ratio of requests with n=3', default=0.0)
parser.add_argument('--n4', type=float, help='ratio of requests with n=4', default=0.0)
parser.add_argument('--n6', type=float, help='ratio of requests with n=6', default=0.0)
parser.add_argument('--n2-beam', type=float, help='ratio of requests with n=2 & beam search', default=0.0)
parser.add_argument('--n4-beam', type=float, help='ratio of requests with n=4 & beam search', default=0.0)
parser.add_argument('--n6-beam', type=float, help='ratio of requests with n=6 & beam search', default=0.0)
parser.add_argument('--n8-beam', type=float, help='ratio of requests with n=8 & beam search', default=0.0)
args = parser.parse_args()
if args.n1 + args.n2 + args.n3 + args.n4 + args.n6 + args.n2_beam + args.n4_beam + args.n6_beam + args.n8_beam != 1.0:
raise ValueError('The ratios of requests must sum to 1.')
model_name = get_model_name(args.model)
dataset_name = get_dataset_name(args.dataset)
if 'opt' in model_name:
if 'opt' not in args.dataset.lower():
raise ValueError(f'OPT models can only be used with OPT datasets.')
elif 'llama' in model_name:
if 'llama' not in args.dataset.lower():
raise ValueError(f'Llama models can only be used with Llama datasets.')
dataset_name = 'sharegpt' if 'sharegpt' in args.dataset else 'alpaca'
sample_dir = get_sampling_dir_name(
args.n1, args.n2, args.n3, args.n4, args.n6, args.n2_beam, args.n4_beam, args.n6_beam, args.n8_beam)
if args.output_dir is None:
args.output_dir = os.path.join(
'../exp',
dataset_name,
f'{model_name}-tp{args.tensor_parallel_size}',
sample_dir,
'cacheflow',
f'block{args.block_size}',
f'req-rate-{args.request_rate}',
f'seed{args.seed}',
f'duration-{args.duration}',
)
os.makedirs(args.output_dir, exist_ok=True)
# Set up logging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
handlers=[
logging.StreamHandler(),
logging.FileHandler(os.path.join(args.output_dir, 'log.txt')),
],
)
logger.info(args)
main(args)

View File

@ -1,116 +0,0 @@
import pickle
import random
from typing import List, Tuple
import numpy as np
from cacheflow.sampling_params import SamplingParams
def generate_text_completion_requests(
dataset: str,
request_rate: float,
duration: int,
seed: int,
n1: float = 0.0,
n2: float = 0.0,
n3: float = 0.0,
n4: float = 0.0,
n6: float = 0.0,
n2_beam: float = 0.0,
n4_beam: float = 0.0,
n6_beam: float = 0.0,
n8_beam: float = 0.0,
max_seq_len: int = 2048,
time_quantum: int = 10,
) -> List[Tuple[float, List[int], SamplingParams]]:
random.seed(seed)
np.random.seed(seed)
# Generate timestamps for requests using Poisson distribution.
lam = request_rate * (time_quantum / 1000)
quantums_per_sec = 1000 / time_quantum
arrival_times = np.random.poisson(
lam=lam, size=int(duration * quantums_per_sec))
timestamps = []
for i, n in enumerate(arrival_times):
timestamps += [i * (time_quantum / 1000)] * n
# Load and shuffle the dataset.
num_requests = len(timestamps)
with open(dataset, 'rb') as f:
data = pickle.load(f)
filtered = []
for pair in data:
input_tokens, output_tokens = pair
input_len = len(input_tokens)
output_len = len(output_tokens)
# Filter out too long sequences.
if input_len + output_len < max_seq_len:
# Output tokens are not needed for the benchmark.
filtered.append((input_tokens, output_len))
data = []
while len(data) < num_requests:
data += filtered
data = data[:num_requests]
# Shuffle the data.
assert len(data) == len(timestamps)
random.shuffle(data)
random_sampling_params_dict = {
'temperature': 1.0,
'top_p': 1.0,
'use_beam_search': False,
'stop_token_ids': set(),
'num_logprobs': 0,
'context_window_size': None,
}
beam_search_params_dict = {
'temperature': 0.0,
'top_p': 1.0,
'use_beam_search': True,
'stop_token_ids': set(),
'num_logprobs': 0,
'context_window_size': None,
}
# Generate requests based on the sampling parameter ratio.
requests = []
assert n1 + n2 + n3 + n4 + n6 + n2_beam + n4_beam + n6_beam + n8_beam == 1.0
cum_sum = 0
for timestamp, pair in zip(timestamps, data):
input_tokens, output_len = pair
if cum_sum < n1 * num_requests:
sampling_params = SamplingParams(
n=1, max_num_steps=output_len, **random_sampling_params_dict)
elif cum_sum < (n1 + n2) * num_requests:
sampling_params = SamplingParams(
n=2, max_num_steps=output_len, **random_sampling_params_dict)
elif cum_sum < (n1 + n2 + n3) * num_requests:
sampling_params = SamplingParams(
n=3, max_num_steps=output_len, **random_sampling_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4) * num_requests:
sampling_params = SamplingParams(
n=4, max_num_steps=output_len, **random_sampling_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4 + n6) * num_requests:
sampling_params = SamplingParams(
n=6, max_num_steps=output_len, **random_sampling_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4 + n6 + n2_beam) * num_requests:
sampling_params = SamplingParams(
n=2, max_num_steps=output_len, **beam_search_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4 + n6 + n2_beam + n4_beam) * num_requests:
sampling_params = SamplingParams(
n=4, max_num_steps=output_len, **beam_search_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4 + n6 + n2_beam + n4_beam + n6_beam) * num_requests:
sampling_params = SamplingParams(
n=6, max_num_steps=output_len, **beam_search_params_dict)
elif cum_sum < (n1 + n2 + n3 + n4 + n6 + n2_beam + n4_beam + n6_beam + n8_beam) * num_requests:
sampling_params = SamplingParams(
n=8, max_num_steps=output_len, **beam_search_params_dict)
else:
raise ValueError('Invalid request ratio.')
cum_sum += 1
requests.append((timestamp, input_tokens, sampling_params))
return requests

8
benchmarks/README.md Normal file
View File

@ -0,0 +1,8 @@
# Benchmarking vLLM
## Downloading the ShareGPT dataset
You can download the dataset by running:
```bash
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
```

View File

@ -0,0 +1,81 @@
"""Benchmark the latency of processing a single batch of requests."""
import argparse
import time
import numpy as np
import torch
from tqdm import tqdm
from vllm import LLM, SamplingParams
def main(args: argparse.Namespace):
print(args)
# Process all the requests in a single batch if possible.
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(
model=args.model,
tokenizer=args.tokenizer,
tensor_parallel_size=args.tensor_parallel_size,
max_num_seqs=args.batch_size,
max_num_batched_tokens=args.batch_size * args.input_len,
trust_remote_code=args.trust_remote_code,
)
sampling_params = SamplingParams(
n=args.n,
temperature=0.0 if args.use_beam_search else 1.0,
top_p=1.0,
use_beam_search=args.use_beam_search,
ignore_eos=True,
max_tokens=args.output_len,
)
print(sampling_params)
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
def run_to_completion(profile: bool = False):
if profile:
torch.cuda.cudart().cudaProfilerStart()
start_time = time.time()
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.time()
latency = end_time - start_time
if profile:
torch.cuda.cudart().cudaProfilerStop()
return latency
print("Warming up...")
run_to_completion(profile=False)
# Benchmark.
latencies = []
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies.append(run_to_completion(profile=False))
print(f'Avg latency: {np.mean(latencies)} seconds')
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--n', type=int, default=1,
help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters', type=int, default=3,
help='Number of iterations to run.')
parser.add_argument('--trust-remote-code', action='store_true',
help='trust remote code from huggingface')
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,233 @@
"""Benchmark online serving throughput.
On the server side, run one of the following commands:
(vLLM backend)
python -m vllm.entrypoints.api_server \
--model <your_model> --swap-space 16 \
--disable-log-requests
(TGI backend)
./launch_hf_server.sh <your_model>
On the client side, run:
python benchmarks/benchmark_serving.py \
--backend <backend> \
--tokenizer <your_model> --dataset <target_dataset> \
--request-rate <request_rate>
"""
import argparse
import asyncio
import json
import random
import time
from typing import AsyncGenerator, List, Tuple
import aiohttp
import numpy as np
from transformers import PreTrainedTokenizerBase
from vllm.transformers_utils.tokenizer import get_tokenizer
# (prompt len, output len, latency)
REQUEST_LATENCY: List[Tuple[int, int, float]] = []
def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, int, int]]:
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [
data for data in dataset
if len(data["conversations"]) >= 2
]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset]
prompt_token_ids = tokenizer(prompts).input_ids
completions = [completion for _, completion in dataset]
completion_token_ids = tokenizer(completions).input_ids
tokenized_dataset = []
for i in range(len(dataset)):
output_len = len(completion_token_ids[i])
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
# Filter out too long sequences.
filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, output_len in tokenized_dataset:
prompt_len = len(prompt_token_ids)
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
# This is because TGI causes errors when the input or output length
# is too short.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
# Sample the requests.
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests
async def get_request(
input_requests: List[Tuple[str, int, int]],
request_rate: float,
) -> AsyncGenerator[Tuple[str, int, int], None]:
input_requests = iter(input_requests)
for request in input_requests:
yield request
if request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait.
continue
# Sample the request interval from the exponential distribution.
interval = np.random.exponential(1.0 / request_rate)
# The next request will be sent after the interval.
await asyncio.sleep(interval)
async def send_request(
backend: str,
api_url: str,
prompt: str,
prompt_len: int,
output_len: int,
best_of: int,
use_beam_search: bool,
) -> None:
request_start_time = time.time()
headers = {"User-Agent": "Benchmark Client"}
if backend == "vllm":
pload = {
"prompt": prompt,
"n": 1,
"best_of": best_of,
"use_beam_search": use_beam_search,
"temperature": 0.0 if use_beam_search else 1.0,
"top_p": 1.0,
"max_tokens": output_len,
"ignore_eos": True,
"stream": False,
}
elif backend == "tgi":
assert not use_beam_search
params = {
"best_of": best_of,
"max_new_tokens": output_len,
"do_sample": True,
}
pload = {
"inputs": prompt,
"parameters": params,
}
else:
raise ValueError(f"Unknown backend: {backend}")
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout) as session:
while True:
async with session.post(api_url, headers=headers, json=pload) as response:
chunks = []
async for chunk, _ in response.content.iter_chunks():
chunks.append(chunk)
output = b"".join(chunks).decode("utf-8")
output = json.loads(output)
# Re-send the request if it failed.
if "error" not in output:
break
request_end_time = time.time()
request_latency = request_end_time - request_start_time
REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
async def benchmark(
backend: str,
api_url: str,
input_requests: List[Tuple[str, int, int]],
best_of: int,
use_beam_search: bool,
request_rate: float,
) -> None:
tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, request_rate):
prompt, prompt_len, output_len = request
task = asyncio.create_task(send_request(backend, api_url, prompt,
prompt_len, output_len,
best_of, use_beam_search))
tasks.append(task)
await asyncio.gather(*tasks)
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
np.random.seed(args.seed)
api_url = f"http://{args.host}:{args.port}/generate"
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
benchmark_start_time = time.time()
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
args.use_beam_search, args.request_rate))
benchmark_end_time = time.time()
benchmark_time = benchmark_end_time - benchmark_start_time
print(f"Total time: {benchmark_time:.2f} s")
print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
# Compute the latency statistics.
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
print(f"Average latency: {avg_latency:.2f} s")
avg_per_token_latency = np.mean([
latency / (prompt_len + output_len)
for prompt_len, output_len, latency in REQUEST_LATENCY
])
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
avg_per_output_token_latency = np.mean([
latency / output_len
for _, output_len, latency in REQUEST_LATENCY
])
print("Average latency per output token: "
f"{avg_per_output_token_latency:.2f} s")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark the online serving throughput.")
parser.add_argument("--backend", type=str, default="vllm",
choices=["vllm", "tgi"])
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--dataset", type=str, required=True,
help="Path to the dataset.")
parser.add_argument("--tokenizer", type=str, required=True,
help="Name or path of the tokenizer.")
parser.add_argument("--best-of", type=int, default=1,
help="Generates `best_of` sequences per prompt and "
"returns the best one.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", type=int, default=1000,
help="Number of prompts to process.")
parser.add_argument("--request-rate", type=float, default=float("inf"),
help="Number of requests per second. If this is inf, "
"then all the requests are sent at time 0. "
"Otherwise, we use Poisson process to synthesize "
"the request arrival times.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--trust-remote-code', action='store_true',
help='trust remote code from huggingface')
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,221 @@
"""Benchmark offline inference throughput."""
import argparse
import json
import random
import time
from typing import List, Tuple
import torch
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, int, int]]:
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [
data for data in dataset
if len(data["conversations"]) >= 2
]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset]
prompt_token_ids = tokenizer(prompts).input_ids
completions = [completion for _, completion in dataset]
completion_token_ids = tokenizer(completions).input_ids
tokenized_dataset = []
for i in range(len(dataset)):
output_len = len(completion_token_ids[i])
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
# Filter out too long sequences.
filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, output_len in tokenized_dataset:
prompt_len = len(prompt_token_ids)
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
# Sample the requests.
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests
def run_vllm(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
) -> float:
llm = LLM(
model=model,
tokenizer=tokenizer,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
)
# Add the requests to the engine.
for prompt, _, output_len in requests:
sampling_params = SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
)
# FIXME(woosuk): Do not use internal method.
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
sampling_params=sampling_params,
)
start = time.time()
# FIXME(woosuk): Do use internal method.
llm._run_engine(use_tqdm=True)
end = time.time()
return end - start
def run_hf(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
use_beam_search: bool,
max_batch_size: int,
trust_remote_code: bool,
) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(model,
torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
llm = llm.cuda()
pbar = tqdm(total=len(requests))
start = time.time()
batch: List[str] = []
max_prompt_len = 0
max_output_len = 0
for i in range(len(requests)):
prompt, prompt_len, output_len = requests[i]
# Add the prompt to the batch.
batch.append(prompt)
max_prompt_len = max(max_prompt_len, prompt_len)
max_output_len = max(max_output_len, output_len)
if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch.
_, next_prompt_len, next_output_len = requests[i + 1]
if (max(max_prompt_len, next_prompt_len) + max(
max_output_len, next_output_len)) <= 2048:
# We can add more requests to the batch.
continue
# Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=not use_beam_search,
num_return_sequences=n,
temperature=1.0,
top_p=1.0,
use_cache=True,
max_new_tokens=max_output_len,
)
# Include the decoding time.
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
pbar.update(len(batch))
# Clear the batch.
batch = []
max_prompt_len = 0
max_output_len = 0
end = time.time()
return end - start
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
# Sample the requests.
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
if args.backend == "vllm":
elapsed_time = run_vllm(
requests, args.model, args.tokenizer, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search, args.trust_remote_code)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(
requests, args.model, tokenizer, args.n, args.use_beam_search,
args.hf_max_batch_size, args.trust_remote_code)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(
prompt_len + output_len
for _, prompt_len, output_len in requests
)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend", type=str, choices=["vllm", "hf"],
default="vllm")
parser.add_argument("--dataset", type=str, required=True,
help="Path to the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", type=int, default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", type=int, default=1000,
help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size", type=int, default=None,
help="Maximum batch size for HF backend.")
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
args = parser.parse_args()
if args.backend == "vllm":
if args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
elif args.backend == "hf":
if args.hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend.")
if args.tokenizer is None:
args.tokenizer = args.model
main(args)

16
benchmarks/launch_tgi_server.sh Executable file
View File

@ -0,0 +1,16 @@
#!/bin/bash
PORT=8000
MODEL=$1
TOKENS=$2
docker run --gpus all --shm-size 1g -p $PORT:80 \
-v $PWD/data:/data \
ghcr.io/huggingface/text-generation-inference:0.8 \
--model-id $MODEL \
--sharded false \
--max-input-length 1024 \
--max-total-tokens 2048 \
--max-best-of 5 \
--max-concurrent-requests 5000 \
--max-batch-total-tokens $TOKENS

View File

@ -1,179 +0,0 @@
import argparse
import asyncio
import time
from typing import List, Dict
import json
import ray
from transformers import AutoTokenizer
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import uvicorn
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.master.server import (Server, add_server_arguments,
initialize_ray_cluster)
from cacheflow.worker.controller import DeviceID
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
app = FastAPI()
class FastAPIFrontend:
def __init__(
self,
model: str,
model_path: str,
pipeline_parallel_size: int,
tensor_parallel_size: int,
block_size: int,
dtype: str,
seed: int,
swap_space: int,
max_num_batched_tokens: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
all_stage_devices: List[List[DeviceID]],
):
self.block_size = block_size
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.seq_group_counter = Counter()
self.seq_counter = Counter()
remote_server_class = ray.remote(num_cpus=0)(Server)
self.server = remote_server_class.remote(
model=model,
model_path=model_path,
use_dummy_weights=False,
pipeline_parallel_size=pipeline_parallel_size,
tensor_parallel_size=tensor_parallel_size,
block_size=block_size,
dtype=dtype,
seed=seed,
swap_space=swap_space,
max_num_batched_tokens=max_num_batched_tokens,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
)
self.running_seq_groups: Dict[int, SequenceGroup] = {}
self.sequence_group_events: Dict[int, asyncio.Event] = {}
self.is_server_running = False
async def server_step(self):
self.is_server_running = True
updated_seq_groups = await self.server.step.remote()
self.is_server_running = False
# Notify the waiting coroutines that there new outputs ready.
for seq_group in updated_seq_groups:
group_id = seq_group.group_id
self.running_seq_groups[group_id] = seq_group
self.sequence_group_events[group_id].set()
async def generate(self, request_dict: Dict):
# Preprocess the request.
prompt = request_dict["prompt"]
sampling_params = SamplingParams.from_dict(request_dict)
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
token_ids = self.tokenizer.encode(prompt)
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
seqs.append(seq)
arrival_time = time.time()
group_id = next(self.seq_group_counter)
seq_group = SequenceGroup(group_id, seqs, arrival_time)
# Create an event to notify us that there is new output from the
# cacheflow server.
group_event = asyncio.Event()
self.running_seq_groups[group_id] = seq_group
self.sequence_group_events[group_id] = group_event
# Add the request into the cacheflow server's waiting queue.
await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])
# The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# the server to process the requests.
while True:
# Kick the server if the server is not running.
if not self.is_server_running:
await self.server_step()
# Wait for new output. The group_event will be set in server_step
# when there is new output available for the sequence group.
# Added a timeout to prevent deadlock.
await asyncio.wait_for(group_event.wait(), timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
# Reset the event to wait for the next output.
group_event.clear()
# Decode and return new outputs
seq_group = self.running_seq_groups[group_id]
all_outputs = []
for seq in seq_group.seqs:
token_ids = seq.get_token_ids()
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
all_outputs.append(output)
ret = {
"text": all_outputs,
"error": 0,
}
yield (json.dumps(ret) + "\0").encode("utf-8")
# Once finished, release the resources of the sequence group.
if seq_group.is_finished():
del self.running_seq_groups[group_id]
del self.sequence_group_events[group_id]
# Kick the server if the server is not running. This is to
# prevent that there are still requests in server's waiting
# queue to be executed.
if not self.is_server_running:
await self.server_step()
break
@app.post("/generate")
async def generate_stream(request: Request):
request_dict = await request.json()
return StreamingResponse(frontend.generate(request_dict))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=10002)
parser = add_server_arguments(parser)
args = parser.parse_args()
# TODO(zhuohan): Support pipeline parallelism.
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_ray_cluster(
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
frontend = FastAPIFrontend(
model=args.model,
model_path=args.model_path,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_num_batched_tokens=args.max_num_batched_tokens,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

View File

@ -1,43 +0,0 @@
import argparse
import json
import time
import gradio as gr
import requests
def http_bot(prompt):
headers = {"User-Agent": "Cacheflow Client"}
pload = {
"prompt": prompt,
"max_num_steps": 128,
}
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"][0]
yield output
def build_demo():
with gr.Blocks() as demo:
gr.Markdown(
"# Cacheflow demo\n"
)
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")# .style(container=False)
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model")
inputbox.submit(http_bot, [inputbox], [outputbox])
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=10003)
parser.add_argument("--model-url", type=str, default="http://localhost:10002/generate")
args = parser.parse_args()
demo = build_demo()
demo.queue(concurrency_count=100).launch(server_name=args.host, server_port=args.port)

View File

@ -1,23 +0,0 @@
import requests
import json
def http_request():
prompt = "Ion Stoica is a"
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
"n": 4,
"use_beam_search": True,
"temperature": 0.0,
}
response = requests.post("http://localhost:10002/generate", headers=headers, json=pload, stream=True)
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
yield output
for h in http_request():
print(h, flush=True)

View File

@ -1,529 +0,0 @@
import enum
import os
import pickle
import time
from typing import Any, Dict, List, Optional, Tuple
from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.master.policy import PolicyFactory
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceOutputs
from cacheflow.sequence import SequenceStatus
class PreemptionMode(enum.Enum):
"""Preemption modes.
1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
and swap them back in when the sequences are resumed.
2. Recomputation: Discard the blocks of the preempted sequences and
recompute them when the sequences are resumed, treating the sequences as
new prompts.
"""
SWAP = enum.auto()
RECOMPUTE = enum.auto()
class Scheduler:
def __init__(
self,
controllers: List,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
max_num_batched_tokens: int,
max_num_sequences: int,
collect_stats: bool,
do_memory_analysis: bool = False,
) -> None:
self.controllers = controllers
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
self.max_num_batched_tokens = max_num_batched_tokens
self.max_num_sequences = max_num_sequences
self.collect_stats = collect_stats
self.do_memory_analysis = do_memory_analysis
# Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name='fcfs')
# Create the block space manager.
self.block_manager = BlockSpaceManager(
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
)
# Sequence groups in the WAITING state.
self.waiting: List[SequenceGroup] = []
# Sequence groups in the RUNNING state.
self.running: List[SequenceGroup] = []
# Mapping: group_id -> num_steps.
self.num_steps: Dict[int, int] = {}
# Mapping: group_id -> sampling params.
self.sampling_params: Dict[int, SamplingParams] = {}
# Sequence groups in the SWAPPED state.
self.swapped: List[SequenceGroup] = []
# Performance-related statistics.
self.stats = Stats(num_gpu_blocks, num_cpu_blocks)
def add_sequence_groups(
self,
seq_groups: List[Tuple[SequenceGroup, SamplingParams]],
) -> None:
# Add sequence groups to the waiting queue.
for seq_group, sampling_params in seq_groups:
self.waiting.append(seq_group)
self.sampling_params[seq_group.group_id] = sampling_params
def _schedule(
self,
) -> Tuple[Dict[int, int], Dict[int, int], Dict[int, List[int]], List[int]]:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {}
# Fix the current time.
now = time.time()
# NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
# in order to minimize the preemption overheads.
# Preemption happens only when there is no available slot to keep all
# the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt.
self.running = self.policy.sort_by_priority(now, self.running)
# Reserve new token slots for the running sequence groups.
running: List[SequenceGroup] = []
preempted: List[SequenceGroup] = []
while self.running:
seq_group = self.running.pop(0)
while not self.block_manager.can_append(seq_group):
if self.running:
# Preempt the lowest-priority sequence groups.
victim_seq_group = self.running.pop(-1)
self._preempt(victim_seq_group, blocks_to_swap_out)
preempted.append(victim_seq_group)
else:
# No other sequence groups can be preempted.
# Preempt the current sequence group.
self._preempt(seq_group, blocks_to_swap_out)
preempted.append(seq_group)
break
else:
# Append new slots to the sequence group.
self._append(seq_group, blocks_to_copy)
running.append(seq_group)
self.running = running
# Swap in the sequence groups in the SWAPPED state if possible.
self.swapped = self.policy.sort_by_priority(now, self.swapped)
# FCFS
while self.swapped and not blocks_to_swap_out:
seq_group = self.swapped[0]
# If the sequence group has been preempted in this step, stop.
if seq_group in preempted:
break
# If the sequence group cannot be swapped in, stop.
if not self.block_manager.can_swap_in(seq_group):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
if len(self.running) + num_seqs > self.max_num_sequences:
break
seq_group = self.swapped.pop(0)
self._swap_in(seq_group, blocks_to_swap_in)
self._append(seq_group, blocks_to_copy)
self.running.append(seq_group)
num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running
)
# Join waiting sequences if possible.
prompt_group_ids: List[int] = []
# NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
# prioritized over the sequence groups in the WAITING state.
# This is because we want to bound the amount of CPU memory taken by
# the swapped sequence groups.
if not self.swapped:
self.waiting = self.policy.sort_by_priority(now, self.waiting)
while self.waiting:
seq_group = self.waiting[0]
# If the sequence group has been preempted in this step, stop.
if seq_group in preempted:
break
# If the sequence group cannot be allocated, stop.
if not self.block_manager.can_allocate(seq_group):
break
# If the number of batched tokens exceeds the limit, stop.
num_prompt_tokens = seq_group.seqs[0].get_len()
if (num_batched_tokens + num_prompt_tokens
> self.max_num_batched_tokens):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
if len(self.running) + num_seqs > self.max_num_sequences:
break
seq_group = self.waiting.pop(0)
self._allocate(seq_group)
self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens
prompt_group_ids.append(seq_group.group_id)
if self.collect_stats:
if self.running or blocks_to_swap_in or blocks_to_swap_out:
self.stats.timestamps.append(now - self.stats.start_time)
self.stats.input_lens.append(num_batched_tokens)
self.stats.swap_out_lens.append(len(blocks_to_swap_out) * self.block_size)
self.stats.swap_in_lens.append(len(blocks_to_swap_in) * self.block_size)
self.stats.num_preemption.append(len(preempted))
self.stats.num_swapped.append(len(self.swapped))
self.stats.num_running.append(len(self.running))
self.stats.num_waiting.append(len(self.waiting))
num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
num_used_gpu_blocks = self.num_gpu_blocks - num_free_gpu_blocks
self.stats.gpu_cache_usage.append(num_used_gpu_blocks / self.num_gpu_blocks)
num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks()
num_used_cpu_blocks = self.num_cpu_blocks - num_free_cpu_blocks
self.stats.cpu_cache_usage.append(num_used_cpu_blocks / self.num_cpu_blocks)
if self.do_memory_analysis:
block_tables = self.block_manager.block_tables
num_logical_blocks = 0
num_logical_tokens = 0
num_physical_blocks = 0
num_physical_tokens = 0
physical_block_numbers = set()
num_reserved_tokens = 0
for seq_group in self.running:
group_id = seq_group.group_id
sampling_params = self.sampling_params[group_id]
max_num_steps = sampling_params.max_num_steps
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
num_logical_blocks += len(seq.logical_token_blocks)
num_logical_tokens += seq.get_len()
seq_id = seq.seq_id
block_table = block_tables[seq_id]
for i, block in enumerate(block_table):
if block.block_number in physical_block_numbers:
continue
physical_block_numbers.add(block.block_number)
num_physical_blocks += 1
num_physical_tokens += seq.logical_token_blocks[i].num_tokens
assert num_physical_blocks == num_used_gpu_blocks
self.stats.num_logical_blocks.append(num_logical_blocks)
self.stats.num_logical_tokens.append(num_logical_tokens)
self.stats.num_physical_blocks.append(num_physical_blocks)
self.stats.num_physical_tokens.append(num_physical_tokens)
self.stats.num_reserved_tokens.append(num_reserved_tokens)
return (blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
prompt_group_ids)
def step(self) -> List[SequenceGroup]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
scheduler_output = self._schedule()
blocks_to_swap_in = scheduler_output[0]
blocks_to_swap_out = scheduler_output[1]
blocks_to_copy = scheduler_output[2]
prompt_group_ids = scheduler_output[3]
# Create input data structures.
input_seq_groups: List[SequenceGroupInputs] = []
updated_seq_groups: List[SequenceGroup] = self.running.copy()
for seq_group in self.running:
group_id = seq_group.group_id
is_prompt = group_id in prompt_group_ids
input_tokens: Dict[int, List[int]] = {}
seq_logprobs: Dict[int, float] = {}
block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
block_tables[seq_id] = self.block_manager.get_block_table(seq)
if is_prompt:
input_tokens[seq_id] = seq.get_token_ids()
else:
input_tokens[seq_id] = [seq.get_last_token_id()]
seq_logprobs[seq_id] = seq.cumulative_logprobs
# NOTE(woosuk): Sequences in the same group have the same
# sequence length
seq_len = seq.get_len()
input_seq_group = SequenceGroupInputs(
group_id=group_id,
is_prompt=is_prompt,
input_tokens=input_tokens,
context_len=seq_len,
seq_logprobs=seq_logprobs,
sampling_params=self.sampling_params[group_id],
block_tables=block_tables,
)
input_seq_groups.append(input_seq_group)
# Execute the first stage of the pipeline.
if input_seq_groups or blocks_to_swap_in or blocks_to_swap_out:
# Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out)
self.controllers[0].execute_stage(
input_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
return updated_seq_groups
def post_step(
self,
seq_outputs: Dict[int, SequenceOutputs],
) -> None:
# Update the running sequences and free blocks.
for seq_group in self.running:
group_id = seq_group.group_id
self.num_steps[group_id] += 1
stop_token_ids = self.sampling_params[group_id].stop_token_ids
# Process beam search results before processing the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam search).
# Free the current sequence.
self.block_manager.free(seq)
# Fork the parent sequence.
parent_seq = seq_group.find(output.parent_seq_id)
parent_seq.fork(seq)
self.block_manager.fork(parent_seq, seq)
# Process the next tokens.
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append(output.output_token, output.logprobs)
# Check if the sequence has generated a stop token.
if output.output_token in stop_token_ids:
self._free_seq(seq)
continue
# Check if the sequence has reached the maximum number of steps.
max_num_steps = self.sampling_params[group_id].max_num_steps
if self.num_steps[group_id] == max_num_steps:
self._free_seq(seq)
continue
# Update the running sequences.
running: List[SequenceGroup] = []
for seq_group in self.running:
if seq_group.is_finished():
self._free_seq_group(seq_group)
else:
running.append(seq_group)
self.running = running
def _allocate(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.seqs:
seq.status = SequenceStatus.RUNNING
# FIXME(woosuk): Support interactive generation.
if seq_group.group_id not in self.num_steps:
self.num_steps[seq_group.group_id] = 0
def _append(
self,
seq_group: SequenceGroup,
blocks_to_copy: Dict[int, List[int]],
) -> None:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
ret = self.block_manager.append(seq)
if ret is not None:
src_block, dst_block = ret
if src_block in blocks_to_copy:
blocks_to_copy[src_block].append(dst_block)
else:
blocks_to_copy[src_block] = [dst_block]
def _preempt(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
preemption_mode: Optional[PreemptionMode] = None,
) -> None:
# If preemption mode is not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences
# (e.g., beam search), recomputation is not supported. In such a case,
# we use swapping instead.
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
# As swapped sequences are prioritized over waiting sequences,
# sequence groups with multiple sequences are implicitly prioritized
# over sequence groups with a single sequence.
# TODO(woosuk): Support recomputation for sequence groups with multiple
# sequences. This may require a more sophisticated CUDA kernel.
if preemption_mode is None:
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
if len(seqs) == 1:
preemption_mode = PreemptionMode.RECOMPUTE
else:
preemption_mode = PreemptionMode.SWAP
if preemption_mode == PreemptionMode.RECOMPUTE:
self._preempt_by_recompute(seq_group)
elif preemption_mode == PreemptionMode.SWAP:
self._preempt_by_swap(seq_group, blocks_to_swap_out)
else:
assert False, 'Invalid preemption mode.'
def _preempt_by_recompute(
self,
seq_group: SequenceGroup,
) -> None:
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
assert len(seqs) == 1
for seq in seqs:
seq.status = SequenceStatus.WAITING
self.block_manager.free(seq)
self.waiting.append(seq_group)
def _preempt_by_swap(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
) -> None:
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
for seq in seqs:
seq.status = SequenceStatus.SWAPPED
self._swap_out(seq_group, blocks_to_swap_out)
self.swapped.append(seq_group)
def _free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq)
def _free_seq_group(self, seq_group: SequenceGroup) -> None:
group_id = seq_group.group_id
del self.num_steps[group_id]
del self.sampling_params[group_id]
def _swap_in(
self,
seq_group: SequenceGroup,
blocks_to_swap_in: Dict[int, int],
) -> None:
mapping = self.block_manager.swap_in(seq_group)
blocks_to_swap_in.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
seq.status = SequenceStatus.RUNNING
def _swap_out(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
) -> None:
assert self.block_manager.can_swap_out(seq_group)
mapping = self.block_manager.swap_out(seq_group)
blocks_to_swap_out.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED
def reset_stats(self) -> None:
self.stats.reset(self.num_gpu_blocks, self.num_cpu_blocks)
def save_stats(
self,
output_dir: str,
) -> None:
assert self.collect_stats, 'Statistics collection is disabled.'
self.stats.save(output_dir)
class Stats:
def __init__(
self,
num_gpu_blocks: int,
num_cpu_blocks: int,
) -> None:
self.start_time: float = time.time()
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
self.timestamps: List[float] = []
self.input_lens: List[int] = []
self.swap_out_lens: List[int] = []
self.swap_in_lens: List[int] = []
self.num_preemption: List[int] = []
self.num_waiting: List[int] = []
self.num_running: List[int] = []
self.num_swapped: List[int] = []
self.gpu_cache_usage: List[float] = []
self.cpu_cache_usage: List[float] = []
self.num_logical_blocks: List[int] = []
self.num_logical_tokens: List[int] = []
self.num_physical_blocks: List[int] = []
self.num_physical_tokens: List[int] = []
self.num_reserved_tokens: List[int] = []
def reset(
self,
num_gpu_blocks: int,
num_cpu_blocks: int,
) -> None:
self.__init__(num_gpu_blocks, num_cpu_blocks)
def to_dict(self) -> Dict[str, Any]:
return {
'start_time': self.start_time,
'num_gpu_blocks': self.num_gpu_blocks,
'num_cpu_blocks': self.num_cpu_blocks,
'timestamps': self.timestamps,
'input_lens': self.input_lens,
'swap_out_lens': self.swap_out_lens,
'swap_in_lens': self.swap_in_lens,
'num_preemption': self.num_preemption,
'num_waiting': self.num_waiting,
'num_running': self.num_running,
'num_swapped': self.num_swapped,
'gpu_cache_usage': self.gpu_cache_usage,
'cpu_cache_usage': self.cpu_cache_usage,
'num_logical_blocks': self.num_logical_blocks,
'num_logical_tokens': self.num_logical_tokens,
'num_physical_blocks': self.num_physical_blocks,
'num_physical_tokens': self.num_physical_tokens,
'num_reserved_tokens': self.num_reserved_tokens,
}
def save(self, output_dir: str) -> None:
with open(os.path.join(output_dir, 'stats.pkl'), 'wb') as f:
pickle.dump(self.to_dict(), f)

View File

@ -1,192 +0,0 @@
import argparse
from typing import List, Tuple
import random
import ray
from cacheflow.master.scheduler import Scheduler
from cacheflow.models import get_memory_analyzer
from cacheflow.worker.controller import Controller, DeviceID
from cacheflow.sequence import SequenceGroup
from cacheflow.sampling_params import SamplingParams
class Server:
def __init__(
self,
model: str,
model_path: str,
use_dummy_weights: bool,
pipeline_parallel_size: int,
tensor_parallel_size: int,
block_size: int,
dtype: str,
seed: int,
swap_space: int,
max_num_batched_tokens: int,
max_num_sequences: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
all_stage_devices: List[List[DeviceID]],
gpu_memory: int,
cpu_memory: int,
collect_stats: bool = False,
do_memory_analysis: bool = False,
):
self.num_nodes = num_nodes
self.num_devices_per_node = num_devices_per_node
self.world_size = pipeline_parallel_size * tensor_parallel_size
self.memory_analyzer = get_memory_analyzer(
model_name=model,
block_size=block_size,
dtype=dtype,
gpu_memory=gpu_memory,
cpu_memory=cpu_memory,
tensor_parallel_size=tensor_parallel_size,
)
self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks(
max_num_batched_tokens=max_num_batched_tokens)
self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks(
swap_space=swap_space)
print(f'# GPU blocks: {self.num_gpu_blocks}, '
f'# CPU blocks: {self.num_cpu_blocks}')
# Create a controller for each pipeline stage.
self.controllers: List[Controller] = []
for i in range(pipeline_parallel_size):
controller = Controller(
stage_id=i,
stage_devices=all_stage_devices[i],
world_size=self.world_size,
pipeline_parallel_size=pipeline_parallel_size,
tensor_parallel_size=tensor_parallel_size,
distributed_init_method=distributed_init_method,
model_name=model,
block_size=block_size,
num_gpu_blocks=self.num_gpu_blocks,
num_cpu_blocks=self.num_cpu_blocks,
dtype=dtype,
seed=seed,
model_path=model_path,
use_dummy_weights=use_dummy_weights,
max_num_batched_tokens=max_num_batched_tokens,
)
self.controllers.append(controller)
# Create a scheduler.
self.scheduler = Scheduler(
controllers=self.controllers,
block_size=block_size,
num_gpu_blocks=self.num_gpu_blocks,
num_cpu_blocks=self.num_cpu_blocks,
max_num_batched_tokens=max_num_batched_tokens,
max_num_sequences=max_num_sequences,
collect_stats=collect_stats,
do_memory_analysis=do_memory_analysis,
)
# Connect the controllers.
for i in range(len(self.controllers) - 1):
self.controllers[i].set_next(self.controllers[i + 1])
self.controllers[-1].set_next(self.scheduler)
def add_sequence_groups(
self,
sequence_groups: List[Tuple[SequenceGroup, SamplingParams]]
):
self.scheduler.add_sequence_groups(sequence_groups)
def step(self):
return self.scheduler.step()
def has_unfinished_requests(self):
return (self.scheduler.waiting or self.scheduler.running or
self.scheduler.swapped)
def initialize_ray_cluster(
address: str = 'auto',
pipeline_parallel_size: int = 1,
tensor_parallel_size: int = 1,
) -> Tuple[int, int, str, List[List[DeviceID]]]:
# Connect to a ray cluster.
ray.init(address=address)
# Assume we have a uniform cluster that each node has the same number of
# GPUs for now.
valid_node_resources = []
num_devices_per_node = None
for node in ray.nodes():
if (not node['Alive']) or node['Resources']['GPU'] <= 0:
continue
if num_devices_per_node is None:
num_devices_per_node = node['Resources']['GPU']
else:
assert num_devices_per_node == node['Resources']['GPU'], (
"The number of GPUs per node is not uniform.")
for key in node['Resources']:
if key.startswith('node:'):
valid_node_resources.append(key)
num_nodes = len(valid_node_resources)
assert (pipeline_parallel_size * tensor_parallel_size
<= num_nodes * num_devices_per_node), (
"The number of required GPUs exceeds the total number of "
"available GPUs.")
if tensor_parallel_size >= num_devices_per_node:
assert tensor_parallel_size % num_devices_per_node == 0, (
"The number of tensor parallelism is not divisible by the "
"number of GPUs per node.")
else:
assert num_devices_per_node % tensor_parallel_size == 0, (
"The number of GPUs per node is not divisible by the number "
"of tensor parallelism.")
# Assign GPUs to pipeline stages.
rank = 0
current_node_id = 0
current_device_id = 0
distributed_init_method = None
all_stage_devices = []
for i in range(pipeline_parallel_size):
stage_devices = []
for j in range(tensor_parallel_size):
node_resource = valid_node_resources[current_node_id]
stage_devices.append((rank, node_resource, current_device_id))
if distributed_init_method is None:
ip = node_resource.split("node:")[-1]
port = random.randint(10000, 20000)
distributed_init_method = f"tcp://{ip}:{port}"
rank += 1
current_device_id += 1
if current_device_id >= num_devices_per_node:
current_node_id += 1
current_device_id = 0
all_stage_devices.append(stage_devices)
return (num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices)
def add_server_arguments(parser: argparse.ArgumentParser):
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
help='model path to download and load the weights')
# Parallel arguments
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser.add_argument('--dtype', type=str, default='half', choices=['half'], help='data type')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-sequences', type=int, default=256, help='maximum number of sequences per iteration')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
return parser

View File

@ -1,69 +0,0 @@
import time
from typing import List, Optional, Set, Tuple
from transformers import AutoTokenizer
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.utils import Counter
class SimpleFrontend:
def __init__(
self,
model_name: str,
block_size: int,
) -> None:
self.block_size = block_size
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.seq_group_counter = Counter()
self.seq_counter = Counter()
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []
def add_eos_token(self, sampling_params: SamplingParams) -> SamplingParams:
# Stop generation when we see an EOS token.
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
return sampling_params
def query(
self,
prompt: str,
sampling_params: SamplingParams,
) -> None:
token_ids = self.tokenizer.encode(prompt)
self._add_query(token_ids, sampling_params)
def _add_query(
self,
token_ids: List[int],
sampling_params: SamplingParams,
arrival_time: Optional[float] = None,
) -> None:
if arrival_time is None:
arrival_time = time.time()
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
seqs.append(seq)
group_id = next(self.seq_group_counter)
seq_group = SequenceGroup(group_id, seqs, arrival_time)
self.inputs.append((seq_group, sampling_params))
def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]:
inputs = self.inputs
self.inputs = []
return inputs
def print_response(
self,
seq_group: SequenceGroup,
) -> None:
for seq in seq_group.seqs:
token_ids = seq.get_token_ids()
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
output = output.strip()
print(f'Seq {seq.seq_id}: {output!r}')

View File

@ -1,10 +0,0 @@
from cacheflow.models.input_metadata import InputMetadata
from cacheflow.models.model_utils import get_memory_analyzer
from cacheflow.models.model_utils import get_model
__all__ = [
'InputMetadata',
'get_memory_analyzer',
'get_model',
]

View File

@ -1,20 +0,0 @@
import torch
import torch.nn as nn
from cacheflow import activation_ops
class SiluAndMul(nn.Module):
def __init__(self):
super().__init__()
def forward(
self,
x: torch.Tensor, # (num_tokens, 2 * d)
) -> torch.Tensor: # (num_tokens, d)
num_tokens = x.shape[0]
d = x.shape[1] // 2
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
activation_ops.silu_and_mul(out, x)
return out

View File

@ -1,207 +0,0 @@
from typing import Optional
from flash_attn.flash_attn_interface import _flash_attn_forward
import torch
import torch.nn as nn
from cacheflow import attention_ops
from cacheflow import cache_ops
from cacheflow import pos_encoding_ops
from cacheflow.models import InputMetadata
class GPTCacheFlowAttention(nn.Module):
def __init__(self, scale: float) -> None:
super().__init__()
self.scale = float(scale)
def multi_query_kv_attention(
self,
output: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
query: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
cumulative_prompt_lens: torch.Tensor, # [num_prompts + 1]
max_prompt_len: int,
) -> None:
if query.dtype == torch.float:
raise ValueError('The float data type is not supported by '
'FlashAttention. Use the half data type instead.')
head_size = query.shape[-1]
if head_size > 128:
raise ValueError('FlashAttention does not support head_size > 128.')
# Directly call FlashAttention's internal function to avoid allocating
# a new tensor for the output.
_flash_attn_forward(
query,
key,
value,
output,
cumulative_prompt_lens,
cumulative_prompt_lens,
max_prompt_len,
max_prompt_len,
dropout_p=0.0,
softmax_scale=self.scale,
causal=True,
return_softmax=False,
)
def single_query_cached_kv_attention(
self,
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
) -> None:
head_size = value_cache.shape[2]
supported_head_sizes = [32, 64, 80, 96, 128, 160, 192, 256]
if head_size not in supported_head_sizes:
raise ValueError(f'head_size ({head_size}) is not supported by '
'the single_query_cached_kv_attention kernel. '
'Use one of the following head sizes: '
f'{supported_head_sizes}.')
block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention(
output,
query,
key_cache,
value_cache,
self.scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
)
def forward(
self,
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# NOTE: The query, key, and value tensors must be sliced from a qkv
# tensor of shape [num_tokens, 3 * num_heads * head_size].
# Reshape the query, key, and value tensors.
num_heads = value_cache.shape[1]
head_size = value_cache.shape[2]
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_heads, head_size)
value = value.view(-1, num_heads, head_size)
# Pre-allocate the output tensor.
output = torch.empty_like(query)
# Compute the attention op for prompts.
num_prompt_tokens = input_metadata.num_prompt_tokens
if num_prompt_tokens > 0:
self.multi_query_kv_attention(
output[:num_prompt_tokens],
query[:num_prompt_tokens],
key[:num_prompt_tokens],
value[:num_prompt_tokens],
input_metadata.cumulative_prompt_lens,
input_metadata.max_prompt_len,
)
# Wait until the cache op is done.
if cache_event is not None:
cache_event.wait()
# Reshape the keys and values and store them in the cache.
num_valid_tokens = input_metadata.num_valid_tokens
if num_valid_tokens > 0:
# The stride is 3 because the key and value are sliced from qkv.
cache_ops.reshape_and_cache(
key[:num_valid_tokens],
value[:num_valid_tokens],
key_cache,
value_cache,
input_metadata.slot_mapping,
)
if input_metadata.num_generation_tokens > 0:
# Compute the attention op for generation tokens.
self.single_query_cached_kv_attention(
output[num_prompt_tokens:num_valid_tokens],
query[num_prompt_tokens:num_valid_tokens],
key_cache,
value_cache,
input_metadata)
# Reshape the output tensor.
# NOTE(woosuk): The output tensor may include paddings.
return output.view(-1, num_heads * head_size)
class OPTCacheFlowAttention(GPTCacheFlowAttention):
"""OPT uses the same attention mechanism as GPT."""
def __init__(self, scale: float) -> None:
super().__init__(scale)
class LlamaCacheFlowAttention(GPTCacheFlowAttention):
"""Llama uses GPT-NeoX style rotary embedding."""
def __init__(
self,
scale: float,
head_size: int,
max_position: int = 8192,
base: int = 10000,
) -> None:
super().__init__(scale)
# Create the cos and sin cache.
inv_freq = 1.0 / (base ** (torch.arange(0, head_size, 2) / head_size))
t = torch.arange(max_position).float()
freqs = torch.einsum('i,j -> ij', t, inv_freq.float())
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
# FIXME(woosuk): This assumes that we configure the default dtype when
# initializing the model. Make it more robust.
torch_dtype = torch.get_default_dtype()
cache = cache.to(torch_dtype)
# Embedding size: [max_position, head_size]
self.register_buffer('cos_sin_cache', cache, persistent=False)
def forward(
self,
positions: torch.LongTensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
# Apply rotary embedding to the query and key before passing them
# to the attention op.
pos_encoding_ops.rotary_embedding_neox(
positions,
query,
key,
self.cos_sin_cache,
)
return super().forward(
query,
key,
value,
key_cache,
value_cache,
input_metadata,
cache_event,
)

View File

@ -1,292 +0,0 @@
"""1D LLaMA model compatible with HuggingFace weights."""
import os
import glob
import filelock
from tqdm import tqdm
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import nn
from transformers import LlamaConfig
from cacheflow.models import InputMetadata
from cacheflow.models.activation import SiluAndMul
from cacheflow.models.attention import LlamaCacheFlowAttention
from cacheflow.models.layernorm import RMSNorm
from cacheflow.models.sample import Sampler
from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from cacheflow.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor]
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=False, input_is_parallel=True,
perform_initialization=False)
if hidden_act != 'silu':
raise ValueError(f'Unsupported activation: {hidden_act}. '
'Only silu is supported for now.')
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class LlamaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim ** -0.5
self.qkv_proj = ColumnParallelLinear(
hidden_size,
3 * self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
self.attn = LlamaCacheFlowAttention(self.scaling, self.head_dim)
def forward(
self,
positions: torch.LongTensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
output, _ = self.o_proj(attn_output)
return output
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
positions: torch.LongTensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class LlamaForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.model = LlamaModel(config)
self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.model(
input_ids, positions, kv_caches, input_metadata, cache_events)
next_tokens = self.sampler(
self.lm_head.weight, hidden_states, input_metadata)
return next_tokens
_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
"qkv_proj.weight", "gate_proj.weight",
"up_proj.weight"]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, weights_path: str):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, param in state_dict.items():
if "qkv_proj" in name or "gate_up_proj" in name:
if "qkv_proj" in name:
original_name = "qkv_proj"
weight_names = ["q_proj", "k_proj", "v_proj"]
shard_size = param.shape[0] // 3
else:
original_name = "gate_up_proj"
weight_names = ["gate_proj", "up_proj"]
shard_size = param.shape[0] // 2
weights_to_concat = []
for weight_name in weight_names:
weight = np.load(os.path.join(
weights_path, name.replace(original_name, weight_name)))
weights_to_concat.append(weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)])
loaded_weight = torch.from_numpy(
np.concatenate(weights_to_concat, axis=0))
else:
loaded_weight = torch.from_numpy(
np.load(os.path.join(weights_path, name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
@staticmethod
def get_weights(model_name: str, path: str):
if not os.path.isfile(os.path.join(model_name, "config.json")):
raise ValueError("LLaMA model's model_name has to be a path"
"to the huggingface model's directory.")
path = os.path.join(model_name, f"np")
path = os.path.abspath(os.path.expanduser(path))
os.makedirs(path, exist_ok=True)
lock_path = os.path.join(path, "file_lock")
lock = filelock.FileLock(lock_path)
with lock:
test_weight_path = os.path.join(path, "model.embed_tokens.weight")
if os.path.exists(test_weight_path):
return path
bin_files = glob.glob(os.path.join(model_name, "*.bin"))
for bin_file in tqdm(bin_files, desc="Convert format"):
state = torch.load(bin_file, map_location="cpu")
for name, param in tqdm(state.items(), leave=False):
param_path = os.path.join(path, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
return path
def initialize_dummy_weights(self) -> None:
for param in self.state_dict().values():
param.data.uniform_(-0.1, 0.1)

View File

@ -1,240 +0,0 @@
import torch
from transformers import AutoConfig
from cacheflow.models.utils import get_dtype_size
_GiB = 1 << 30
class CacheFlowMemoryAnalyzer:
def get_max_num_gpu_blocks(
self,
max_num_batched_tokens: int,
memory_utilization: float,
) -> int:
raise NotImplementedError()
def get_workspace_size(self) -> int:
return 1 * _GiB
def get_cache_block_size(self) -> int:
raise NotImplementedError()
def get_max_num_cpu_blocks(
self,
swap_space: int,
) -> int:
swap_space = swap_space * _GiB
cpu_memory = self.cpu_memory
if swap_space > 0.8 * cpu_memory:
raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) '
'takes more than 80% of the available memory '
f'({cpu_memory / _GiB:.2f} GiB).'
'Please check the swap space size.')
if swap_space > 0.5 * cpu_memory:
print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) '
'takes more than 50% of the available memory '
f'({cpu_memory / _GiB:.2f} GiB).'
'This may slow the system performance.')
max_num_blocks = swap_space // self.get_cache_block_size()
return max_num_blocks
class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
def __init__(
self,
model_name: str,
block_size: int,
dtype: torch.dtype,
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int,
) -> None:
self.model_name = model_name
self.block_size = block_size
self.dtype = dtype
self.gpu_memory = gpu_memory
self.cpu_memory = cpu_memory
self.tensor_parallel_size = tensor_parallel_size
config = AutoConfig.from_pretrained(model_name)
self.num_layers = config.num_hidden_layers
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = config.hidden_size // self.num_heads
self.ffn_size = config.ffn_dim
self.embedding_size = config.word_embed_proj_dim
self.vocab_size = config.vocab_size
self.max_position = config.max_position_embeddings
def _get_param_size(self) -> int:
word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size
if self.embedding_size != self.hidden_size:
# Project in/out.
word_embedding += 2 * self.embedding_size * self.hidden_size
position_embedding = self.max_position * self.hidden_size
ln1 = 2 * self.hidden_size
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
mha = ln1 + q + k + v + out
ln2 = 2 * self.hidden_size
ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
ffn = ln2 + ffn1 + ffn2
total = (word_embedding + position_embedding +
self.num_layers * (mha + ffn))
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
def _get_max_act_size(
self,
max_num_batched_tokens: int,
) -> int:
# NOTE: We approxmiately calculate the maximum activation size by
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
# Double the activation size for input and output.
max_act = 2 * (max(qkv, ffn) + residual)
# Size of output logits.
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
max_act = max(max_act, output_logits)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * max_act
def get_cache_block_size(self) -> int:
key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
value_cache_block = key_cache_block
total = self.num_layers * (key_cache_block + value_cache_block)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
def get_max_num_gpu_blocks(
self,
max_num_batched_tokens: int,
memory_utilization: float = 0.95,
) -> int:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
usable_memory = int(memory_utilization * self.gpu_memory)
param_size = self._get_param_size()
act_size = self._get_max_act_size(max_num_batched_tokens)
workspace_size = self.get_workspace_size()
max_cache_size = usable_memory - (param_size + act_size + workspace_size)
if max_cache_size <= 0:
raise RuntimeError('Not enough GPU memory.')
max_num_blocks = max_cache_size // self.get_cache_block_size()
return max_num_blocks
class LlamaMemoryAnalyzer(CacheFlowMemoryAnalyzer):
def __init__(
self,
model_name: str,
block_size: int,
dtype: torch.dtype,
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int,
) -> None:
self.model_name = model_name
self.block_size = block_size
self.dtype = dtype
self.gpu_memory = gpu_memory
self.cpu_memory = cpu_memory
self.tensor_parallel_size = tensor_parallel_size
config = AutoConfig.from_pretrained(model_name)
self.num_layers = config.num_hidden_layers
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = config.hidden_size // self.num_heads
self.ffn_size = config.intermediate_size
self.vocab_size = config.vocab_size
self.max_position = 8192
def _get_param_size(self) -> int:
word_embedding = self.vocab_size * self.hidden_size // self.tensor_parallel_size
position_embedding = self.max_position * self.hidden_size
# NOTE: LLaMA does not have bias terms.
ln1 = self.hidden_size
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size
# Rotary embedding.
# TODO(woosuk): Share the rotary embedding between layers.
rot = self.max_position * self.head_size
mha = ln1 + q + k + v + out + rot
ln2 = self.hidden_size
gate = self.hidden_size * self.ffn_size // self.tensor_parallel_size
down = self.ffn_size * self.hidden_size // self.tensor_parallel_size
up = self.hidden_size * self.ffn_size // self.tensor_parallel_size
ffn = ln2 + gate + down + up
total = (word_embedding + position_embedding + self.num_layers * (mha + ffn))
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
def _get_max_act_size(
self,
max_num_batched_tokens: int,
) -> int:
# NOTE: We approxmiately calculate the maximum activation size by
# estimating
# 1) the maximum activation tensor size during inference
# 2) the residual tensor size during inference
# Here, we assume that FlashAttention is used and
# thus the attention maps are never materialized in GPU DRAM.
residual = max_num_batched_tokens * self.hidden_size
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
ffn = 2 * (max_num_batched_tokens * self.ffn_size) // self.tensor_parallel_size
# Double the activation size for input and output.
max_act = 2 * (max(qkv, ffn) + residual)
# Size of output logits.
output_logits = 2 * (max_num_batched_tokens * self.vocab_size)
max_act = max(max_act, output_logits)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * max_act
def get_cache_block_size(self) -> int:
key_cache_block = self.block_size * self.hidden_size // self.tensor_parallel_size
value_cache_block = key_cache_block
total = self.num_layers * (key_cache_block + value_cache_block)
dtype_size = get_dtype_size(self.dtype)
return dtype_size * total
def get_max_num_gpu_blocks(
self,
max_num_batched_tokens: int,
memory_utilization: float = 0.95,
) -> int:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
gpu_memory = self.gpu_memory
usable_memory = int(memory_utilization * gpu_memory)
param_size = self._get_param_size()
act_size = self._get_max_act_size(max_num_batched_tokens)
workspace_size = self.get_workspace_size()
max_cache_size = usable_memory - (param_size + act_size + workspace_size)
if max_cache_size <= 0:
raise RuntimeError('Not enough GPU memory.')
max_num_blocks = max_cache_size // self.get_cache_block_size()
return max_num_blocks

View File

@ -1,72 +0,0 @@
from typing import Union
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoConfig
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
from cacheflow.models.memory_analyzer import LlamaMemoryAnalyzer
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
from cacheflow.models.llama import LlamaForCausalLM
from cacheflow.models.opt import OPTForCausalLM
from cacheflow.models.utils import get_torch_dtype
_MODELS = {
'llama': LlamaForCausalLM,
'opt': OPTForCausalLM,
}
_MEMORY_ANALYZERS = {
'llama': LlamaMemoryAnalyzer,
'opt': OPTMemoryAnalyzer,
}
def get_model(
model_name: str,
dtype: Union[torch.dtype, str],
path: str,
use_dummy_weights: bool,
) -> nn.Module:
torch_dtype = get_torch_dtype(dtype)
torch.set_default_dtype(torch_dtype)
config = AutoConfig.from_pretrained(model_name)
for model_class_name, model_class in _MODELS.items():
if model_class_name in model_name:
if use_dummy_weights:
# Create a model instance.
# The weights will be initialized as empty tensors.
model = model_class(config)
model = model.cuda()
# NOTE(woosuk): For precise performance evaluation, we assign
# random values to the weights.
model.initialize_dummy_weights()
else:
# Download model weights if it's not cached.
weights_dir = model_class.get_weights(model_name, path=path)
# Create a model instance.
model = model_class(config)
# Load the weights from the cached or downloaded files.
model.load_weights(weights_dir)
model = model.cuda()
return model.eval(), torch_dtype
raise ValueError(f'Unsupported model name: {model_name}')
def get_memory_analyzer(
model_name: str,
block_size: int,
dtype: Union[torch.dtype, str],
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int = 1,
) -> CacheFlowMemoryAnalyzer:
torch_dtype = get_torch_dtype(dtype)
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
if model_class in model_name:
return memory_analyzer(
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
tensor_parallel_size)
raise ValueError(f'Unsupported model name: {model_name}')

View File

@ -1,287 +0,0 @@
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from cacheflow.models import InputMetadata
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceOutputs
from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_parallel_region
class Sampler(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> Dict[int, SequenceOutputs]:
# Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
logits = gather_from_tensor_model_parallel_region(logits)
# Apply temperature scaling.
temperatures = _get_temperatures(input_metadata)
assert len(temperatures) == logits.shape[0]
if any(t != 1.0 for t in temperatures):
t = torch.tensor(
temperatures, dtype=logits.dtype, device=logits.device)
# Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1))
# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p).
logprobs = torch.log(probs)
# Apply top-p truncation.
top_ps = _get_top_ps(input_metadata)
assert len(top_ps) == probs.shape[0]
if any(p < 1.0 for p in top_ps):
p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device)
probs = _apply_top_p(probs, p)
# Sample the next tokens.
return _sample(probs, logprobs, input_metadata)
def _prune_hidden_states(
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
start_idx = 0
last_token_indicies: List[int] = []
for prompt_len in input_metadata.prompt_lens:
last_token_indicies.append(start_idx + prompt_len - 1)
start_idx += prompt_len
last_token_indicies.extend(
range(start_idx, start_idx + input_metadata.num_generation_tokens))
return hidden_states[last_token_indicies]
def _get_temperatures(
input_metadata: InputMetadata,
) -> List[float]:
# Collect the temperatures for the logits.
temperatures: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature
if temperature == 0.0:
# NOTE: Zero temperature means deterministic sampling
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature = 1.0
if i < input_metadata.num_prompts:
# A prompt input.
temperatures.append(temperature)
else:
# A generation token.
temperatures += [temperature] * len(seq_ids)
return temperatures
def _get_top_ps(
input_metadata: InputMetadata,
) -> List[float]:
top_ps: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
# A prompt input.
top_ps.append(sampling_params.top_p)
else:
# A generation token.
top_ps += [sampling_params.top_p] * len(seq_ids)
return top_ps
def _apply_top_p(
probs: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
# TODO(woosuk): Optimize.
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
probs = torch.gather(
probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1))
return probs
def _get_topk_logprobs(
logprobs: torch.Tensor,
num_logprobs: int,
) -> Dict[int, float]:
if num_logprobs == 0:
return {}
topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
if num_logprobs == 1:
topk_logprobs = [topk_logprobs.item()]
topk_ids = [topk_ids.item()]
else:
topk_logprobs = topk_logprobs.tolist()
topk_ids = topk_ids.tolist()
token_to_logprob: Dict[int, float] = {}
for token_id, logprob in zip(topk_ids, topk_logprobs):
token_to_logprob[token_id] = logprob
return token_to_logprob
def _sample_from_prompt(
prob: torch.Tensor,
sampling_params: SamplingParams,
) -> List[int]:
if sampling_params.use_beam_search:
# Beam search.
beam_width = sampling_params.n
_, next_token_ids = torch.topk(prob, beam_width)
next_token_ids = next_token_ids.tolist()
elif sampling_params.temperature == 0.0:
# Greedy sampling.
assert sampling_params.n == 1
next_token_id = torch.argmax(prob)
next_token_ids = [next_token_id.item()]
else:
# Neucleus sampling.
# Sample n tokens for the prompt.
n = sampling_params.n
next_token_ids = torch.multinomial(
prob, num_samples=n, replacement=True)
next_token_ids = next_token_ids.tolist()
return next_token_ids
def _sample_from_generation_tokens(
seq_ids: List[int],
probs: torch.Tensor,
logprobs: torch.Tensor,
seq_logprobs: List[float],
sampling_params: SamplingParams,
) -> Tuple[List[int], List[int]]:
# NOTE(woosuk): sampling_params.n can be greater than
# len(seq_ids) because some sequences in the group might have
# been already terminated.
if sampling_params.use_beam_search:
# Beam search.
# Add cumulative logprobs for the sequences in the group.
seq_logprobs = torch.tensor(
seq_logprobs, dtype=torch.float, device=logprobs.device)
logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)
vocab_size = logprobs.size(-1)
beam_width = len(seq_ids)
_, topk_ids = torch.topk(logprobs.flatten(), beam_width)
topk_ids = topk_ids.tolist()
seq_idx = [i // vocab_size for i in topk_ids]
beam_seq_ids = [seq_ids[i] for i in seq_idx]
token_ids = [i % vocab_size for i in topk_ids]
beam_outputs: Dict[int, Tuple[int, int]] = {}
outstanding_beams: List[Tuple[int, int]] = []
# If a beam survives, continue with it.
for seq_id, token_id in zip(beam_seq_ids, token_ids):
if seq_id not in beam_outputs:
beam_outputs[seq_id] = (seq_id, token_id)
else:
outstanding_beams.append((seq_id, token_id))
# If a beam is discarded, fork another beam.
for seq_id in seq_ids:
if seq_id not in beam_outputs:
beam_outputs[seq_id] = outstanding_beams.pop()
assert not outstanding_beams
parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
elif sampling_params.temperature == 0.0:
# Greedy sampling.
assert len(seq_ids) == 1
next_token_id = torch.argmax(probs, dim=-1)
next_token_ids = [next_token_id.item()]
parent_seq_ids = seq_ids
else:
# Neucleus sampling.
# Sample 1 token for each sequence in the group.
next_token_ids = torch.multinomial(
probs, num_samples=1, replacement=True)
next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
parent_seq_ids = seq_ids
return parent_seq_ids, next_token_ids
def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
input_metadata: InputMetadata,
) -> Dict[int, SequenceOutputs]:
seq_outputs: Dict[int, SequenceOutputs] = {}
# TODO(woosuk): Optimize.
idx = 0
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
# Generate the next tokens for a prompt input.
assert len(seq_ids) == sampling_params.n
prob = probs[idx]
logprob = logprobs[idx]
idx += 1
# Sample the next tokens.
next_token_ids = _sample_from_prompt(prob, sampling_params)
# Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs(
logprob, sampling_params.num_logprobs)
# Build the output.
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
output_logprobs = next_logprobs.copy()
output_logprobs[next_token_id] = logprob[next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs(
seq_id, seq_id, next_token_id, output_logprobs)
else:
# Generate the next tokens for generation tokens.
prob = probs[idx:idx + len(seq_ids)]
logprob = logprobs[idx:idx + len(seq_ids)]
idx += len(seq_ids)
# Sample the next tokens.
seq_logprobs = [
input_metadata.seq_logprobs[seq_id] for seq_id in seq_ids]
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params)
# Get top-k log probabilities for the next tokens.
next_logprobs: Dict[int, Dict[int, float]] = {}
for i, seq_id in enumerate(seq_ids):
next_logprobs[seq_id] = _get_topk_logprobs(
logprob[i], sampling_params.num_logprobs)
# Build the output.
for seq_id, parent_seq_id, next_token_id in zip(
seq_ids, parent_seq_ids, next_token_ids):
i = seq_ids.index(parent_seq_id)
output_logprobs = next_logprobs[parent_seq_id].copy()
output_logprobs[next_token_id] = logprob[i, next_token_id].item()
seq_outputs[seq_id] = SequenceOutputs(
seq_id,
parent_seq_id,
next_token_id,
output_logprobs,
)
return seq_outputs

View File

@ -1,24 +0,0 @@
from typing import Union
import torch
_STR_DTYPE_TO_TORCH_DTYPE = {
'half': torch.half,
'float': torch.float,
'float16': torch.float16,
'float32': torch.float32,
}
def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
if isinstance(dtype, str):
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()]
else:
torch_dtype = dtype
return torch_dtype
def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
torch_dtype = get_torch_dtype(dtype)
return torch.tensor([], dtype=torch_dtype).element_size()

View File

@ -1,12 +0,0 @@
import cacheflow.parallel_utils.parallel_state
import cacheflow.parallel_utils.tensor_parallel
import cacheflow.parallel_utils.utils
# Alias parallel_state as mpu, its legacy name
mpu = parallel_state
__all__ = [
"parallel_state",
"tensor_parallel",
"utils",
]

View File

@ -1,108 +0,0 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from typing import List, Sequence
from cacheflow.parallel_utils.utils import divide
from cacheflow.parallel_utils import parallel_state
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
""" Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
""" Break a tensor into equal 1D chunks across tensor parallel ranks.
Returns a Tensor or View with this rank's portion of the data.
Arguments:
tensor: The tensor to split
Keyword Arguments:
new_buffer (bool): If True, returns a new Tensor.
If False, returns a view into the existing Tensor.
Default is False
"""
partition_size = torch.numel(tensor) // \
parallel_state.get_tensor_model_parallel_world_size()
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor):
""" Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
model parallel ranks.
Returns a new Tensor with the gathered data.
Arguments:
tensor: A Tensor or view of this rank's portion of the data.
"""
numel_gathered = torch.numel(tensor) * \
parallel_state.get_tensor_model_parallel_world_size()
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor,
group=parallel_state.get_tensor_model_parallel_group())
return gathered
class VocabUtility:
""" Split the vocabulary into `world_size` chunks and return the first
and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)
"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int, rank, world_size: int
) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size
)

View File

@ -1,120 +0,0 @@
"""Utility functions used throughout Megatron core"""
from functools import reduce
import operator
import torch
from cacheflow.parallel_utils import parallel_state
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator
)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self):
self.buffer = {}
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get((name, dtype), None) is None or \
self.buffer[(name, dtype)].numel() < required_len:
self.buffer[(name, dtype)] = \
torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
def _kernel_make_viewless_tensor(inp, requires_grad):
'''Make a viewless tensor.
View tensors have the undesirable side-affect of retaining a reference
to the originally-viewed tensor, even after manually setting the '.data'
field. This method creates a new tensor that links to the old tensor's
data, without linking the viewed tensor, referenced via the '._base'
field.
'''
out = torch.empty(
(1,),
dtype = inp.dtype,
device = inp.device,
requires_grad = requires_grad,
)
out.data = inp.data
return out
class MakeViewlessTensor(torch.autograd.Function):
'''
Autograd function to make a viewless tensor.
This function should be used in cases where the computation graph needs
to be propagated, but we only want a viewless tensor (e.g.,
ParallelTransformer's hidden_states). Call this function by passing
'keep_graph = True' to 'make_viewless_tensor()'.
'''
@staticmethod
def forward(ctx, inp, requires_grad):
return _kernel_make_viewless_tensor(inp, requires_grad)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def make_viewless_tensor(inp, requires_grad, keep_graph):
'''
Entry-point for creating viewless tensors.
This method should be used, rather than calling 'MakeViewlessTensor'
or '_kernel_make_viewless_tensor' directly. This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor.
'''
# return tensor as-is, if not a 'view'
if inp._base is None:
return inp
# create viewless tensor
if keep_graph:
return MakeViewlessTensor.apply(inp, requires_grad)
else:
return _kernel_make_viewless_tensor(inp, requires_grad)
def assert_viewless_tensor(tensor, extra_msg = None):
'''Assert that a tensor is not a view (i.e., its '._base' field is
not set).'''
if isinstance(tensor, list):
[ assert_viewless_tensor(t) for t in tensor ]
return tensor
if not isinstance(tensor, torch.Tensor):
return tensor
assert tensor._base is None, (
"Ensure tensor._base is None before setting tensor.data or storing "
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
"likely accumulate over iterations). %s"
) % extra_msg
return tensor
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
'''Safely set tensor's '.data' field.
Check first that the tensor is viewless (i.e., '._base' not set). If not,
raise an exception.
'''
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
tensor.data = new_data_tensor

View File

@ -1,84 +0,0 @@
from typing import Optional, Set, Dict
class SamplingParams:
def __init__(
self,
n: int,
temperature: float,
top_p: float,
use_beam_search: bool,
stop_token_ids: Set[int],
max_num_steps: int,
num_logprobs: int,
context_window_size: Optional[int],
) -> None:
if n < 1:
raise ValueError(f'n must be at least 1, got {n}.')
if temperature < 0.0:
raise ValueError(
f'temperature must be non-negative, got {temperature}.')
if not 0.0 < top_p <= 1.0:
raise ValueError(f'top_p must be in (0, 1], got {top_p}.')
if max_num_steps < 1:
raise ValueError(
f'max_num_steps must be at least 1, got {max_num_steps}.')
if num_logprobs < 0:
raise ValueError(
f'num_logprobs must be non-negative, got {num_logprobs}.')
if context_window_size is not None and context_window_size < 0:
raise ValueError(
'context_window_size must be non-negative, '
f'got {context_window_size}.')
if use_beam_search:
if n == 1:
raise ValueError(
'n must be greater than 1 when using beam search.')
if temperature > 0.0:
raise ValueError(
'temperature must be 0 when using beam search.')
if top_p < 1.0:
raise ValueError(
'top_p must be 1 when using beam search.')
elif temperature == 0.0:
# Zero temperature means greedy sampling.
if n > 1:
raise ValueError(
'n must be 1 when using greedy sampling.')
if top_p < 1.0:
raise ValueError(
'top_p must be 1 when using greedy sampling.')
self.n = n
self.temperature = temperature
self.top_p = top_p
self.use_beam_search = use_beam_search
self.stop_token_ids = stop_token_ids
self.max_num_steps = max_num_steps
self.num_logprobs = num_logprobs
self.context_window_size = context_window_size
def __repr__(self) -> str:
return (f'SamplingParams(n={self.n}, '
f'temperature={self.temperature}, '
f'top_p={self.top_p}, '
f'use_beam_search={self.use_beam_search}, '
f'stop_token_ids={self.stop_token_ids}, '
f'max_num_steps={self.max_num_steps}, '
f'num_logprobs={self.num_logprobs}, '
f'context_window_size={self.context_window_size})')
@classmethod
def from_dict(cls, d: Dict) -> 'SamplingParams':
return cls(
n=d.get('n', 1),
temperature=d.get('temperature', 1.0),
top_p=d.get('top_p', 1.0),
use_beam_search=d.get('use_beam_search', False),
stop_token_ids=set(d.get('stop_token_ids', set())),
max_num_steps=d.get('max_num_steps', 16),
num_logprobs=d.get('num_logprobs', 0),
context_window_size=d.get('context_window_size', None),
)

View File

@ -1,169 +0,0 @@
import copy
import enum
from typing import Dict, List, Optional
from cacheflow.block import LogicalTokenBlock
from cacheflow.sampling_params import SamplingParams
class SequenceStatus(enum.Enum):
WAITING = enum.auto()
RUNNING = enum.auto()
SWAPPED = enum.auto()
FINISHED = enum.auto()
class Sequence:
def __init__(
self,
seq_id: int,
token_ids: List[int],
block_size: int,
) -> None:
self.seq_id = seq_id
self.block_size = block_size
self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the given token ids.
self.add(token_ids)
self.prompt_len = len(token_ids)
self.status = SequenceStatus.WAITING
self.output_logprobs: List[Dict[int, float]] = []
self.cumulative_logprobs = 0.0
def add_block(self) -> None:
block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks),
block_size=self.block_size,
)
self.logical_token_blocks.append(block)
def add(self, token_ids: List[int]) -> None:
while token_ids:
if not self.logical_token_blocks:
self.add_block()
last_block = self.logical_token_blocks[-1]
if last_block.is_full():
self.add_block()
last_block = self.logical_token_blocks[-1]
num_empty_slots = last_block.get_num_empty_slots()
last_block.append(token_ids[:num_empty_slots])
token_ids = token_ids[num_empty_slots:]
def append(self, token_id: int, logprobs: Dict[int, float]) -> None:
assert token_id in logprobs
self.add([token_id])
self.output_logprobs.append(logprobs)
self.cumulative_logprobs += logprobs[token_id]
def get_len(self) -> int:
return sum(block.num_tokens for block in self.logical_token_blocks)
def get_token_ids(self) -> List[int]:
token_ids: List[int] = []
for block in self.logical_token_blocks:
token_ids.extend(block.get_token_ids())
return token_ids
def get_last_token_id(self) -> int:
return self.logical_token_blocks[-1].get_last_token_id()
def fork(self, child_seq: 'Sequence') -> 'Sequence':
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
child_seq.cumulative_logprobs = self.cumulative_logprobs
def __repr__(self) -> str:
return (f'Sequence(seq_id={self.seq_id}, '
f'status={self.status.name}, '
f'num_blocks={len(self.logical_token_blocks)})')
class SequenceGroup:
def __init__(
self,
group_id: int,
seqs: List[Sequence],
arrival_time: float,
) -> None:
self.group_id = group_id
self.seqs = seqs
self.arrival_time = arrival_time
def get_seqs(
self,
status: Optional[SequenceStatus] = None,
) -> List[Sequence]:
if status is None:
return self.seqs
else:
return [seq for seq in self.seqs if seq.status == status]
def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
return len(self.get_seqs(status))
def find(self, seq_id: int) -> Sequence:
for seq in self.seqs:
if seq.seq_id == seq_id:
return seq
raise ValueError(f'Sequence {seq_id} not found.')
def is_finished(self) -> bool:
return all(seq.status == SequenceStatus.FINISHED for seq in self.seqs)
def __repr__(self) -> str:
return (f'SequenceGroup(group_id={self.group_id}, '
f'num_seqs={len(self.seqs)})')
class SequenceGroupInputs:
def __init__(
self,
group_id: int,
is_prompt: bool,
input_tokens: Dict[int, List[int]], # Seq id -> token ids.
context_len: int,
seq_logprobs: Dict[int, float], # Seq id -> cumulative logprobs.
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]], # Seq id -> List of physical block numbers.
) -> None:
self.group_id = group_id
self.is_prompt = is_prompt
self.input_tokens = input_tokens
self.context_len = context_len
self.seq_logprobs = seq_logprobs
self.sampling_params = sampling_params
self.block_tables = block_tables
class SequenceOutputs:
def __init__(
self,
seq_id: int,
parent_seq_id: int,
output_token: int,
logprobs: Dict[int, float], # Token id -> logP(x_i+1 | x_0, ..., x_i).
) -> None:
self.seq_id = seq_id
self.parent_seq_id = parent_seq_id
self.output_token = output_token
self.logprobs = logprobs
def __repr__(self) -> str:
return (f'SequenceOutputs(seq_id={self.seq_id}, '
f'parent_seq_id={self.parent_seq_id}, '
f'output_token={self.output_token}), '
f'logprobs={self.logprobs}')
def __eq__(self, other: 'SequenceOutputs') -> bool:
return (self.seq_id == other.seq_id and
self.parent_seq_id == other.parent_seq_id and
self.output_token == other.output_token and
self.logprobs == other.logprobs)

View File

@ -1,47 +0,0 @@
import enum
import random
import psutil
import numpy as np
import torch
from cacheflow.parallel_utils.parallel_state import model_parallel_is_initialized
from cacheflow.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed
class Device(enum.Enum):
GPU = enum.auto()
CPU = enum.auto()
class Counter:
def __init__(self, start: int = 0) -> None:
self.counter = start
def __next__(self) -> int:
id = self.counter
self.counter += 1
return id
def reset(self) -> None:
self.counter = 0
def set_random_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
if model_parallel_is_initialized():
model_parallel_cuda_manual_seed(seed)
def get_gpu_memory(gpu: int = 0) -> int:
return torch.cuda.get_device_properties(gpu).total_memory
def get_cpu_memory() -> int:
return psutil.virtual_memory().total

View File

@ -1,101 +0,0 @@
from typing import Dict, List, Union, Tuple
import ray
from cacheflow.master.scheduler import Scheduler
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.worker.worker import Worker
DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id
class Controller:
def __init__(
self,
stage_id: int,
stage_devices: List[DeviceID],
world_size: int,
tensor_parallel_size: int,
pipeline_parallel_size: int,
distributed_init_method: str,
model_name: str,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
dtype: str,
seed: int,
model_path: str,
use_dummy_weights: bool,
max_num_batched_tokens: int,
) -> None:
self.stage_id = stage_id
self.stage_devices = stage_devices
self.model_name = model_name
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
# Which pipeline stage is this node assigned to?
self.is_first_stage = stage_id == 0
self.is_last_stage = False
self.workers: List[Worker] = []
for rank, node_resource, device_id in stage_devices:
worker_cls = ray.remote(num_cpus=0,
num_gpus=1,
resources={node_resource: 1e-5})(Worker)
worker = worker_cls.remote(
model_name=model_name,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
dtype=dtype,
seed=seed,
distributed_init_method=distributed_init_method,
rank=rank,
world_size=world_size,
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
model_path=model_path,
use_dummy_weights=use_dummy_weights,
max_num_batched_tokens=max_num_batched_tokens,
)
self.workers.append(worker)
def set_next(
self,
next_node: Union['Controller', 'Scheduler'],
) -> None:
self.next_node = next_node
self.is_last_stage = isinstance(next_node, Scheduler)
def execute_stage(
self,
input_seq_groups: List[SequenceGroupInputs],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
futures = []
for worker in self.workers:
future = worker.execute_stage.remote(
input_seq_groups,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
)
futures.append(future)
all_outputs = ray.get(futures)
# Make sure all workers have the same results.
output = all_outputs[0]
for other_output in all_outputs[1:]:
assert output == other_output
if self.is_last_stage:
self.next_node.post_step(output)
else:
# TODO: Support pipeline parallelism.
assert False

View File

@ -1,264 +0,0 @@
from typing import Dict, List, Tuple
import torch
from cacheflow.models import get_model
from cacheflow.models import InputMetadata
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import SequenceGroupInputs
from cacheflow.sequence import SequenceOutputs
from cacheflow.worker.cache_engine import CacheEngine
from cacheflow.parallel_utils.parallel_state import (
initialize_model_parallel,
initialize_all_reduce_launcher,
get_tensor_model_parallel_world_size)
from cacheflow.utils import set_random_seed
class Worker:
def __init__(
self,
model_name: str,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
dtype: str,
seed: int,
distributed_init_method: str,
rank: int,
world_size: int,
model_path: str,
use_dummy_weights: bool,
max_num_batched_tokens: int,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
) -> None:
self.init_distributed_environment(distributed_init_method,
rank,
world_size,
tensor_parallel_size,
pipeline_parallel_size)
self.worker_id = rank
self.block_size = block_size
set_random_seed(seed)
# Initialize the model.
self.model, self.dtype = get_model(
model_name, dtype=dtype, path=model_path, use_dummy_weights=use_dummy_weights)
tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
initialize_all_reduce_launcher(
max_num_batched_tokens, self.model.config.hidden_size, self.dtype)
self.num_layers = self.model.config.num_hidden_layers
assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0
self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size
self.head_size = self.model.config.hidden_size // (self.num_heads * tensor_model_parallel_world_size)
# We reset the seed after initializing the model to ensure that
# the random state is not affected by the model initialization.
set_random_seed(seed)
self.cache_engine = CacheEngine(
worker_id=self.worker_id,
num_layers=self.num_layers,
num_heads=self.num_heads,
head_size=self.head_size,
block_size=block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
dtype=self.dtype,
)
self.cache_events = self.cache_engine.events
self.gpu_cache = self.cache_engine.gpu_cache
def init_distributed_environment(self,
distributed_init_method: str,
rank: int,
world_size: int,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1) -> None:
"""Initialize the distributed environment."""
torch.distributed.init_process_group(
backend='nccl',
init_method=distributed_init_method,
world_size=world_size,
rank=rank,
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(tensor_parallel_size,
pipeline_parallel_size)
def prepare_inputs(
self,
input_seq_groups: List[SequenceGroupInputs],
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
seq_logprobs: Dict[int, float] = {}
sampling_params: Dict[int, SamplingParams] = {}
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
# Add prompt tokens.
prompt_lens: List[int] = []
for input_seq_group in input_seq_groups:
if not input_seq_group.is_prompt:
continue
seq_ids = list(input_seq_group.input_tokens.keys())
sampling_params = input_seq_group.sampling_params
seq_groups.append((seq_ids, sampling_params))
seq_logprobs.update(input_seq_group.seq_logprobs)
# Use any sequence in the group.
seq_id = seq_ids[0]
prompt_tokens = input_seq_group.input_tokens[seq_id]
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
input_tokens.extend(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(range(len(prompt_tokens)))
# Compute the slot mapping.
block_table = input_seq_group.block_tables[seq_id]
for i in range(prompt_len):
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
cumulative_prompt_lens: List[int] = [0]
for prompt_len in prompt_lens:
cumulative_prompt_lens.append(
cumulative_prompt_lens[-1] + prompt_len)
# Add generation tokens.
max_context_len = 0
max_num_blocks_per_seq = 0
context_lens: List[int] = []
generation_block_tables: List[List[int]] = []
for input_seq_group in input_seq_groups:
if input_seq_group.is_prompt:
continue
seq_ids = list(input_seq_group.input_tokens.keys())
sampling_params = input_seq_group.sampling_params
seq_groups.append((seq_ids, sampling_params))
seq_logprobs.update(input_seq_group.seq_logprobs)
for seq_id in seq_ids:
assert len(input_seq_group.input_tokens[seq_id]) == 1
generation_token = input_seq_group.input_tokens[seq_id][0]
input_tokens.append(generation_token)
position = input_seq_group.context_len - 1
input_positions.append(position)
block_table = input_seq_group.block_tables[seq_id]
generation_block_tables.append(block_table)
max_context_len = max(
max_context_len, input_seq_group.context_len)
max_num_blocks_per_seq = max(
max_num_blocks_per_seq, len(block_table))
context_lens.append(input_seq_group.context_len)
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
# Optimization: Pad the input length to be a multiple of 8.
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
input_positions = _pad_to_alignment(input_positions, multiple_of=8)
# Convert to tensors.
tokens_tensor = torch.tensor(
input_tokens, dtype=torch.long, device='cuda')
positions_tensor = torch.tensor(
input_positions, dtype=torch.long, device='cuda')
slot_mapping_tensor = torch.tensor(
slot_mapping, dtype=torch.int, device='cuda')
context_lens_tensor = torch.tensor(
context_lens, dtype=torch.int, device='cuda')
padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables]
block_tables_tensor = torch.tensor(
padded_block_tables, dtype=torch.int, device='cuda')
cumulative_prompt_lens_tensor = torch.tensor(
cumulative_prompt_lens, dtype=torch.int, device='cuda')
input_metadata = InputMetadata(
seq_groups=seq_groups,
seq_logprobs=seq_logprobs,
prompt_lens=prompt_lens,
cumulative_prompt_lens=cumulative_prompt_lens_tensor,
slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor,
max_context_len=max_context_len,
block_tables=block_tables_tensor,
)
return tokens_tensor, positions_tensor, input_metadata
@torch.inference_mode()
def execute_stage(
self,
input_seq_groups: List[SequenceGroupInputs],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> Dict[int, SequenceOutputs]:
# Issue cache operations.
command_issued = False
if blocks_to_swap_in:
self.cache_engine.swap_in(blocks_to_swap_in)
command_issued = True
if blocks_to_swap_out:
self.cache_engine.swap_out(blocks_to_swap_out)
command_issued = True
if blocks_to_copy:
self.cache_engine.copy(blocks_to_copy)
command_issued = True
if command_issued:
cache_events = self.cache_events
else:
cache_events = None
# If there is no input, we don't need to execute the model.
if not input_seq_groups:
if cache_events is not None:
for event in cache_events:
event.wait()
return {}
# Prepare input tensors.
input_tokens, input_positions, input_metadata = self.prepare_inputs(
input_seq_groups)
# Execute the model.
output = self.model(
input_ids=input_tokens,
positions=input_positions,
kv_caches=self.gpu_cache,
input_metadata=input_metadata,
cache_events=cache_events,
)
return output
def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
return x + [0] * ((-len(x)) % multiple_of)
def _pad_to_max(x: List[int], max_len: int) -> List[int]:
return x + [0] * (max_len - len(x))

View File

@ -4,9 +4,25 @@ void silu_and_mul(
torch::Tensor& out,
torch::Tensor& input);
void gelu_new(
torch::Tensor& out,
torch::Tensor& input);
void gelu_fast(
torch::Tensor& out,
torch::Tensor& input);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"silu_and_mul",
&silu_and_mul,
"Activation function used in SwiGLU.");
m.def(
"gelu_new",
&gelu_new,
"GELU implementation used in GPT-2.");
m.def(
"gelu_fast",
&gelu_fast,
"Approximate GELU implementation.");
}

View File

@ -1,7 +1,9 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
namespace cacheflow {
#include "dispatch_utils.h"
namespace vllm {
template<typename T>
__device__ __forceinline__ T silu(const T& x) {
@ -22,7 +24,7 @@ __global__ void silu_and_mul_kernel(
}
}
} // namespace cacheflow
} // namespace vllm
void silu_and_mul(
torch::Tensor& out, // [num_tokens, d]
@ -34,13 +36,79 @@ void silu_and_mul(
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"silu_and_mul_kernel",
[&] {
cacheflow::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
d);
});
}
namespace vllm {
// Element-wise activation kernel template.
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
__global__ void activation_kernel(
scalar_t* __restrict__ out, // [num_tokens, d]
const scalar_t* __restrict__ input, // [num_tokens, d]
const int d) {
const int token_idx = blockIdx.x;
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
}
} // namespace vllm
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int num_tokens = input.size(0); \
int d = input.size(1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"activation_kernel", \
[&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), \
d); \
});
namespace vllm {
template<typename T>
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
const float x3 = (float) (x * x * x);
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
return ((T) 0.5) * x * (((T) 1.0) + t);
}
template<typename T>
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
const float f = (float) x;
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
return ((T) 0.5) * x * (((T) 1.0) + t);
}
} // namespace vllm
void gelu_new(
torch::Tensor& out, // [num_tokens, d]
torch::Tensor& input) // [num_tokens, d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
}
void gelu_fast(
torch::Tensor& out, // [num_tokens, d]
torch::Tensor& input) // [num_tokens, d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}

View File

@ -1,15 +1,18 @@
#include <torch/extension.h>
#include <c10/util/Optional.h>
void single_query_cached_kv_attention(
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len);
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(

View File

@ -0,0 +1,6 @@
#pragma once
#include "attention_generic.cuh"
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
#include "dtype_bfloat16.cuh"

View File

@ -0,0 +1,64 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <stdint.h>
namespace vllm {
// A vector type to store Q, K, V elements.
template<typename T, int VEC_SIZE>
struct Vec {};
// A vector type to store FP32 accumulators.
template<typename T>
struct FloatVec {};
// Template vector operations.
template<typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b);
template<typename T>
inline __device__ float sum(T v);
template<typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<T, T, T>(a, b));
}
template<typename A, typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<A, T, T>(a, b));
}
template<typename T>
inline __device__ void zero(T& dst) {
constexpr int WORDS = sizeof(T) / 4;
union {
T raw;
uint32_t words[WORDS];
} tmp;
#pragma unroll
for (int ii = 0; ii < WORDS; ++ii) {
tmp.words[ii] = 0u;
}
dst = tmp.raw;
}
} // namespace vllm

View File

@ -0,0 +1,519 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include <algorithm>
#define WARP_SIZE 32
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
namespace vllm {
// Utility function for attention softmax.
template<int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < NUM_WARPS) {
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Broadcast to other threads.
return __shfl_sync(uint32_t(-1), sum, 0);
}
// Grid: (num_heads, num_seqs).
template<
typename scalar_t,
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS>
__global__ void single_query_cached_kv_attention_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
const int* __restrict__ head_mapping, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE;
const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const int kv_head_idx = head_mapping[head_idx];
const int seq_idx = blockIdx.y;
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group
// fetch or compute 16 bytes at a time.
// For example, if the size of a thread group is 4 and the data type is half,
// then the vector size is 16 / (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// th vectors of the query, and so on.
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
// Memory planning.
extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float* logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(scalar_t);
float qk_max = -FLT_MAX;
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const int context_len = context_lens[seq_idx];
const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx];
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
// vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride
+ physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
}
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= context_len;
logits[token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
}
}
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
}
__syncthreads();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
// Broadcast the max qk value to all threads.
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
// Get the sum of the exp values.
float exp_sum = 0.f;
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max);
logits[i] = val;
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
logits[i] *= inv_sum;
}
__syncthreads();
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using Float_L_vec = typename FloatVec<L_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[i] = 0.f;
}
scalar_t zero_value;
zero(zero_value);
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx];
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));
const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ kv_head_idx * kv_head_stride;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
if (block_idx == num_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the context,
// we should explicitly zero out the values since they may contain NaNs.
// See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll
for (int j = 0; j <= V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
}
}
accs[i] += dot(logits_vec, v_vec);
}
}
}
// Perform reduction within each warp.
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
}
accs[i] = acc;
}
// NOTE(woosuk): A barrier is required because the shared memory space for logits
// is reused for the output.
__syncthreads();
// Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[i];
}
}
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[i] += src[row_idx];
}
}
}
__syncthreads();
}
// Write the final output.
if (warp_idx == 0) {
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[i]);
}
}
}
}
} // namespace vllm
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
head_mapping_ptr, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride);
// TODO(woosuk): Tune NUM_THREADS.
template<
typename T,
int BLOCK_SIZE,
int NUM_THREADS = 128>
void single_query_cached_kv_attention_launcher(
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& head_mapping,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ?
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_context_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
int shared_mem_size = std::max(logits_size, outputs_size);
dim3 grid(num_heads, num_seqs);
dim3 block(NUM_THREADS);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
// NOTE(woosuk): To reduce the compilation time, we omitted head sizes
// 32, 160, 192.
// case 32:
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
// break;
case 64:
LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
break;
case 80:
LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS);
break;
case 96:
LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
break;
case 112:
LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS);
break;
case 128:
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
break;
// case 160:
// LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
// break;
// case 192:
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
// break;
case 256:
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
break;
default:
TORCH_CHECK(false, "Unsupported head size: ", head_size);
break;
}
}
#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
single_query_cached_kv_attention_launcher<T, BLOCK_SIZE>( \
out, \
query, \
key_cache, \
value_cache, \
head_mapping, \
scale, \
block_tables, \
context_lens, \
max_context_len, \
alibi_slopes);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
switch (block_size) { \
/* case 1: */ \
/* CALL_KERNEL_LAUNCHER(T, 1); */ \
/* break; */ \
/* case 2: */ \
/* CALL_KERNEL_LAUNCHER(T, 2); */ \
/* break; */ \
/* case 4: */ \
/* CALL_KERNEL_LAUNCHER(T, 4); */ \
/* break; */ \
case 8: \
CALL_KERNEL_LAUNCHER(T, 8); \
break; \
case 16: \
CALL_KERNEL_LAUNCHER(T, 16); \
break; \
case 32: \
CALL_KERNEL_LAUNCHER(T, 32); \
break; \
/* case 64: */ \
/* CALL_KERNEL_LAUNCHER(T, 64); */ \
/* break; */ \
/* case 128: */ \
/* CALL_KERNEL_LAUNCHER(T, 128); */ \
/* break; */ \
/* case 256: */ \
/* CALL_KERNEL_LAUNCHER(T, 256); */ \
/* break; */ \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
void single_query_cached_kv_attention(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& head_mapping, // [num_heads]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
if (query.dtype() == at::ScalarType::Float) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float);
} else if (query.dtype() == at::ScalarType::Half) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
}
#undef WARP_SIZE
#undef MAX
#undef MIN

View File

@ -0,0 +1,55 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "attention_dtypes.h"
#include <float.h>
#include <type_traits>
namespace vllm {
// Q*K^T operation.
template<int THREAD_GROUP_SIZE, typename Vec, int N>
inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
using A_vec = typename FloatVec<Vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}
// Finalize the reduction across lanes.
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
}
return qk;
}
template<typename T, int THREAD_GROUP_SIZE>
struct Qk_dot {
template<typename Vec, int N>
static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
return qk_dot_<THREAD_GROUP_SIZE>(q, k);
}
};
} // namespace vllm

View File

@ -0,0 +1,433 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "attention_generic.cuh"
#include "dtype_float32.cuh"
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <stdint.h>
namespace vllm {
// Define custom BF16 vector data types.
struct bf16_4_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
};
struct bf16_8_t {
__nv_bfloat162 x;
__nv_bfloat162 y;
__nv_bfloat162 z;
__nv_bfloat162 w;
};
// BF16 vector types for Q, K, V.
template<>
struct Vec<__nv_bfloat16, 1> {
using Type = __nv_bfloat16;
};
template<>
struct Vec<__nv_bfloat16, 2> {
using Type = __nv_bfloat162;
};
template<>
struct Vec<__nv_bfloat16, 4> {
using Type = bf16_4_t;
};
template<>
struct Vec<__nv_bfloat16, 8> {
using Type = bf16_8_t;
};
// FP32 accumulator vector types corresponding to Vec.
template<>
struct FloatVec<__nv_bfloat16> {
using Type = float;
};
template<>
struct FloatVec<__nv_bfloat162> {
using Type = float2;
};
template<>
struct FloatVec<bf16_4_t> {
using Type = Float4_;
};
template<>
struct FloatVec<bf16_8_t> {
using Type = Float8_;
};
// Utility functions for type conversions.
inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __bfloat1622float2(val);
#endif
}
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __bfloat162bfloat162(val);
#endif
}
// Vector addition.
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return a + b;
#endif
}
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hadd2(a, b);
#endif
}
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
bf16_4_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
bf16_8_t c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
inline __device__ float2 add(__nv_bfloat162 a, float2 fb) {
float2 fa = bf1622float2(a);
return add(fa, fb);
}
inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
Float4_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
return fc;
}
inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
// Vector multiplication.
template<>
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hmul(a, b);
#endif
}
template<>
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hmul2(a, b);
#endif
}
template<>
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
template<>
inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return c;
}
template<>
inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a);
bf16_4_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return c;
}
template<>
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return c;
}
template<>
inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a);
bf16_8_t c;
c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x);
c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y);
c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z);
c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w);
return c;
}
template<>
inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) {
float fa = __bfloat162float(a);
float fb = __bfloat162float(b);
return fa * fb;
}
template<>
inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return mul<float2, float2, float2>(fa, fb);
}
template<>
inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) {
return mul<float2, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b);
}
template<>
inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
return fc;
}
template<>
inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) {
__nv_bfloat162 s = bf162bf162(a);
Float4_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
return fc;
}
template<>
inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.x, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.y, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.z, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(a.w, b.w);
return fc;
}
template<>
inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) {
__nv_bfloat162 s = bf162bf162(a);
Float8_ fc;
fc.x = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.x);
fc.y = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.y);
fc.z = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.z);
fc.w = mul<float2, __nv_bfloat162, __nv_bfloat162>(s, b.w);
return fc;
}
// Vector fused multiply-add.
inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hfma2(a, b, c);
#endif
}
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return __hfma2(bf162bf162(a), b, c);
#endif
}
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
bf16_4_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) {
__nv_bfloat162 s = bf162bf162(a);
bf16_4_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
bf16_8_t d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) {
__nv_bfloat162 s = bf162bf162(a);
bf16_8_t d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) {
return __bfloat162float(a) * __bfloat162float(b) + fc;
}
inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) {
float2 fa = bf1622float2(a);
float2 fb = bf1622float2(b);
return fma(fa, fb, fc);
}
inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) {
return fma(bf162bf162(a), b, fc);
}
inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
Float4_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
return fd;
}
inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) {
__nv_bfloat162 s = bf162bf162(a);
Float4_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
return fd;
}
inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
Float8_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
fd.z = fma(a.z, b.z, fc.z);
fd.w = fma(a.w, b.w, fc.w);
return fd;
}
inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) {
__nv_bfloat162 s = bf162bf162(a);
Float8_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
fd.z = fma(s, b.z, fc.z);
fd.w = fma(s, b.w, fc.w);
return fd;
}
// Vector sum.
template<>
inline __device__ float sum(__nv_bfloat16 v) {
return __bfloat162float(v);
}
template<>
inline __device__ float sum(__nv_bfloat162 v) {
float2 vf = bf1622float2(v);
return vf.x + vf.y;
}
template<>
inline __device__ float sum(bf16_4_t v) {
return sum(v.x) + sum(v.y);
}
template<>
inline __device__ float sum(bf16_8_t v) {
return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
}
// From float32 to bfloat16.
inline __device__ void from_float(__nv_bfloat16& dst, float src) {
dst = __float2bfloat16(src);
}
inline __device__ void from_float(__nv_bfloat162& dst, float2 src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
dst = __float22bfloat162_rn(src);
#endif
}
inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
#endif
}
inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
dst.x = __float22bfloat162_rn(src.x);
dst.y = __float22bfloat162_rn(src.y);
dst.z = __float22bfloat162_rn(src.z);
dst.w = __float22bfloat162_rn(src.w);
#endif
}
// Zero-out a variable.
inline __device__ void zero(__nv_bfloat16& dst) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
// Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
dst = __ushort_as_bfloat16((unsigned short)0x0000U);
#endif
}
} // namespace vllm

View File

@ -0,0 +1,444 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "attention_generic.cuh"
#include "dtype_float32.cuh"
#include <stdint.h>
namespace vllm {
// FP16 vector types for Q, K, V.
template<>
struct Vec<uint16_t, 1> {
using Type = uint16_t;
};
template<>
struct Vec<uint16_t, 2> {
using Type = uint32_t;
};
template<>
struct Vec<uint16_t, 4> {
using Type = uint2;
};
template<>
struct Vec<uint16_t, 8> {
using Type = uint4;
};
// FP32 accumulator vector types corresponding to Vec.
template<>
struct FloatVec<uint16_t> {
using Type = float;
};
template<>
struct FloatVec<uint32_t> {
using Type = float2;
};
template<>
struct FloatVec<uint2> {
using Type = Float4_;
};
template<>
struct FloatVec<uint4> {
using Type = Float8_;
};
// Utility functions for type conversions.
inline __device__ uint32_t h0_h0(uint16_t a) {
uint32_t b;
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
return b;
}
inline __device__ float half_to_float(uint16_t h) {
float f;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
return f;
}
inline __device__ float2 half2_to_float2(uint32_t v) {
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
return make_float2(half_to_float(lo), half_to_float(hi));
}
inline __device__ uint16_t float_to_half(float f) {
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
return tmp.u16[0];
}
inline __device__ uint32_t float2_to_half2(float2 f) {
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif
return tmp.u32;
}
// Vector addition.
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
uint16_t c;
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
}
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
inline __device__ uint2 add(uint2 a, uint2 b) {
uint2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ uint4 add(uint4 a, uint4 b) {
uint4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
inline __device__ float2 add(uint32_t a, float2 fb) {
float2 fa = half2_to_float2(a);
return add(fa, fb);
}
inline __device__ Float4_ add(uint2 a, Float4_ fb) {
Float4_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
return fc;
}
inline __device__ Float8_ add(uint4 a, Float8_ fb) {
Float8_ fc;
fc.x = add(a.x, fb.x);
fc.y = add(a.y, fb.y);
fc.z = add(a.z, fb.z);
fc.w = add(a.w, fb.w);
return fc;
}
// Vector multiplication.
template<>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
uint16_t c;
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
}
template<>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
template<>
inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
return mul<uint32_t, uint32_t, uint32_t>(h0_h0(a), b);
}
template<>
inline __device__ uint2 mul(uint2 a, uint2 b) {
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
return c;
}
template<>
inline __device__ uint2 mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a);
uint2 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
return c;
}
template<>
inline __device__ uint4 mul(uint4 a, uint4 b) {
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(a.x, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(a.y, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(a.z, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(a.w, b.w);
return c;
}
template<>
inline __device__ uint4 mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a);
uint4 c;
c.x = mul<uint32_t, uint32_t, uint32_t>(s, b.x);
c.y = mul<uint32_t, uint32_t, uint32_t>(s, b.y);
c.z = mul<uint32_t, uint32_t, uint32_t>(s, b.z);
c.w = mul<uint32_t, uint32_t, uint32_t>(s, b.w);
return c;
}
template<>
inline __device__ float mul(uint16_t a, uint16_t b) {
float fa = half_to_float(a);
float fb = half_to_float(b);
return fa * fb;
}
template<>
inline __device__ float2 mul(uint32_t a, uint32_t b) {
float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b);
return mul<float2, float2, float2>(fa, fb);
}
template<>
inline __device__ float2 mul(uint16_t a, uint32_t b) {
return mul<float2, uint32_t, uint32_t>(h0_h0(a), b);
}
template<>
inline __device__ Float4_ mul(uint2 a, uint2 b) {
Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
return fc;
}
template<>
inline __device__ Float4_ mul(uint16_t a, uint2 b) {
uint32_t s = h0_h0(a);
Float4_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
return fc;
}
template<>
inline __device__ Float8_ mul(uint4 a, uint4 b) {
Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(a.x, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(a.y, b.y);
fc.z = mul<float2, uint32_t, uint32_t>(a.z, b.z);
fc.w = mul<float2, uint32_t, uint32_t>(a.w, b.w);
return fc;
}
template<>
inline __device__ Float8_ mul(uint16_t a, uint4 b) {
uint32_t s = h0_h0(a);
Float8_ fc;
fc.x = mul<float2, uint32_t, uint32_t>(s, b.x);
fc.y = mul<float2, uint32_t, uint32_t>(s, b.y);
fc.z = mul<float2, uint32_t, uint32_t>(s, b.z);
fc.w = mul<float2, uint32_t, uint32_t>(s, b.w);
return fc;
}
// Vector fused multiply-add.
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
return d;
}
inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
return fma(h0_h0(a), b, c);
}
inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
uint2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
uint32_t s = h0_h0(a);
uint2 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
return d;
}
inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
uint4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
uint32_t s = h0_h0(a);
uint4 d;
d.x = fma(s, b.x, c.x);
d.y = fma(s, b.y, c.y);
d.z = fma(s, b.z, c.z);
d.w = fma(s, b.w, c.w);
return d;
}
inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
float fa = half_to_float(a);
float fb = half_to_float(b);
return fa * fb + fc;
}
inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
float2 fa = half2_to_float2(a);
float2 fb = half2_to_float2(b);
return fma(fa, fb, fc);
}
inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
return fma(h0_h0(a), b, fc);
}
inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
Float4_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
return fd;
}
inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
uint32_t s = h0_h0(a);
Float4_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
return fd;
}
inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
Float8_ fd;
fd.x = fma(a.x, b.x, fc.x);
fd.y = fma(a.y, b.y, fc.y);
fd.z = fma(a.z, b.z, fc.z);
fd.w = fma(a.w, b.w, fc.w);
return fd;
}
inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
uint32_t s = h0_h0(a);
Float8_ fd;
fd.x = fma(s, b.x, fc.x);
fd.y = fma(s, b.y, fc.y);
fd.z = fma(s, b.z, fc.z);
fd.w = fma(s, b.w, fc.w);
return fd;
}
// Vector sum.
template<>
inline __device__ float sum(uint16_t v) {
return half_to_float(v);
}
template<>
inline __device__ float sum(uint32_t v) {
float2 tmp = half2_to_float2(v);
return tmp.x + tmp.y;
}
template<>
inline __device__ float sum(uint2 v) {
uint32_t c = add(v.x, v.y);
return sum(c);
}
template<>
inline __device__ float sum(uint4 v) {
uint32_t c = add(v.x, v.y);
c = add(c, v.z);
c = add(c, v.w);
return sum(c);
}
// From float32 to float16.
inline __device__ void from_float(uint16_t& dst, float src) {
dst = float_to_half(src);
}
inline __device__ void from_float(uint32_t& dst, float2 src) {
dst = float2_to_half2(src);
}
inline __device__ void from_float(uint2& dst, Float4_ src) {
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
}
inline __device__ void from_float(uint4& dst, Float8_ src) {
dst.x = float2_to_half2(src.x);
dst.y = float2_to_half2(src.y);
dst.z = float2_to_half2(src.z);
dst.w = float2_to_half2(src.w);
}
// From float16 to float32.
inline __device__ float to_float(uint16_t u) {
return half_to_float(u);
}
inline __device__ float2 to_float(uint32_t u) {
return half2_to_float2(u);
}
inline __device__ Float4_ to_float(uint2 u) {
Float4_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
return tmp;
}
inline __device__ Float8_ to_float(uint4 u) {
Float8_ tmp;
tmp.x = half2_to_float2(u.x);
tmp.y = half2_to_float2(u.y);
tmp.z = half2_to_float2(u.z);
tmp.w = half2_to_float2(u.w);
return tmp;
}
// Zero-out a variable.
inline __device__ void zero(uint16_t& dst) {
dst = uint16_t(0);
}
} // namespace vllm

View File

@ -0,0 +1,273 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "attention_generic.cuh"
#include <stdint.h>
namespace vllm {
// Define custom FP32 vector data types.
struct Float4_ {
float2 x;
float2 y;
};
struct Float8_ {
float2 x;
float2 y;
float2 z;
float2 w;
};
// FP32 vector types for Q, K, V.
template<>
struct Vec<float, 1> {
using Type = float;
};
template<>
struct Vec<float, 2> {
using Type = float2;
};
template<>
struct Vec<float, 4> {
using Type = float4;
};
// FP32 accumulator vector types corresponding to Vec.
template<>
struct FloatVec<float> {
using Type = float;
};
template<>
struct FloatVec<float2> {
using Type = float2;
};
template<>
struct FloatVec<float4> {
using Type = float4;
};
// Vector addition.
inline __device__ float add(float a, float b) {
return a + b;
}
inline __device__ float2 add(float2 a, float2 b) {
float2 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}
inline __device__ float4 add(float4 a, float4 b) {
float4 c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
c.z = add(a.z, b.z);
c.w = add(a.w, b.w);
return c;
}
// Vector multiplication.
template<>
inline __device__ float mul<float, float>(float a, float b) {
return a * b;
}
template<>
inline __device__ float2 mul(float2 a, float2 b) {
float2 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
return c;
}
template<>
inline __device__ float2 mul(float a, float2 b) {
float2 c;
c.x = a * b.x;
c.y = a * b.y;
return c;
}
template<>
inline __device__ float4 mul(float4 a, float4 b) {
float4 c;
c.x = a.x * b.x;
c.y = a.y * b.y;
c.z = a.z * b.z;
c.w = a.w * b.w;
return c;
}
template<>
inline __device__ float4 mul(float a, float4 b) {
float4 c;
c.x = a * b.x;
c.y = a * b.y;
c.z = a * b.z;
c.w = a * b.w;
return c;
}
// Vector fused multiply-add.
inline __device__ float fma(float a, float b, float c) {
return a * b + c;
}
inline __device__ float2 fma(float2 a, float2 b, float2 c) {
float2 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
return d;
}
inline __device__ float2 fma(float a, float2 b, float2 c) {
float2 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
inline __device__ float4 fma(float4 a, float4 b, float4 c) {
float4 d;
d.x = fma(a.x, b.x, c.x);
d.y = fma(a.y, b.y, c.y);
d.z = fma(a.z, b.z, c.z);
d.w = fma(a.w, b.w, c.w);
return d;
}
inline __device__ float4 fma(float a, float4 b, float4 c) {
float4 d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
Float4_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
return d;
}
inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
Float8_ d;
d.x = fma(a, b.x, c.x);
d.y = fma(a, b.y, c.y);
d.z = fma(a, b.z, c.z);
d.w = fma(a, b.w, c.w);
return d;
}
// Vector sum.
template<>
inline __device__ float sum(float v) {
return v;
}
template<>
inline __device__ float sum(float2 v) {
return v.x + v.y;
}
template<>
inline __device__ float sum(float4 v) {
return v.x + v.y + v.z + v.w;
}
template<>
inline __device__ float sum(Float4_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y;
}
template<>
inline __device__ float sum(Float8_ v) {
return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
}
// Vector dot product.
inline __device__ float dot(float a, float b) {
return a * b;
}
inline __device__ float dot(float2 a, float2 b) {
float2 c = mul<float2, float2, float2>(a, b);
return c.x + c.y;
}
inline __device__ float dot(Float4_ a, Float4_ b) {
float2 acc = mul<float2, float2, float2>(a.x, b.x);
acc = fma(a.y, b.y, acc);
return acc.x + acc.y;
}
inline __device__ float dot(Float8_ a, Float8_ b) {
float2 acc = mul<float2, float2, float2>(a.x, b.x);
acc = fma(a.y, b.y, acc);
acc = fma(a.z, b.z, acc);
acc = fma(a.w, b.w, acc);
return acc.x + acc.y;
}
// From float to float.
inline __device__ void from_float(float& dst, float src) {
dst = src;
}
inline __device__ void from_float(float2& dst, float2 src) {
dst = src;
}
inline __device__ void from_float(float4& dst, float4 src) {
dst = src;
}
// From float to float.
inline __device__ float to_float(float u) {
return u;
}
inline __device__ float2 to_float(float2 u) {
return u;
}
inline __device__ float4 to_float(float4 u) {
return u;
}
inline __device__ Float4_ to_float(Float4_ u) {
return u;
}
inline __device__ Float8_ to_float(Float8_ u) {
return u;
}
// Zero-out a variable.
inline __device__ void zero(float& dst) {
dst = 0.f;
}
} // namespace vllm

View File

@ -1,896 +0,0 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "attention_utils.h"
#include "cuda_primitives.h"
#include "reduction_utils.h"
#include <algorithm>
#define WARP_SIZE 32
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
namespace cacheflow {
// Grid: (num_heads, num_seqs).
template<
typename scalar_t,
int HEAD_SIZE,
int BLOCK_SIZE,
int NUM_THREADS>
__global__ void single_query_cached_kv_attention_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const int q_stride) {
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE;
const int head_idx = blockIdx.x;
const int num_heads = gridDim.x;
const int seq_idx = blockIdx.y;
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group
// fetch or compute 16 bytes at a time.
// For example, if the size of a thread group is 4 and the data type is half,
// then the vector size is 16 / (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// th vectors of the query, and so on.
// NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
Q_vec q_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
// Memory planning.
extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 logits and accumulation.
float *logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(scalar_t);
float qk_max = -FLT_MAX;
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const int context_len = context_lens[seq_idx];
const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx];
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in the group
// has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
// vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
+ head_idx * HEAD_SIZE * BLOCK_SIZE
+ physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
}
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
const bool mask = token_idx >= context_len;
if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
logits[token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
}
}
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
}
__syncthreads();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
}
// Broadcast the max qk value to all threads.
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
// Get the sum of the exp values.
float exp_sum = 0.f;
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max);
logits[i] = val;
exp_sum += val;
}
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
// Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
logits[i] *= inv_sum;
}
__syncthreads();
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename FloatVec<V_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;
float accs[NUM_ROWS_PER_THREAD];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[i] = 0.f;
}
for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
const int physical_block_number = block_table[block_idx];
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec = *reinterpret_cast<L_vec*>(logits + token_idx);
const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
+ head_idx * HEAD_SIZE * BLOCK_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
accs[i] += dot(logits_vec, cast_to_float(v_vec));
}
}
}
// Perform reduction within each warp.
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
}
accs[i] = acc;
}
// NOTE(woosuk): A barrier is required because the shared memory space for logits
// is reused for the output.
__syncthreads();
// Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2;
// Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[i];
}
}
}
__syncthreads();
// Lower warps update the output.
if (warp_idx < mid) {
const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[i] += src[row_idx];
}
}
}
__syncthreads();
}
// Write the final output.
if (warp_idx == 0) {
scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
convert_from_float(*(out_ptr + row_idx), accs[i]);
}
}
}
}
} // namespace cacheflow
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
cacheflow::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq, \
query_stride);
// TODO(woosuk): Tune NUM_THREADS.
template<
typename T,
int BLOCK_SIZE,
int NUM_THREADS = 128>
void single_query_cached_kv_attention_launcher(
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int query_stride = query.stride(0);
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0);
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_context_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
int shared_mem_size = std::max(logits_size, outputs_size);
dim3 grid(num_heads, num_seqs);
dim3 block(NUM_THREADS);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (head_size) {
case 32:
LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
break;
case 64:
LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
break;
case 80:
LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS);
break;
case 96:
LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
break;
case 128:
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
break;
case 160:
LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
break;
case 192:
LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
break;
case 256:
LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
break;
default:
assert(false);
break;
}
}
#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \
single_query_cached_kv_attention_launcher<T, BLOCK_SIZE>( \
out, \
query, \
key_cache, \
value_cache, \
scale, \
block_tables, \
context_lens, \
max_context_len);
void single_query_cached_kv_attention(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
int block_size,
int max_context_len) {
// TODO(woosuk): Support BF16.
if (query.element_size() == 2) {
// Half.
if (block_size == 1) {
CALL_KERNEL_LAUNCHER(uint16_t, 1);
} else if (block_size == 2) {
CALL_KERNEL_LAUNCHER(uint16_t, 2);
} else if (block_size == 4) {
CALL_KERNEL_LAUNCHER(uint16_t, 4);
} else if (block_size == 8) {
CALL_KERNEL_LAUNCHER(uint16_t, 8);
} else if (block_size == 16) {
CALL_KERNEL_LAUNCHER(uint16_t, 16);
} else if (block_size == 32) {
CALL_KERNEL_LAUNCHER(uint16_t, 32);
} else if (block_size == 64) {
CALL_KERNEL_LAUNCHER(uint16_t, 64);
} else if (block_size == 128) {
CALL_KERNEL_LAUNCHER(uint16_t, 128);
} else if (block_size == 256) {
CALL_KERNEL_LAUNCHER(uint16_t, 256);
} else {
assert(false);
}
} else {
// Float.
assert(false);
}
}
// namespace cacheflow {
// // Grid: (num_heads, num_query_tokens).
// template<
// typename scalar_t,
// int HEAD_SIZE,
// int BLOCK_SIZE,
// int NUM_THREADS>
// __device__ void multi_query_cached_kv_attention_kernel_unoptimized_(
// scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
// const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
// const int seq_start_idx,
// const int seq_len,
// const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
// const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
// const float scale,
// const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq]
// const int context_len,
// const int max_num_blocks_per_seq,
// const int q_stride) {
// constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE;
// constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
// const int thread_idx = threadIdx.x;
// const int warp_idx = thread_idx / WARP_SIZE;
// const int lane = thread_idx % WARP_SIZE;
// const int head_idx = blockIdx.x;
// const int num_heads = gridDim.x;
// const int seq_idx = blockIdx.y;
// // A vector type to store a part of a key or a query.
// // The vector size is configured in such a way that the threads in a thread group
// // fetch or comput 16 bytes at a time.
// // For example, if the size of a thread group is 4 and the data type is half,
// // then the vector size is 16 / (4 * sizeof(half)) == 2.
// constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t));
// using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
// using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
// constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
// constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
// const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
// const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
// // Load the query to registers.
// // Each thread in a thread group has a different part of the query.
// // For example, if the the thread group size is 4, then the first thread in the group
// // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
// // th vectors of the query, and so on.
// // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
// const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
// Q_vec q_vecs[NUM_VECS_PER_THREAD];
// #pragma unroll
// for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
// q_vecs[i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
// }
// // Memory planning.
// extern __shared__ char shared_mem[];
// // NOTE(woosuk): We use FP32 logits and accumulation.
// float *logits = reinterpret_cast<float*>(shared_mem);
// // Workspace for reduction.
// __shared__ float red_smem[2 * NUM_WARPS];
// // x == THREAD_GROUP_SIZE * VEC_SIZE
// // Each thread group fetches x elements from the key at a time.
// constexpr int x = 16 / sizeof(scalar_t);
// float qk_max = -FLT_MAX;
// const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
// const int mask_boundary = context_len - seq_len + 1 + (seq_idx - seq_start_idx);
// // Iterate over the key blocks.
// // Each warp fetches a block of keys for each iteration.
// // Each thread group in a warp fetches a key from the block, and computes
// // dot product with the query.
// for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
// const int physical_block_number = block_table[block_idx];
// const int physical_block_offset = thread_group_idx % BLOCK_SIZE;
// const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
// // Load a key to registers.
// // Each thread in a thread group has a different part of the key.
// // For example, if the the thread group size is 4, then the first thread in the group
// // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
// // vectors of the key, and so on.
// K_vec k_vecs[NUM_VECS_PER_THREAD];
// #pragma unroll
// for (int i = 0; i < NUM_VECS_PER_THREAD; i++) {
// const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
// + head_idx * HEAD_SIZE * BLOCK_SIZE
// + physical_block_offset * x;
// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
// const int offset1 = (vec_idx * VEC_SIZE) / x;
// const int offset2 = (vec_idx * VEC_SIZE) % x;
// k_vecs[i] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
// }
// // Compute dot product.
// // This includes a reduction across the threads in the same thread group.
// const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
// const bool mask = token_idx >= mask_boundary;
// if (thread_group_offset == 0) {
// // Store the partial reductions to shared memory.
// // NOTE(woosuk): It is required to zero out the masked logits.
// logits[token_idx] = mask ? 0.f : qk;
// // Update the max value.
// qk_max = mask ? qk_max : fmaxf(qk_max, qk);
// }
// }
// // Perform reduction across the threads in the same warp to get the
// // max qk value for each "warp" (not across the thread block yet).
// // The 0-th thread of each thread group already has its max qk value.
// #pragma unroll
// for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
// qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
// }
// if (lane == 0) {
// red_smem[warp_idx] = qk_max;
// }
// __syncthreads();
// // TODO(woosuk): Refactor this part.
// // Get the max qk value for the sequence.
// qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
// #pragma unroll
// for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
// qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
// }
// // Broadcast the max qk value to all threads.
// qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
// // Get the sum of the exp values.
// float exp_sum = 0.f;
// for (int i = thread_idx; i < mask_boundary; i += NUM_THREADS) {
// float val = __expf(logits[i] - qk_max);
// logits[i] = val;
// exp_sum += val;
// }
// exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
// // Compute softmax.
// const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
// for (int i = thread_idx; i < context_len; i += NUM_THREADS) {
// logits[i] *= inv_sum;
// }
// __syncthreads();
// // Each thread will fetch 16 bytes from the value cache at a time.
// constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t);
// using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
// using L_vec = typename FloatVec<V_vec>::Type;
// constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
// constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
// constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;
// float accs[NUM_ROWS_PER_THREAD];
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// accs[i] = 0.f;
// }
// for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) {
// const int physical_block_number = block_table[block_idx];
// const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
// const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
// L_vec logits_vec = *reinterpret_cast<L_vec*>(logits + token_idx);
// const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
// + head_idx * HEAD_SIZE * BLOCK_SIZE;
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
// if (row_idx < HEAD_SIZE) {
// const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
// V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
// accs[i] += dot(logits_vec, cast_to_float(v_vec));
// }
// }
// }
// // Perform reduction within each warp.
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// float acc = accs[i];
// #pragma unroll
// for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
// acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
// }
// accs[i] = acc;
// }
// // NOTE(woosuk): A barrier is required because the shared memory space for logits
// // is reused for the output.
// __syncthreads();
// // Perform reduction across warps.
// float* out_smem = reinterpret_cast<float*>(shared_mem);
// #pragma unroll
// for (int i = NUM_WARPS; i > 1; i /= 2) {
// int mid = i / 2;
// // Upper warps write to shared memory.
// if (warp_idx >= mid && warp_idx < i) {
// float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
// dst[row_idx] = accs[i];
// }
// }
// }
// __syncthreads();
// // Lower warps update the output.
// if (warp_idx < mid) {
// const float* src = &out_smem[warp_idx * HEAD_SIZE];
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
// accs[i] += src[row_idx];
// }
// }
// }
// __syncthreads();
// }
// // Write the final output.
// if (warp_idx == 0) {
// scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
// #pragma unroll
// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
// convert_from_float(*(out_ptr + row_idx), accs[i]);
// }
// }
// }
// }
// // Grid: (num_heads, num_query_tokens).
// template<
// typename scalar_t,
// int HEAD_SIZE,
// int BLOCK_SIZE,
// int NUM_THREADS>
// __global__ void multi_query_cached_kv_attention_kernel(
// const int* cu_query_lens, // [num_prompts+1]
// const int* seq_prompt_mapping, // [num_seqs] mapping from seq_idx to prompt_idx
// scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
// const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
// const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
// const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
// const float scale,
// const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq]
// const int* __restrict__ context_lens, // [num_prompts]
// const int max_num_blocks_per_seq,
// const int q_stride) {
// const int seq_idx = blockIdx.y;
// const int prompt_idx = seq_prompt_mapping[seq_idx];
// const int seq_start_idx = cu_query_lens[prompt_idx];
// const int seq_len = cu_query_lens[prompt_idx + 1] - seq_start_idx;
// const int* block_table = block_tables + prompt_idx * max_num_blocks_per_seq;
// const int context_len = context_lens[prompt_idx];
// multi_query_cached_kv_attention_kernel_unoptimized_<
// scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
// out,
// q,
// seq_start_idx,
// seq_len,
// k_cache,
// v_cache,
// scale,
// block_table,
// context_len,
// max_num_blocks_per_seq,
// q_stride);
// }
// } // namespace cacheflow
// #define LAUNCH_MULTI_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
// cacheflow::multi_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
// <<<grid, block, shared_mem_size, stream>>>( \
// cu_query_lens_ptr, \
// seq_prompt_mapping_ptr, \
// out_ptr, \
// query_ptr, \
// key_cache_ptr, \
// value_cache_ptr, \
// scale, \
// block_tables_ptr, \
// context_lens_ptr, \
// max_num_blocks_per_seq, \
// query_stride);
// // TODO(woosuk): Tune NUM_THREADS.
// template<
// typename T,
// int BLOCK_SIZE,
// int NUM_THREADS = 128>
// void multi_query_cached_kv_attention_launcher(
// torch::Tensor& cu_query_lens,
// torch::Tensor& seq_prompt_mapping,
// torch::Tensor& out,
// torch::Tensor& query,
// torch::Tensor& key_cache,
// torch::Tensor& value_cache,
// float scale,
// torch::Tensor& block_tables,
// torch::Tensor& context_lens,
// int max_context_len) {
// int num_seqs = query.size(0);
// int num_heads = query.size(1);
// int head_size = query.size(2);
// int max_num_blocks_per_seq = block_tables.size(1);
// int query_stride = query.stride(0);
// int* cu_query_lens_ptr = cu_query_lens.data_ptr<int>();
// int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr<int>();
// T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
// T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
// T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
// T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
// int* block_tables_ptr = block_tables.data_ptr<int>();
// int* context_lens_ptr = context_lens.data_ptr<int>();
// constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
// int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
// int logits_size = padded_max_context_len * sizeof(float);
// int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// int shared_mem_size = std::max(logits_size, outputs_size);
// dim3 grid(num_heads, num_seqs);
// dim3 block(NUM_THREADS);
// const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// switch (head_size) {
// case 32:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
// break;
// case 64:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS);
// break;
// case 80:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS);
// break;
// case 96:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
// break;
// case 128:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
// break;
// case 160:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS);
// break;
// case 192:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
// break;
// case 256:
// LAUNCH_MULTI_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
// break;
// default:
// assert(false);
// break;
// }
// }
// void multi_query_cached_kv_attention(
// torch::Tensor& cu_query_lens,
// torch::Tensor& out,
// torch::Tensor& query,
// torch::Tensor& key_cache,
// torch::Tensor& value_cache,
// float scale,
// torch::Tensor& block_tables,
// torch::Tensor& context_lens,
// int block_size,
// int max_context_len) {
// torch::Tensor query_lens = cu_query_lens.to(torch::kCPU);
// int num_queries = query_lens.size(0) - 1;
// const int* query_lens_ptr = query_lens.data_ptr<int>();
// int num_seqs = query.size(0);
// torch::Tensor cpu_tensor = torch::empty({num_seqs}, torch::dtype(torch::kInt32));
// auto accessor = cpu_tensor.accessor<int32_t, 1>();
// for (int i = 0, query_cursor = 0; i < num_seqs; ++i) {
// if (i >= query_lens_ptr[query_cursor + 1]) {
// ++query_cursor;
// }
// accessor[i] = query_cursor;
// }
// // TODO(suquark): This can be slow, as it to(torch::kCPU) and to(torch::kCUDA)
// // implicitly synchronizes the CPU and GPU. And we can avoid this issue by giving
// // the mapping as an input parameter. Let's do this optimization in a later PR.
// torch::Tensor seq_prompt_mapping = cpu_tensor.to(torch::kCUDA);
// // TODO(woosuk): Support BF16.
// if (query.element_size() == 2) {
// // Half.
// if (block_size == 8) {
// multi_query_cached_kv_attention_launcher<uint16_t, 8>(
// cu_query_lens,
// seq_prompt_mapping,
// out,
// query,
// key_cache,
// value_cache,
// scale,
// block_tables,
// context_lens,
// max_context_len);
// } else if (block_size == 16) {
// multi_query_cached_kv_attention_launcher<uint16_t, 16>(
// cu_query_lens,
// seq_prompt_mapping,
// out,
// query,
// key_cache,
// value_cache,
// scale,
// block_tables,
// context_lens,
// max_context_len);
// } else if (block_size == 32) {
// multi_query_cached_kv_attention_launcher<uint16_t, 32>(
// cu_query_lens,
// seq_prompt_mapping,
// out,
// query,
// key_cache,
// value_cache,
// scale,
// block_tables,
// context_lens,
// max_context_len);
// } else {
// assert(false);
// }
// } else if (query.element_size() == 4) {
// // Float.
// if (block_size == 8) {
// multi_query_cached_kv_attention_launcher<float, 8>(
// cu_query_lens,
// seq_prompt_mapping,
// out,
// query,
// key_cache,
// value_cache,
// scale,
// block_tables,
// context_lens,
// max_context_len);
// } else if (block_size == 16) {
// multi_query_cached_kv_attention_launcher<float, 16>(
// cu_query_lens,
// seq_prompt_mapping,
// out,
// query,
// key_cache,
// value_cache,
// scale,
// block_tables,
// context_lens,
// max_context_len);
// } else if (block_size == 32) {
// multi_query_cached_kv_attention_launcher<float, 32>(
// cu_query_lens,
// seq_prompt_mapping,
// out,
// query,
// key_cache,
// value_cache,
// scale,
// block_tables,
// context_lens,
// max_context_len);
// } else {
// assert(false);
// }
// } else {
// assert(false);
// }
// }
#undef WARP_SIZE
#undef MAX
#undef MIN

View File

@ -1,165 +0,0 @@
#pragma once
#include "cuda_primitives.h"
#include <float.h>
#include <type_traits>
#define MMHA_USE_FP32_ACUM_FOR_FMA
#define MMHA_USE_FP32_ACUM_FOR_OUT
namespace cacheflow {
// A vector type to store Q, K, V elements.
template<typename T, int VEC_SIZE>
struct Vec {};
template<>
struct Vec<float, 1> {
using Type = float;
};
template<>
struct Vec<float, 2> {
using Type = float2;
};
template<>
struct Vec<float, 4> {
using Type = float4;
};
template<>
struct Vec<uint16_t, 1> {
using Type = uint16_t;
};
template<>
struct Vec<uint16_t, 2> {
using Type = uint32_t;
};
template<>
struct Vec<uint16_t, 4> {
using Type = uint2;
};
template<>
struct Vec<uint16_t, 8> {
using Type = uint4;
};
template<typename T>
struct FloatVec {};
template<>
struct FloatVec<float> {
using Type = float;
};
template<>
struct FloatVec<float2> {
using Type = float2;
};
template<>
struct FloatVec<float4> {
using Type = float4;
};
template<>
struct FloatVec<uint16_t> {
using Type = float;
};
template<>
struct FloatVec<uint32_t> {
using Type = float2;
};
template<>
struct FloatVec<uint2> {
using Type = Float4_;
};
template<>
struct FloatVec<uint4> {
using Type = Float8_;
};
template<int THREADS_PER_KEY, typename K_vec, int N>
inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N])
{
using K_vec_acum = typename FloatVec<K_vec>::Type;
// Compute the parallel products for Q*K^T (treat vector lanes separately).
K_vec_acum qk_vec = mul<K_vec_acum, K_vec, K_vec>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}
// Finalize the reduction across lanes.
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
}
return qk;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int THREADS_PER_KEY>
struct Qk_dot {
template<typename K_vec, int N>
static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
{
return qk_dot_<THREADS_PER_KEY>(q, k);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
{
float4 c;
float zero = 0.f;
asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
" {%0, %1, %2, %3}, \n"
" {%4, %5}, \n"
" {%6}, \n"
" {%7, %7, %7, %7}; \n"
: "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
: "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int N>
inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N])
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using K_vec_acum = typename FloatVec<uint32_t>::Type;
K_vec_acum qk_vec = mul<K_vec_acum, uint32_t, uint32_t>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
}
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
uint32_t qk_vec_ = float2_to_half2(qk_vec);
return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
#else
return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
#endif
#else
return 0.f;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Qk_dot<uint16_t, 4> {
template<int N>
static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N])
{
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
return qk_hmma_dot_(q, k);
#else
return qk_dot_<4>(q, k);
#endif // defined MMHA_USE_HMMA_FOR_REDUCTION
}
};
} // namespace cacheflow
#undef MMHA_USE_FP32_ACUM_FOR_FMA
#undef MMHA_USE_FP32_ACUM_FOR_OUT

View File

@ -1,6 +1,8 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "dispatch_utils.h"
#include <algorithm>
#include <cassert>
#include <map>
@ -14,14 +16,16 @@ void swap_blocks(
torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type;
if (src_device.is_cuda() && dst_device.is_cuda()) {
assert(src_device.index() == dst_device.index());
TORCH_CHECK(
src_device.index() == dst_device.index(),
"src and dst must be on the same GPU");
memcpy_type = cudaMemcpyDeviceToDevice;
} else if (src_device.is_cuda() && dst_device.is_cpu()) {
memcpy_type = cudaMemcpyDeviceToHost;
} else if (src_device.is_cpu() && dst_device.is_cuda()) {
memcpy_type = cudaMemcpyHostToDevice;
} else {
assert(false);
TORCH_CHECK(false, "Invalid device combination");
}
void *src_ptr = src.data_ptr();
@ -29,6 +33,7 @@ void swap_blocks(
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large.
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
int64_t dst_block_number = pair.second;
@ -43,7 +48,7 @@ void swap_blocks(
}
}
namespace cacheflow {
namespace vllm {
// Grid: (num_layers, num_pairs)
template<typename scalar_t>
@ -74,7 +79,7 @@ __global__ void copy_blocks_kernel(
}
}
} // namespace cacheflow
} // namespace vllm
void copy_blocks(
std::vector<torch::Tensor>& key_caches,
@ -122,9 +127,9 @@ void copy_blocks(
dim3 grid(num_layers, num_pairs);
dim3 block(std::min(1024, numel_per_block));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
VLLM_DISPATCH_FLOATING_TYPES(
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
cacheflow::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
key_cache_ptrs_tensor.data_ptr<int64_t>(),
value_cache_ptrs_tensor.data_ptr<int64_t>(),
block_mapping_tensor.data_ptr<int>(),
@ -132,7 +137,7 @@ void copy_blocks(
}));
}
namespace cacheflow {
namespace vllm {
template<typename scalar_t>
__global__ void reshape_and_cache_kernel(
@ -176,6 +181,48 @@ __global__ void reshape_and_cache_kernel(
}
}
} // namespace vllm
void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping) // [num_tokens]
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),
"reshape_and_cache_kernel",
[&] {
vllm::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int>(),
key_stride,
value_stride,
num_heads,
head_size,
block_size,
x);
});
}
namespace vllm {
// Grid: (num_blocks, block_size).
template<typename scalar_t>
__global__ void gather_cached_kv_kernel(
@ -294,46 +341,7 @@ __global__ void gather_cached_kv_kernel_optimized(
}
}
} // namespace cacheflow
void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping) // [num_tokens]
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int key_stride = key.stride(0);
int value_stride = value.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
key.scalar_type(),
"reshape_and_cache_kernel",
[&] {
cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),
value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int>(),
key_stride,
value_stride,
num_heads,
head_size,
block_size,
x);
});
}
} // namespace vllm
void gather_cached_kv(
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
@ -354,11 +362,11 @@ void gather_cached_kv(
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(),
"gather_cached_kv_kernel_optimized",
[&] {
cacheflow::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
vllm::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(),

File diff suppressed because it is too large Load Diff

14
csrc/dispatch_utils.h Normal file
View File

@ -0,0 +1,14 @@
/*
* Adapted from
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
*/
#include <torch/extension.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

View File

@ -1,9 +1,10 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "reduction_utils.h"
#include "dispatch_utils.h"
#include "reduction_utils.cuh"
namespace cacheflow {
namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template<typename scalar_t>
@ -33,7 +34,7 @@ __global__ void rms_norm_kernel(
}
}
} // namespace cacheflow
} // namespace vllm
void rms_norm(
torch::Tensor& out, // [num_tokens, hidden_size]
@ -46,11 +47,11 @@ void rms_norm(
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(),
"rms_norm_kernel",
[&] {
cacheflow::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),

View File

@ -1,14 +1,16 @@
#include <torch/extension.h>
void rotary_embedding_neox(
void rotary_embedding(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
torch::Tensor& cos_sin_cache);
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"rotary_embedding_neox",
&rotary_embedding_neox,
"Apply GPT-NeoX style rotary embedding to query and key");
"rotary_embedding",
&rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
}

View File

@ -1,78 +1,127 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
namespace cacheflow {
#include "dispatch_utils.h"
template<typename scalar_t>
__global__ void rotary_embedding_neox_kernel(
namespace vllm {
template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding(
scalar_t* __restrict__ arr,
const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr,
int rot_offset,
int embed_dim)
{
int x_index, y_index;
scalar_t cos, sin;
if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = __ldg(cos_ptr + x_index);
sin = __ldg(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = __ldg(cos_ptr + x_index / 2);
sin = __ldg(sin_ptr + x_index / 2);
}
const scalar_t x = arr[x_index];
const scalar_t y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}
template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [num_tokens]
scalar_t* __restrict__ query, // [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, head_size // 2]
const int stride,
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int query_stride,
const int key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * head_size;
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
const int embed_dim = head_size / 2;
const int n = num_heads * embed_dim;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim;
const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * stride + head_idx * head_size;
const int token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
}
const int out_x = token_idx * stride + head_idx * head_size + x_index;
const int out_y = token_idx * stride + head_idx * head_size + y_index;
const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);
const scalar_t q_x = query[token_head + x_index];
const scalar_t q_y = query[token_head + y_index];
query[out_x] = q_x * cos - q_y * sin;
query[out_y] = q_y * cos + q_x * sin;
const scalar_t k_x = key[token_head + x_index];
const scalar_t k_y = key[token_head + y_index];
key[out_x] = k_x * cos - k_y * sin;
key[out_y] = k_y * cos + k_x * sin;
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim);
}
}
} // namespace cacheflow
} // namespace vllm
void rotary_embedding_neox(
void rotary_embedding(
torch::Tensor& positions, // [num_tokens]
torch::Tensor& query, // [num_tokens, num_heads * head_size]
torch::Tensor& key, // [num_tokens, num_heads * head_size]
torch::Tensor& cos_sin_cache) // [max_position, head_size]
{
torch::Tensor& key, // [num_tokens, num_kv_heads * head_size]
int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int num_tokens = query.size(0);
int head_size = cos_sin_cache.size(1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size;
int stride = query.stride(0);
TORCH_CHECK(stride == key.stride(0));
int num_kv_heads = key.size(1) / head_size;
int query_stride = query.stride(0);
int key_stride = key.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size / 2, 512));
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(),
"rotary_embedding_neox",
"rotary_embedding",
[&] {
cacheflow::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
stride,
num_heads,
head_size);
if (is_neox) {
vllm::rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
vllm::rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
}

51
csrc/reduction_utils.cuh Normal file
View File

@ -0,0 +1,51 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace vllm {
template<typename T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(0xffffffff, val, mask, 32);
return val;
}
/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
} // namespace vllm

View File

@ -1,76 +0,0 @@
#pragma once
namespace cacheflow {
template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
inline __device__ float block_sum(float* red_smem, float sum)
{
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}
// Make sure the data is in shared memory.
__syncthreads();
// The warps compute the final sums.
if (lane < WARPS_PER_BLOCK) {
sum = red_smem[lane];
}
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}
// Broadcast to other threads.
return __shfl_sync(uint32_t(-1), sum, 0);
}
#define FINAL_MASK 0xffffffff
template<typename T>
__inline__ __device__ T warpReduceSum(T val)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
}
/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
} // namespace cacheflow

20
docs/Makefile Normal file
View File

@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

19
docs/README.md Normal file
View File

@ -0,0 +1,19 @@
# vLLM documents
## Build the docs
```bash
# Install dependencies.
pip install -r requirements-docs.txt
# Build the docs.
make clean
make html
```
## Open the docs with your browser
```bash
python -m http.server -d build/html/
```
Launch your browser and open localhost:8000.

35
docs/make.bat Normal file
View File

@ -0,0 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

View File

@ -0,0 +1,3 @@
sphinx == 6.2.1
sphinx-book-theme == 1.0.1
sphinx-copybutton == 0.5.2

Binary file not shown.

After

Width:  |  Height:  |  Size: 267 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 285 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 259 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 276 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 244 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 260 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 255 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 272 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 86 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

67
docs/source/conf.py Normal file
View File

@ -0,0 +1,67 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
# -- Project information -----------------------------------------------------
project = 'vLLM'
copyright = '2023, vLLM Team'
author = 'the vLLM Team'
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx.ext.intersphinx",
"sphinx_copybutton",
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
# Exclude the prompt "$" when copying code
copybutton_prompt_text = r"\$ "
copybutton_prompt_is_regexp = True
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_title = project
html_theme = 'sphinx_book_theme'
html_logo = 'assets/logos/vllm-logo-text-light.png'
html_theme_options = {
'logo_only': True,
'path_to_docs': 'docs/source',
'repository_url': 'https://github.com/vllm-project/vllm',
'use_repository_button': True,
}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']

View File

@ -0,0 +1,57 @@
.. _installation:
Installation
============
vLLM is a Python library that also contains some C++ and CUDA code.
This additional code requires compilation on the user's machine.
Requirements
------------
* OS: Linux
* Python: 3.8 or higher
* CUDA: 11.0 -- 11.8
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, etc.)
.. note::
As of now, vLLM does not support CUDA 12.
If you are using Hopper or Lovelace GPUs, please use CUDA 11.8 instead of CUDA 12.
.. tip::
If you have trouble installing vLLM, we recommend using the NVIDIA PyTorch Docker image.
.. code-block:: console
$ # Pull the Docker image with CUDA 11.8.
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
Inside the Docker container, please execute :code:`pip uninstall torch` before installing vLLM.
Install with pip
----------------
You can install vLLM using pip:
.. code-block:: console
$ # (Optional) Create a new conda environment.
$ conda create -n myenv python=3.8 -y
$ conda activate myenv
$ # Install vLLM.
$ pip install vllm # This may take 5-10 minutes.
.. _build_from_source:
Build from source
-----------------
You can also build and install vLLM from source:
.. code-block:: console
$ git clone https://github.com/vllm-project/vllm.git
$ cd vllm
$ pip install -e . # This may take 5-10 minutes.

View File

@ -0,0 +1,131 @@
.. _quickstart:
Quickstart
==========
This guide shows how to use vLLM to:
* run offline batched inference on a dataset;
* build an API server for a large language model;
* start an OpenAI-compatible API server.
Be sure to complete the :ref:`installation instructions <installation>` before continuing with this guide.
Offline Batched Inference
-------------------------
We first show an example of using vLLM for offline batched inference on a dataset. In other words, we use vLLM to generate texts for a list of input prompts.
Import ``LLM`` and ``SamplingParams`` from vLLM. The ``LLM`` class is the main class for running offline inference with vLLM engine. The ``SamplingParams`` class specifies the parameters for the sampling process.
.. code-block:: python
from vllm import LLM, SamplingParams
Define the list of input prompts and the sampling parameters for generation. The sampling temperature is set to 0.8 and the nucleus sampling probability is set to 0.95. For more information about the sampling parameters, refer to the `class definition <https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py>`_.
.. code-block:: python
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
Initialize vLLM's engine for offline inference with the ``LLM`` class and the `OPT-125M model <https://arxiv.org/abs/2205.01068>`_. The list of supported models can be found at :ref:`supported models <supported_models>`.
.. code-block:: python
llm = LLM(model="facebook/opt-125m")
Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM engine's waiting queue and executes the vLLM engine to generate the outputs with high throughput. The outputs are returned as a list of ``RequestOutput`` objects, which include all the output tokens.
.. code-block:: python
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
The code example can also be found in `examples/offline_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py>`_.
API Server
----------
vLLM can be deployed as an LLM service. We provide an example `FastAPI <https://fastapi.tiangolo.com/>`_ server. Check `vllm/entrypoints/api_server.py <https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/api_server.py>`_ for the server implementation. The server uses ``AsyncLLMEngine`` class to support asynchronous processing of incoming requests.
Start the server:
.. code-block:: console
$ python -m vllm.entrypoints.api_server
By default, this command starts the server at ``http://localhost:8000`` with the OPT-125M model.
Query the model in shell:
.. code-block:: console
$ curl http://localhost:8000/generate \
$ -d '{
$ "prompt": "San Francisco is a",
$ "use_beam_search": true,
$ "n": 4,
$ "temperature": 0
$ }'
See `examples/api_client.py <https://github.com/vllm-project/vllm/blob/main/examples/api_client.py>`_ for a more detailed client example.
OpenAI-Compatible Server
------------------------
vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API.
Start the server:
.. code-block:: console
$ python -m vllm.entrypoints.openai.api_server \
$ --model facebook/opt-125m
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_ and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
This server can be queried in the same format as OpenAI API. For example, list the models:
.. code-block:: console
$ curl http://localhost:8000/v1/models
Query the model with input prompts:
.. code-block:: console
$ curl http://localhost:8000/v1/completions \
$ -H "Content-Type: application/json" \
$ -d '{
$ "model": "facebook/opt-125m",
$ "prompt": "San Francisco is a",
$ "max_tokens": 7,
$ "temperature": 0
$ }'
Since this server is compatible with OpenAI API, you can use it as a drop-in replacement for any applications using OpenAI API. For example, another way to query the server is via the ``openai`` python package:
.. code-block:: python
import openai
# Modify OpenAI's API key and API base to use vLLM's API server.
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"
completion = openai.Completion.create(model="facebook/opt-125m",
prompt="San Francisco is a")
print("Completion result:", completion)
For a more detailed client example, refer to `examples/openai_client.py <https://github.com/vllm-project/vllm/blob/main/examples/openai_client.py>`_.

72
docs/source/index.rst Normal file
View File

@ -0,0 +1,72 @@
Welcome to vLLM!
================
.. figure:: ./assets/logos/vllm-logo-text-light.png
:width: 60%
:align: center
:alt: vLLM
:class: no-scaled-link
.. raw:: html
<p style="text-align:center">
<strong>Easy, fast, and cheap LLM serving for everyone
</strong>
</p>
<p style="text-align:center">
<script async defer src="https://buttons.github.io/buttons.js"></script>
<a class="github-button" href="https://github.com/vllm-project/vllm" data-show-count="true" data-size="large" aria-label="Star">Star</a>
<a class="github-button" href="https://github.com/vllm-project/vllm/subscription" data-icon="octicon-eye" data-size="large" aria-label="Watch">Watch</a>
<a class="github-button" href="https://github.com/vllm-project/vllm/fork" data-icon="octicon-repo-forked" data-size="large" aria-label="Fork">Fork</a>
</p>
vLLM is a fast and easy-to-use library for LLM inference and serving.
vLLM is fast with:
* State-of-the-art serving throughput
* Efficient management of attention key and value memory with **PagedAttention**
* Continuous batching of incoming requests
* Optimized CUDA kernels
vLLM is flexible and easy to use with:
* Seamless integration with popular HuggingFace models
* High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
* Tensor parallelism support for distributed inference
* Streaming outputs
* OpenAI-compatible API server
For more information, check out the following:
* `vLLM announcing blog post <https://vllm.ai>`_ (intro to PagedAttention)
* `How continuous batching enables 23x throughput in LLM inference while reducing p50 latency <https://www.anyscale.com/blog/continuous-batching-llm-inference>`_ by Cade Daniel et al.
Documentation
-------------
.. toctree::
:maxdepth: 1
:caption: Getting Started
getting_started/installation
getting_started/quickstart
.. toctree::
:maxdepth: 1
:caption: Serving
serving/distributed_serving
serving/run_on_sky
.. toctree::
:maxdepth: 1
:caption: Models
models/supported_models
models/adding_model

View File

@ -0,0 +1,94 @@
.. _adding_a_new_model:
Adding a New Model
==================
This document provides a high-level guide on integrating a `HuggingFace Transformers <https://github.com/huggingface/transformers>`_ model into vLLM.
.. note::
The complexity of adding a new model depends heavily on the model's architecture.
The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex.
.. tip::
If you are encountering issues while integrating your model into vLLM, feel free to open an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ repository.
We will be happy to help you out!
0. Fork the vLLM repository
--------------------------------
Start by forking our `GitHub <https://github.com/vllm-project/vllm/>`_ repository and then :ref:`build it from source <build_from_source>`.
This gives you the ability to modify the codebase and test your model.
1. Bring your model code
------------------------
Clone the PyTorch model code from the HuggingFace Transformers repository and put it into the `vllm/model_executor/models <https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models>`_ directory.
For instance, vLLM's `OPT model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/opt.py>`_ was adpated from the HuggingFace's `modeling_opt.py <https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py>`_ file.
.. warning::
When copying the model code, make sure to review and adhere to the code's copyright and licensing terms.
2. Rewrite the :code:`forward` methods
--------------------------------------
Next, you need to rewrite the :code:`forward` methods of your model by following these steps:
1. Remove any unnecessary code, such as the code only used for training.
2. Change the input parameters:
.. code-block:: diff
def forward(
self,
input_ids: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
-) -> Union[Tuple, CausalLMOutputWithPast]:
+ positions: torch.Tensor,
+ kv_caches: List[KVCache],
+ input_metadata: InputMetadata,
+ cache_events: Optional[List[torch.cuda.Event]],
+) -> SamplerOutput:
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.
.. note::
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM.
3. (Optional) Implement tensor parallelism support
--------------------------------------------------
If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it.
To do this, substitute your model's linear and embedding layers with their tensor-parallel versions.
For the embedding layer, you can simply replace :code:`nn.Embedding` with :code:`VocabParallelEmbedding`.
When it comes to the linear layers, you should use either :code:`RowParallelLinear` or :code:`ColumnParallelLinear`.
Typically, :code:`ColumnParallelLinear` is used for QKV linear layers and the first linear layers of the MLP blocks.
For the remaining linear layers, :code:`RowParallelLinear` is used.
4. Implement the weight loading logic
-------------------------------------
You now need to implement the :code:`load_weights` method in your :code:`*ForCausalLM` class.
This method should load the weights from the HuggingFace's checkpoint file and assign them to the corresponding layers in your model.
While the process is straightforward for most layers, the tensor-parallel layers necessitate some additional care as their weights should be partitioned to multiple GPUs.
5. Register your model
----------------------
Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/__init__.py>`_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/model_loader.py>`_.

View File

@ -0,0 +1,72 @@
.. _supported_models:
Supported Models
================
vLLM supports a variety of generative Transformer models in `HuggingFace Transformers <https://huggingface.co/models>`_.
The following is the list of model architectures that are currently supported by vLLM.
Alongside each architecture, we include some popular models that use it.
.. list-table::
:widths: 25 25 50
:header-rows: 1
* - Architecture
- Models
- Example HuggingFace Models
* - :code:`AquilaForCausalLM`
- Aquila
- :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc.
* - :code:`BaiChuanForCausalLM`
- Baichuan
- :code:`baichuan-inc/Baichuan-7B`, :code:`baichuan-inc/Baichuan-13B-Chat`, etc.
* - :code:`BloomForCausalLM`
- BLOOM, BLOOMZ, BLOOMChat
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
* - :code:`FalconForCausalLM`
- Falcon
- :code:`tiiuae/falcon-7b``, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
* - :code:`GPT2LMHeadModel`
- GPT-2
- :code:`gpt2`, :code:`gpt2-xl`, etc.
* - :code:`GPTBigCodeForCausalLM`
- StarCoder, SantaCoder, WizardCoder
- :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc.
* - :code:`GPTJForCausalLM`
- GPT-J
- :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc.
* - :code:`GPTNeoXForCausalLM`
- GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
* - :code:`InternLMForCausalLM`
- InternLM
- :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc.
* - :code:`LlamaForCausalLM`
- LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, etc.
* - :code:`MPTForCausalLM`
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
* - :code:`OPTForCausalLM`
- OPT, OPT-IML
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
* - :code:`QWenLMHeadModel`
- Qwen
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ project.
.. tip::
The easiest way to check if your model is supported is to run the program below:
.. code-block:: python
from vllm import LLM
llm = LLM(model=...) # Name or path of your model
output = llm.generate("Hello, my name is")
print(output)
If vLLM successfully generates text, it indicates that your model is supported.

View File

@ -0,0 +1,38 @@
.. _distributed_serving:
Distributed Inference and Serving
=================================
vLLM supports distributed tensor-parallel inference and serving. Currently, we support `Megatron-LM's tensor parallel algorithm <https://arxiv.org/pdf/1909.08053.pdf>`_. We manage the distributed runtime with `Ray <https://github.com/ray-project/ray>`_. To run distributed inference, install Ray with:
.. code-block:: console
$ pip install ray
To run multi-GPU inference with the :code:`LLM` class, set the :code:`tensor_parallel_size` argument to the number of GPUs you want to use. For example, to run inference on 4 GPUs:
.. code-block:: python
from vllm import LLM
llm = LLM("facebook/opt-13b", tensor_parallel_size=4)
output = llm.generate("San Franciso is a")
To run multi-GPU serving, pass in the :code:`--tensor-parallel-size` argument when starting the server. For example, to run API server on 4 GPUs:
.. code-block:: console
$ python -m vllm.entrypoints.api_server \
$ --model facebook/opt-13b \
$ --tensor-parallel-size 4
To scale vLLM beyond a single machine, start a `Ray runtime <https://docs.ray.io/en/latest/ray-core/starting-ray.html>`_ via CLI before running vLLM:
.. code-block:: console
$ # On head node
$ ray start --head
$ # On worker nodes
$ ray start --address=<ray-head-address>
After that, you can run inference and serving on multiple machines by launching the vLLM process on the head node by setting :code:`tensor_parallel_size` to the number of GPUs to be the total number of GPUs across all machines.

View File

@ -0,0 +1,69 @@
.. _on_cloud:
Running on clouds with SkyPilot
===============================
.. raw:: html
<p align="center">
<img src="https://imgur.com/yxtzPEu.png" alt="vLLM"/>
</p>
vLLM can be run on the cloud to scale to multiple GPUs with `SkyPilot <https://github.com/skypilot-org/skypilot>`__, an open-source framework for running LLMs on any cloud.
To install SkyPilot and setup your cloud credentials, run:
.. code-block:: console
$ pip install skypilot
$ sky check
See the vLLM SkyPilot YAML for serving, `serving.yaml <https://github.com/skypilot-org/skypilot/blob/master/llm/vllm/serve.yaml>`__.
.. code-block:: yaml
resources:
accelerators: A100
envs:
MODEL_NAME: decapoda-research/llama-13b-hf
TOKENIZER: hf-internal-testing/llama-tokenizer
setup: |
conda create -n vllm python=3.9 -y
conda activate vllm
git clone https://github.com/vllm-project/vllm.git
cd vllm
pip install .
pip install gradio
run: |
conda activate vllm
echo 'Starting vllm api server...'
python -u -m vllm.entrypoints.api_server \
--model $MODEL_NAME \
--tensor-parallel-size $SKYPILOT_NUM_GPUS_PER_NODE \
--tokenizer $TOKENIZER 2>&1 | tee api_server.log &
echo 'Waiting for vllm api server to start...'
while ! `cat api_server.log | grep -q 'Uvicorn running on'`; do sleep 1; done
echo 'Starting gradio server...'
python vllm/examples/gradio_webserver.py
Start the serving the LLaMA-13B model on an A100 GPU:
.. code-block:: console
$ sky launch serving.yaml
Check the output of the command. There will be a sharable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion.
.. code-block:: console
(task, pid=7431) Running on public URL: https://<gradio-hash>.gradio.live
**Optional**: Serve the 65B model instead of the default 13B and use more GPU:
.. code-block:: console
sky launch -c vllm-serve-new -s serve.yaml --gpus A100:8 --env MODEL_NAME=decapoda-research/llama-65b-hf

77
examples/api_client.py Normal file
View File

@ -0,0 +1,77 @@
"""Example Python client for vllm.entrypoints.api_server"""
import argparse
import json
from typing import Iterable, List
import requests
def clear_line(n: int = 1) -> None:
LINE_UP = '\033[1A'
LINE_CLEAR = '\x1b[2K'
for _ in range(n):
print(LINE_UP, end=LINE_CLEAR, flush=True)
def post_http_request(prompt: str,
api_url: str,
n: int = 1,
stream: bool = False) -> requests.Response:
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
"n": n,
"use_beam_search": True,
"temperature": 0.0,
"max_tokens": 16,
"stream": stream,
}
response = requests.post(api_url, headers=headers, json=pload, stream=True)
return response
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
yield output
def get_response(response: requests.Response) -> List[str]:
data = json.loads(response.content)
output = data["text"]
return output
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--n", type=int, default=4)
parser.add_argument("--prompt", type=str, default="San Francisco is a")
parser.add_argument("--stream", action="store_true")
args = parser.parse_args()
prompt = args.prompt
api_url = f"http://{args.host}:{args.port}/generate"
n = args.n
stream = args.stream
print(f"Prompt: {prompt!r}\n", flush=True)
response = post_http_request(prompt, api_url, n, stream)
if stream:
num_printed_lines = 0
for h in get_streaming_response(response):
clear_line(num_printed_lines)
num_printed_lines = 0
for i, line in enumerate(h):
num_printed_lines += 1
print(f"Beam candidate {i}: {line!r}", flush=True)
else:
output = get_response(response)
for i, line in enumerate(output):
print(f"Beam candidate {i}: {line!r}", flush=True)

View File

@ -0,0 +1,52 @@
import argparse
import json
import gradio as gr
import requests
def http_bot(prompt):
headers = {"User-Agent": "vLLM Client"}
pload = {
"prompt": prompt,
"stream": True,
"max_tokens": 128,
}
response = requests.post(args.model_url,
headers=headers,
json=pload,
stream=True)
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"][0]
yield output
def build_demo():
with gr.Blocks() as demo:
gr.Markdown("# vLLM text completion demo\n")
inputbox = gr.Textbox(label="Input",
placeholder="Enter text and press ENTER")
outputbox = gr.Textbox(label="Output",
placeholder="Generated result from the model")
inputbox.submit(http_bot, [inputbox], [outputbox])
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--model-url",
type=str,
default="http://localhost:8000/generate")
args = parser.parse_args()
demo = build_demo()
demo.queue(concurrency_count=100).launch(server_name=args.host,
server_port=args.port,
share=True)

View File

@ -0,0 +1,51 @@
import argparse
from vllm import EngineArgs, LLMEngine, SamplingParams
def main(args: argparse.Namespace):
# Parse the CLI argument and initialize the engine.
engine_args = EngineArgs.from_cli_args(args)
engine = LLMEngine.from_engine_args(engine_args)
# Test the following prompts.
test_prompts = [
("A robot may not injure a human being",
SamplingParams(temperature=0.0)),
("To be or not to be,",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
("What is the meaning of life?",
SamplingParams(n=2,
best_of=5,
temperature=0.8,
top_p=0.95,
frequency_penalty=0.1)),
("It is only with the heart that one can see rightly",
SamplingParams(n=3, best_of=3, use_beam_search=True,
temperature=0.0)),
]
# Run the engine by calling `engine.step()` manually.
request_id = 0
while True:
# To test continuous batching, we add one request at each step.
if test_prompts:
prompt, sampling_params = test_prompts.pop(0)
engine.add_request(str(request_id), prompt, sampling_params)
request_id += 1
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
print(request_output)
if not (engine.has_unfinished_requests() or test_prompts):
break
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Demo on using the LLMEngine class directly')
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,22 @@
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="facebook/opt-125m")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -0,0 +1,33 @@
import openai
# Modify OpenAI's API key and API base to use vLLM's API server.
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"
# List models API
models = openai.Model.list()
print("Models:", models)
model = models["data"][0]["id"]
# Chat completion API
chat_completion = openai.ChatCompletion.create(
model=model,
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}])
print("Chat completion results:")
print(chat_completion)

View File

@ -0,0 +1,28 @@
import openai
# Modify OpenAI's API key and API base to use vLLM's API server.
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"
# List models API
models = openai.Model.list()
print("Models:", models)
model = models["data"][0]["id"]
# Completion API
stream = False
completion = openai.Completion.create(
model=model,
prompt="A robot may not injure a human being",
echo=False,
n=2,
stream=stream,
logprobs=3)
print("Completion results:")
if stream:
for c in completion:
print(c)
else:
print(completion)

108
format.sh Executable file
View File

@ -0,0 +1,108 @@
#!/usr/bin/env bash
# YAPF formatter, adapted from ray and skypilot.
#
# Usage:
# # Do work and commit your work.
# # Format files that differ from origin/main.
# bash format.sh
# # Commit changed files with message 'Run yapf and pylint'
#
#
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.
# Cause the script to exit if a single command fails
set -eo pipefail
# this stops git rev-parse from failing if we run this from the .git directory
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
ROOT="$(git rev-parse --show-toplevel)"
builtin cd "$ROOT" || exit 1
YAPF_VERSION=$(yapf --version | awk '{print $2}')
PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}')
MYPY_VERSION=$(mypy --version | awk '{print $2}')
# # params: tool name, tool version, required version
tool_version_check() {
if [[ $2 != $3 ]]; then
echo "Wrong $1 version installed: $3 is required, not $2."
exit 1
fi
}
tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "pylint" $PYLINT_VERSION "$(grep "pylint==" requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
YAPF_FLAGS=(
'--recursive'
'--parallel'
)
YAPF_EXCLUDES=(
'--exclude' 'build/**'
'--exclude' 'vllm/model_executor/parallel_utils/**'
)
# Format specified files
format() {
yapf --in-place "${YAPF_FLAGS[@]}" "$@"
}
# Format files that differ from main branch. Ignores dirs that are not slated
# for autoformat yet.
format_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause yapf to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
fi
}
# Format all files
format_all() {
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm
}
## This flag formats individual files. --files *must* be the first command line
## arg to use this option.
if [[ "$1" == '--files' ]]; then
format "${@:2}"
# If `--all` is passed, then any further arguments are ignored and the
# entire python directory is formatted.
elif [[ "$1" == '--all' ]]; then
format_all
else
# Format only the files that changed in last commit.
format_changed
fi
echo 'vLLM yapf: Done'
# Run mypy
# TODO(zhuohan): Enable mypy
# echo 'vLLM mypy:'
# mypy
# Run Pylint
echo 'vLLM Pylint:'
pylint vllm
if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.'
echo 'Changes not staged for commit:'
echo
git --no-pager diff --name-only
exit 1
fi

8
mypy.ini Normal file
View File

@ -0,0 +1,8 @@
[mypy]
python_version = 3.8
ignore_missing_imports = True
files = vllm
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = vllm/model_executor/parallel_utils/|vllm/model_executor/models/

View File

@ -1,20 +0,0 @@
import requests
import json
def http_bot():
prompt = "How are you? I'm fine."
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
}
response = requests.post("http://localhost:10002", headers=headers, json=pload, stream=True)
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
yield output
for h in http_bot():
print(h, end="", flush=True)

View File

@ -1,40 +0,0 @@
import argparse
import asyncio
import time
from typing import Union
import json
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import uvicorn
app = FastAPI()
async def text_streamer(args):
context = args["prompt"]
words = context.split(" ")
for word in words:
await asyncio.sleep(1)
print("word:", word)
ret = {
"text": word + " ",
"error": 0,
}
yield (json.dumps(ret) + "\0").encode("utf-8")
@app.post("/")
async def read_root(request: Request):
args = await request.json()
return StreamingResponse(text_streamer(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=10002)
args = parser.parse_args()
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

Some files were not shown because too many files have changed in this diff Show More