Compare commits

...

242 Commits

Author SHA1 Message Date
ff075ecb92 Merge branch 'final-version' of github.com:huggingface/new-model-addition-meta into sparse-llama4-moe-2 2025-04-07 19:31:46 +00:00
afbd7ab321 Repo consistency 2025-04-05 21:10:57 +02:00
50a8daab79 git push Merge branch 'final-version' of github.com:huggingface/new-model-addition-meta into final-version 2025-04-05 19:03:01 +00:00
538ba2b07b fix styling 2025-04-05 19:02:51 +00:00
7c471ea72a fix 2025-04-05 21:02:44 +02:00
99b6bc8f40 ruff fix fast image processor 2025-04-05 19:02:28 +00:00
44a90c0f66 cleanup removal of slow image processor 2025-04-05 19:01:53 +00:00
8c5093488b Code quality 2025-04-05 21:01:00 +02:00
cbb6e5990a Code quality 2025-04-05 20:59:52 +02:00
b8786474bb Merge branch 'final-version' of github.com:huggingface/new-model-addition-meta into final-version 2025-04-05 18:58:22 +00:00
949b1b7e98 Nuke bunch of failing stuff 2025-04-05 18:57:42 +00:00
71521afb43 Code quality 2025-04-05 20:55:51 +02:00
8167ac4c57 Code quality 2025-04-05 20:54:49 +02:00
0c8624b28e fix more tests for now 2025-04-05 18:54:26 +00:00
ecaa1a7beb Code quality 2025-04-05 20:53:22 +02:00
6ba8ef7fca Code quality 2025-04-05 20:49:36 +02:00
555c4eeb07 Merge branch 'final-version' of github.com:huggingface/new-model-addition-meta into final-version 2025-04-05 18:44:41 +00:00
5ce5746bf3 fix dynamic cache 2025-04-05 18:44:19 +00:00
5b96e5d27b Merge pull request #58 from huggingface/only-fast-image-processor
Only fast image processor is supported
2025-04-05 20:40:09 +02:00
4994729f62 Merge branch 'final-version' of github.com:huggingface/new-model-addition-meta into final-version 2025-04-05 18:39:24 +00:00
6c6e9014b3 fix issue with flex encoder 2025-04-05 18:38:09 +00:00
ab8bbadc61 trigger CI 2025-04-05 20:22:45 +02:00
d73aea8c60 nit 2025-04-05 20:17:34 +02:00
bac11b511f Only fast image processor is supported 2025-04-05 20:16:51 +02:00
34f6e9ef8c remove warning 2025-04-05 17:49:11 +00:00
b97451ea5d Merge branch 'main' of github.com:huggingface/new-model-addition-meta into final-version 2025-04-05 15:32:36 +00:00
7f292e1f17 remove comment 2025-04-05 15:29:32 +00:00
ed669a34b6 fix and clean docstrings 2025-04-05 15:21:35 +00:00
931dad92f8 Merge branch 'final-version' of github.com:huggingface/new-model-addition-meta into final-version 2025-04-05 15:17:54 +00:00
1895d02cbe fix dummies 2025-04-05 15:15:32 +00:00
7d5d5f0d4e Merge pull request #55 from huggingface/tgi_cuda_graph_fix
Cuda graph fix
2025-04-05 17:08:15 +02:00
f87c237832 Merge pull request #51 from huggingface/code-quality
Some cleanup
2025-04-05 17:08:02 +02:00
c53e2595a7 more styling changes 2025-04-05 15:06:43 +00:00
26b56748d6 commit licence, cleanup here and there and style 2025-04-05 15:05:22 +00:00
54785ef229 Merge branch 'code-quality' of github.com:huggingface/new-model-addition-meta into code-quality 2025-04-05 14:45:22 +00:00
695c1e7f6f fixup 2025-04-05 14:44:47 +00:00
688dc5cf3b Merge branch 'final-version' into code-quality 2025-04-05 16:43:11 +02:00
3eab44367c Update src/transformers/models/llama4/modeling_llama4.py
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
2025-04-05 20:13:05 +05:30
29028393c6 Merge pull request #54 from huggingface/fix-tp-pipeline
Fix tp pipeline
2025-04-05 16:39:54 +02:00
fb495fd935 Merge pull request #44 from huggingface/fix_style
Fix tests
2025-04-05 16:36:05 +02:00
83282a199c styling 2025-04-05 14:34:32 +00:00
9f03f059f3 fixup 2025-04-05 14:34:21 +00:00
ad839d3cb7 revert some stuff 2025-04-05 14:32:23 +00:00
eb9e4afb2b Merge branch 'final-version' of github.com:huggingface/new-model-addition-meta into fix-tp-pipeline 2025-04-05 14:31:42 +00:00
8578252f72 Merge pull request #53 from huggingface/l4-docs
Docs!
2025-04-05 16:28:46 +02:00
aeec2dce4b Merge pull request #49 from huggingface/fix-quantization
Fix quantization
2025-04-05 16:25:17 +02:00
b239675add Merge pull request #38 from huggingface/smol-fix
Fix subconfig
2025-04-05 16:22:22 +02:00
27364dafc3 cuda graph fix 2025-04-05 14:17:42 +00:00
f7756b4e07 Fixes 2025-04-05 16:04:21 +02:00
0849d322f6 Update docs/source/en/model_doc/llama4.md
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
2025-04-05 16:03:48 +02:00
46b081560b Fixes 2025-04-05 16:03:30 +02:00
48b4f5638c Docs! 2025-04-05 15:26:39 +02:00
8dbf7cb9c5 rm print 2025-04-05 12:42:28 +00:00
3d58f8e120 restrict to compressed tensors for now 2025-04-05 12:38:51 +00:00
f642d32cf5 Don't move to cuda:0 in distributed mode 2025-04-05 12:33:48 +00:00
4c4bc81cb3 clean up 2025-04-05 12:24:03 +00:00
a471b10482 update 2025-04-05 11:41:04 +00:00
61f45af60e update 2025-04-05 11:40:06 +00:00
2374ff7102 modulelist 2025-04-05 11:40:03 +00:00
ce91d95e3e patch 2025-04-05 11:39:01 +00:00
9b2e35df68 replace correctly module 2025-04-05 11:20:52 +00:00
66c36a47b5 fix 2025-04-05 10:13:20 +00:00
141da6575e update 2025-04-05 09:20:27 +00:00
c38bf3a87b fix 2025-04-05 10:43:35 +02:00
ef8dbe2bf1 Some cleanup 2025-04-05 10:39:07 +02:00
fd150bb783 Merge pull request #50 from huggingface/fix-eager
eager needs 4D mask
2025-04-05 10:34:29 +02:00
6ab068259f fix 2025-04-05 08:27:36 +00:00
a3e8267d62 eager needs 4D mask 2025-04-05 08:17:23 +00:00
45cf5828ef fix text geneartion pipeline
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
2025-04-05 00:24:49 -07:00
93022de716 fix fp8 loading
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
2025-04-04 18:17:27 -07:00
0130b2df91 skipping some tests 2025-04-04 23:22:52 +00:00
e547b10bf8 fix 2025-04-04 21:46:22 +00:00
7bda11f26f update tests 2025-04-04 21:42:16 +00:00
6a8b9f6294 Merge pull request #46 from huggingface/keep-nrope-layers-fix
Use correct no_rope_layers if provided one is empty
2025-04-04 14:42:01 -07:00
7c03c7e020 Use correct no_rope_layers if provided one is empty list 2025-04-04 13:23:11 -07:00
f5dd6fb7c5 fix 2025-04-04 18:54:26 +00:00
a43e0561a4 Fix MoE vs FF (#41) 2025-04-04 20:08:42 +02:00
516e08fbfe Add support for sparse Llama4TextMoe layer from the kernel hub 2025-04-04 17:57:53 +02:00
535030a0b8 Fix typo and remove warning with compiled flex and chunked prefill 2025-04-04 17:57:34 +02:00
6ca6f66c92 Revert "Merge pull request #36 from huggingface/sparse-llama4-moe"
This reverts commit ccda19f050867dd42ea143c5de60f3dec81375f0, reversing
changes made to a515579aed8c0fe9bf529b6c40446a289406d5d6.
2025-04-04 17:54:25 +02:00
fcee23da4b fix config 2025-04-04 15:54:16 +00:00
ec7656a487 Merge pull request #35 from huggingface/fix-context-length
Fix context length
2025-04-04 17:40:25 +02:00
598dded8cc Merge branch 'final-version' into fix-context-length 2025-04-04 17:40:06 +02:00
ccda19f050 Merge pull request #36 from huggingface/sparse-llama4-moe
Add support for sparse `Llama4TextMoe` layer from the kernel hub
2025-04-04 17:39:19 +02:00
a9045fc914 Merge branch 'final-version' of github.com:huggingface/new-model-addition-meta into sparse-llama4-moe 2025-04-04 15:29:12 +00:00
a515579aed Merge pull request #14 from huggingface/norope
Add support for no rope
2025-04-04 17:28:06 +02:00
04b302a73a fixup 2025-04-04 15:26:28 +00:00
7414235469 revert 2025-04-04 15:25:19 +00:00
cb58ceac00 eager should use chunked_attention_mask 2025-04-04 15:24:46 +00:00
51f7cd24d1 fix 2025-04-04 15:22:21 +00:00
206c8aea2f remove dict 2025-04-04 15:19:10 +00:00
f781885878 style 2025-04-04 15:16:28 +00:00
3612b9cc6e Read initializer_range 2025-04-04 15:15:10 +00:00
f03660ad67 Sync eos terminators 2025-04-04 15:14:49 +00:00
fd0f2733db Save attention_chunk_size 2025-04-04 15:14:32 +00:00
1a76267569 Update modeling_llama4.py 2025-04-04 17:14:26 +02:00
bc44b2bee2 Write max_position_embeddings and max_model_length 2025-04-04 15:13:43 +00:00
5da08327e3 revert print 2025-04-04 15:11:08 +00:00
bfc8049c13 nits 2025-04-04 15:10:43 +00:00
85b3c7acc6 still broken fixing now 2025-04-04 14:57:06 +00:00
64c2133da7 update 2025-04-04 14:21:07 +00:00
d7d09a1788 Merge branch 'norope' of github.com:huggingface/new-model-addition-meta into sparse-llama4-moe 2025-04-04 14:14:04 +00:00
373a472e93 better merge 2025-04-04 14:10:56 +00:00
61626d0d8e cleanup 2025-04-04 14:04:33 +00:00
dcb29eb805 Add support for sparse Llama4TextMoe layer from the kernel hub 2025-04-04 14:01:07 +00:00
05cc59e1bf fix SDPA! 2025-04-04 13:53:50 +00:00
aa8daba2be Merge pull request #21 from huggingface/add_fbgemm
Adding fbgemm
2025-04-04 14:45:41 +02:00
bdfb5731ed update tp_plan 2025-04-04 12:08:34 +00:00
174eda3c4b allocate 2025-04-04 12:03:05 +00:00
9f9974b149 Merge branch 'final-version' into add_fbgemm 2025-04-04 13:26:03 +02:00
efb45772c9 fix 2025-04-04 11:15:06 +00:00
c7d4c883ca Fix context length 2025-04-04 11:04:44 +00:00
7990c78f5d commit cache utils cleanup 2025-04-04 10:03:04 +00:00
99f2297e03 Merge pull request #30 from huggingface/fix-causal-lm-loading
Fix causal lm loading
2025-04-04 11:46:45 +02:00
30cacf7038 rm processing 2025-04-04 09:42:38 +00:00
7f8941d298 fix shapes in general 2025-04-04 10:03:02 +02:00
eb167f2864 fix auto factory 2025-04-04 08:00:16 +00:00
e19af4b3ad try to revert the potentially breaking change 2025-04-04 07:48:25 +00:00
60a58cb749 Merge branch 'final-version' into norope 2025-04-04 09:08:58 +02:00
eb535ee04e chunking 2025-04-04 07:07:35 +00:00
96066e09af style 2025-04-04 07:06:54 +00:00
aeaad13a33 Fix flex impl 2025-04-04 06:05:15 +00:00
f2bbb4ba50 add compressed_tensos & fix fbgemm tp 2025-04-03 21:01:50 +00:00
24dbcad649 should work with both short and long 2025-04-03 17:14:41 +00:00
a820dbe5e8 Merge branch 'norope' of github.com:huggingface/new-model-addition-meta into norope 2025-04-03 16:19:14 +00:00
7a001691a0 push current version 2025-04-03 16:18:50 +00:00
4eabf8f28d Merge pull request #32 from huggingface/remove-warning
Remove tied weights warning
2025-04-03 17:06:27 +02:00
ba2e4641ec Merge branch 'norope' of github.com:huggingface/new-model-addition-meta into norope 2025-04-03 14:58:27 +00:00
6decf84454 fix sdpa 2025-04-03 14:57:14 +00:00
ff1df035fd fix tied-weights 2025-04-03 14:42:05 +00:00
6d564d0341 Merge pull request #20 from huggingface/conversion-fixes
Conversion fixes
2025-04-03 14:41:19 +02:00
ed6cba8756 rm 2025-04-03 11:42:03 +00:00
06413dcd3f fix causallml loading 2025-04-03 11:36:49 +00:00
cf83f0b740 Fix pad_token_id
See
https://huggingface.co/ll-re/Llama-4-Scout-17B-16E/discussions/2/files
Confirmed in the original codebase.
2025-04-03 09:56:02 +00:00
29f55d2bdc oups 2025-04-03 09:46:41 +00:00
c06da80c5b push what works, dirty trick for the device synch 2025-04-03 09:43:46 +00:00
5b1721bbe5 add missing attn scales 2025-04-03 05:35:33 +00:00
9e2e0f958f add floor scale 2025-04-03 05:21:01 +00:00
5e87ba9cde add support for attn_temperature_tuning 2025-04-03 05:17:22 +00:00
90e8e2c81d current updates 2025-04-03 05:13:41 +00:00
0a10252492 Merge pull request #28 from huggingface/cleanup-mllama4
Clean up Llama4 vision model
2025-04-03 06:45:27 +02:00
99ec54bf8c Clean up Llama4 vision model 2025-04-02 16:48:22 -07:00
0f5b27ba5f Merge pull request #26 from huggingface/meta/fix-nope
[Fix Nope] Expose no_rope_layer_interval from Llama4TextConfig for vLLM
2025-04-03 01:21:38 +02:00
b25084be23 minor fix
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
2025-04-02 15:36:03 -07:00
b98cde8397 fix nope
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
2025-04-02 15:30:53 -07:00
ddf899369a add tp_plan 2025-04-02 19:29:39 +00:00
12451706ac Ignore new key present in base models 2025-04-02 17:07:37 +00:00
ef479fa1eb Merge pull request #23 from huggingface/minor_tgi_fix
Minor fix
2025-04-02 18:01:26 +02:00
196d87ed73 add worldsize and make eager attn work for vision 2025-04-02 15:56:37 +00:00
2f8d05bdb7 nits 2025-04-02 15:38:02 +00:00
e472a4eeaf fix 2025-04-02 15:37:39 +00:00
8b0a8c9f4b update 2025-04-02 15:19:14 +00:00
c29469ce4f update 2025-04-02 14:37:25 +00:00
2133277bd5 add context parallel 2025-04-02 14:24:30 +00:00
f418d0624a missking keys 2025-04-02 13:11:58 +00:00
ab268fb832 updates 2025-04-02 12:42:03 +00:00
85cf8b924d nits 2025-04-02 12:31:54 +00:00
7a2afb3db1 updates 2025-04-02 12:05:39 +00:00
da1e6910bf fixes for now flex seems to work :) 2025-04-02 11:55:43 +00:00
ce5d1ea052 Add boi/eoi tokens
We don't use them.
2025-04-02 11:51:06 +00:00
afcc7ec352 style 2025-04-02 11:28:20 +00:00
ef31789fd0 Save processor, fix chat template 2025-04-02 11:13:01 +00:00
fe240a6a5a adding fp8 2025-04-02 11:08:07 +00:00
37391a3ebb lol don't know how this got here 2025-04-02 10:32:05 +00:00
a417896dc8 Add <|eom|> 2025-04-02 10:13:11 +00:00
822f29610d Read new params from config 2025-04-02 10:10:18 +00:00
b61c859c69 small fix for TP on 1 node 2025-04-02 09:51:51 +00:00
eb677fa9a3 Merge pull request #19 from huggingface/reverting_quantization_fix
Revert changes for_llm_compressor
2025-04-02 11:01:52 +02:00
9679739db3 nits 2025-04-02 08:29:08 +00:00
fa75c34943 revert 2025-04-02 07:58:45 +00:00
cd4a2dae19 Merge pull request #18 from huggingface/meta/fix-model-loading
[WIP] Fix Llama4 128E fp8 loading issue
2025-04-02 07:26:10 +02:00
233c7df8bf fix model loading
Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
2025-04-01 19:41:16 -07:00
6da9409fb2 comment todo small 2025-04-01 19:30:13 +00:00
4047e8651a Merge pull request #15 from huggingface/fix_quantization
Replace Experts layer in quantizer
2025-04-01 21:12:29 +02:00
21eb873c86 Merge pull request #17 from huggingface/tests
[llama4/mm] Fix image processor unit tests
2025-04-01 12:08:38 -07:00
0a9da1b588 sdpa works 2025-04-01 19:04:30 +00:00
2ad69a48c4 add explicit dtype
Signed-off-by: Jon Swenson <jmswen@gmail.com>
2025-04-01 12:03:51 -07:00
f4f9fbce14 [llama4/mm] Fix Llama 4 image processing unit tests 2025-04-01 11:44:25 -07:00
6f63da6205 Merge pull request #16 from huggingface/global-tile
[llama4/mm] Add back <|image|> tag in tokenization corresponding to global tile
2025-04-01 10:51:34 -07:00
5be1b28a76 [llama4/mm] Add back <|image|> token that delimits global tile 2025-04-01 10:33:37 -07:00
725171611a nit 2025-04-01 16:35:38 +00:00
558c096db5 rebase and delete llm_compressor 2025-04-01 16:05:59 +00:00
6529cade32 small updates 2025-04-01 16:05:01 +00:00
c338736bde add layer 2025-04-01 16:00:58 +00:00
31d88f178d fix 2025-04-01 16:00:58 +00:00
d728d06f86 fixes 2025-04-01 15:10:17 +00:00
1f4072b33d Merge branch 'final-version' into norope 2025-04-01 14:50:10 +00:00
0c3dc0c00f support flex attention 2025-04-01 14:43:19 +00:00
c358a1b4de fix post merge with main 2025-04-01 14:25:01 +00:00
ec85fa388b Merge branch 'final-version' of github.com:huggingface/new-model-addition-meta into final-version 2025-04-01 13:50:56 +00:00
0c3f25a52d Merge branch 'main' of github.com:huggingface/new-model-addition-meta into final-version 2025-04-01 13:49:04 +00:00
71385f16a4 nit 2025-04-01 13:46:48 +00:00
87abef5a38 Merge pull request #13 from huggingface/meta_vllm
Meta vllm
2025-04-01 14:43:33 +02:00
5b8dd838ce Update src/transformers/models/llama4/image_processing_llama4_fast.py 2025-04-01 14:41:04 +02:00
e53363d1b5 Add changes for no_rope, moe_layers, chunked attention. Just need to test all 2025-04-01 12:39:44 +00:00
90d587624a Merge branch 'final-version' of github.com:huggingface/new-model-addition-meta into final-version 2025-04-01 08:35:28 +00:00
55a17c58c7 nit 2025-04-01 08:34:53 +00:00
c487c62d54 Merge pull request #4 from huggingface/moe-128
128 experts
2025-04-01 10:32:48 +02:00
b077bb5b19 Update src/transformers/models/llama4/configuration_llama4.py 2025-04-01 10:30:41 +02:00
4af4c778d1 Fix parameter_count name 2025-04-01 06:59:40 +00:00
4a1fec8d3e Merge remote-tracking branch 'origin/final-version' into moe-128 2025-04-01 06:33:12 +00:00
3bf26c269e Merge pull request #11 from huggingface/remove-aspect-ratios
Remove `aspect_ratios` from `Llama4Processor` output
2025-03-31 20:45:39 -07:00
ca64ae5095 [llama4] Pop aspect_ratios from image processor output in Llama4Processor
Signed-off-by: Jon Swenson <jmswen@gmail.com>
2025-03-31 19:27:35 -07:00
7352034288 remove un-used imports 2025-03-31 18:24:23 -07:00
24d4599a06 un-comment write_tokenizer from converting script 2025-03-31 16:18:37 -07:00
189a1032ca Moe 128 rebased (#8)
* 128 experts

* Use default rope

* Unfuse mlp

* Address feedback

* Use None "default" for rope_scaling. Add eot.

* Meta/llama quant compat (#7)

* add quant compatible model & conversion code for llama4

* fix a few issues

* fix a few issues

* minor type mapping fix

---------

Co-authored-by: Lu Fang <fanglu@fb.com>

* use a new config parameter to determine which model definition to use for MoE

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Lu Fang <fanglu@fb.com>
2025-03-31 16:14:19 -07:00
fb748af776 return aspect ratios and bug fixes 2025-03-31 16:12:01 -07:00
b5373e20b2 Merge branch 'main' of github.com:huggingface/new-model-addition-meta into final-version 2025-03-31 15:25:41 +00:00
ed00fb308b set seed 2025-03-31 15:22:11 +00:00
b38318d1df Use None "default" for rope_scaling. Add eot. 2025-03-30 14:23:27 +00:00
54be1a01c1 Address feedback 2025-03-28 17:04:31 +00:00
ca0cd0ea11 Merge branch 'final-version' into moe-128 2025-03-28 09:31:11 +00:00
82004d95f9 Merge pull request #5 from huggingface/fixes_cleanups
Supports multi-image prompting and batching.
2025-03-26 16:55:45 +01:00
ddf7adc269 fix from review 2025-03-26 16:54:48 +01:00
03e993957e remove .item() 👀 2025-03-26 14:00:32 +00:00
9c0ef18c50 Merge branch 'fixes_cleanups' of github.com:huggingface/new-model-addition-meta into fixes_cleanups 2025-03-26 13:57:28 +00:00
52787d5ce3 simplify a lot inputs embeds merging 2025-03-26 13:55:20 +00:00
b06a26b7df Unfuse mlp 2025-03-26 13:27:07 +00:00
347a7620cb Merge branch 'final-version' into fixes_cleanups 2025-03-26 13:29:17 +01:00
1be3ddc3da Use default rope 2025-03-26 08:45:07 +00:00
972c465e06 128 experts 2025-03-25 22:20:17 +00:00
671c37bdc3 fixup size 2025-03-25 20:45:48 +00:00
5bebf97869 multi-image fixes in modeling + processor 2025-03-25 20:30:31 +00:00
d9e3f86a5c [convert] Use num_experts 2025-03-25 16:02:27 +00:00
507857d76c [convert] Minor fixes 2025-03-25 12:38:06 +00:00
aa595de667 [convert] Strip extraneous bytes from shards 2025-03-25 12:37:44 +00:00
5e9d84f376 [convert] Fix typo 2025-03-25 10:34:00 +00:00
6c04e10cb6 update fake image token 2025-03-21 09:50:07 +00:00
db2821e6b2 vllm updates 2025-03-19 13:14:22 +00:00
ba7a8aad28 some cleanups 2025-03-19 13:13:54 +00:00
8da4b6e849 cleanups 2025-03-13 19:29:53 +00:00
693fc474e0 update 2025-03-13 18:13:32 +00:00
0cf2e771c6 nit 2025-03-13 17:50:00 +00:00
2defa9c728 Merge branch 'final-version' of github.com:huggingface/new-model-addition-meta into final-version 2025-03-13 17:42:32 +00:00
660dc8c76c more quality of life improvements 2025-03-13 17:42:22 +00:00
1854fc9034 styling 2025-03-13 17:22:16 +00:00
e3c52a2fe5 update fast image processor after refactor 2025-03-13 16:41:14 +00:00
9a75c63ff9 remove one of the last deps 2025-03-13 16:32:02 +00:00
89f2a0242f fix 2025-03-13 16:06:35 +00:00
4288c9f5eb update 2025-03-13 14:18:53 +00:00
0c4bc3b115 Merge branch 'main' of github.com:huggingface/new-model-addition-meta into final-version 2025-03-13 13:21:05 +00:00
0d55fefa13 Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com>
Co-authored-by: Pablo Montalvo <39954772+molbap@users.noreply.github.com>
Co-authored-by: pcuenca <pedro@latenitesoft.com>
2025-03-13 13:18:58 +00:00
46 changed files with 5397 additions and 286 deletions

View File

@ -507,6 +507,8 @@
title: Llama2
- local: model_doc/llama3
title: Llama3
- local: model_doc/llama4
title: Llama4
- local: model_doc/longformer
title: Longformer
- local: model_doc/longt5

View File

@ -0,0 +1,446 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Llama4
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
</div>
</div>
Llama 4, developed by Meta, introduces a new auto-regressive Mixture-of-Experts (MoE) architecture.
This generation includes two models:
- The highly capable Llama 4 Maverick with 17B active parameters out of ~400B total, with 128 experts.
- The efficient Llama 4 Scout also has 17B active parameters out of ~109B total, using just 16 experts.
-
Both models leverage early fusion for native multimodality, enabling them to process text and image inputs.
Maverick and Scout are both trained on up to 40 trillion tokens on data encompassing 200 languages
(with specific fine-tuning support for 12 languages including Arabic, Spanish, German, and Hindi).
For deployment, Llama 4 Scout is designed for accessibility, fitting on a single server-grade GPU via
on-the-fly 4-bit or 8-bitint4 quantization, while Maverick is available in BF16 and FP8 formats.
These models are released under the custom Llama 4 Community License Agreement, available on the model repositories.
You can find all the original Llama checkpoints under the [meta-llama](https://huggingface.co/meta-llama) organization.
> [!TIP]
> The Llama 4 family of models comes in two flavors: 109B, and 402B parameters. Both of these flavors are extremely
> large and won't fit on your run-of-the-mill device. See below for some examples to reduce the memory usage of the
> model.
>
> For the download to be faster and more resilient, we recommend installing the `hf_xet` dependency as followed:
> `pip install transformers[hf_xet]`
The examples below demonstrates how to generate with [`Pipeline`] or the [`AutoModel`]. We additionally add an example
showcasing how to toggle the right attributes to enable very long-context generations, as some flavors of Llama 4
have context lengths going up to 10 million tokens.
<hfoptions id="usage">
<hfoption id="Pipeline">
```py
from transformers import pipeline
import torch
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
messages = [
{"role": "user", "content": "what is the recipe of mayonnaise?"},
]
pipe = pipeline(
"text-generation",
model=model_id,
device_map="auto",
torch_dtype=torch.bfloat16
)
output = pipe(messages, do_sample=False, max_new_tokens=200)
print(output[0]["generated_text"][-1]["content"])
```
</hfoption>
<hfoption id="AutoModel - Text only">
```py
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
import torch
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
messages = [
{"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16
)
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
print(outputs[0])
```
</hfoption>
<hfoption id="AutoModel - Multimodal">
```py
from transformers import AutoProcessor, Llama4ForConditionalGeneration
import torch
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
processor = AutoProcessor.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": img_url},
{"type": "text", "text": "Describe this image in two sentences."},
]
},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=256,
)
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
print(response)
```
</hfoption>
<hfoption id="AutoModel - Multimodal with multiple images">
```py
from transformers import AutoProcessor, Llama4ForConditionalGeneration
import torch
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
processor = AutoProcessor.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
url1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"
url2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
messages = [
{
"role": "user",
"content": [
{"type": "image", "url": url1},
{"type": "image", "url": url2},
{"type": "text", "text": "Can you describe how these two images are similar, and how they differ?"},
]
},
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=256,
)
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
print(response)
```
</hfoption>
<hfoption id="AutoModel - Long context">
Beware: the example below uses both `device_map="auto"` and flex-attention.
Please use `torchrun` to run this example in tensor-parallel mode.
We will work to enable running with `device_map="auto"` and flex-attention without
tensor-parallel in the future.
```py
from transformers import Llama4ForConditionalGeneration, AutoTokenizer
import torch
import time
file = "very_long_context_prompt.txt"
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
with open(file, "r") as f:
very_long_text = "\n".join(f.readlines())
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
attn_implementation="flex_attention",
torch_dtype=torch.bfloat16
)
messages = [
{"role": "user", "content": f"Look at the following texts: [{very_long_text}]\n\n\n\nWhat are the books, and who wrote them? Make me a nice list."},
]
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
torch.cuda.synchronize()
start = time.time()
out = model.generate(
input_ids.to(model.device),
prefill_chunk_size=2048*8,
max_new_tokens=300,
cache_implementation="hybrid",
)
print(time.time()-start)
print(tokenizer.batch_decode(out[:, input_ids.shape[-1]:]))
print(f"{torch.cuda.max_memory_allocated(model.device) / 1024**3:.2f} GiB")
```
</hfoption>
</hfoptions>
## Efficiency; how to get the best out of llama 4
### The Attention methods
Updating the default attention function can significantly improve compute performance as well as memory usage. Refer to the [Attention Interface](../attention_interface) overview for an in-depth explanation of our interface.
As of release, the Llama 4 model supports the following attention methods: `eager`, `flex_attention`, `sdpa`. We recommend using `flex_attention` for best results.
Switching attention mechanism is done at the model initialization step:
<hfoptions id="Attention">
<hfoption id="Flex Attention">
Setting Flex Attention ensures the best results with the very long context the model can handle.
> [!TIP] Beware: the example below uses both `device_map="auto"` and flex-attention.
> Please use `torchrun` to run this example in tensor-parallel mode.
>
> We will work to enable running with `device_map="auto"` and flex-attention without
> tensor-parallel in the future.
```py
from transformers import Llama4ForConditionalGeneration
import torch
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
attn_implementation="flex_attention",
device_map="auto",
torch_dtype=torch.bfloat16,
)
```
</hfoption>
<hfoption id="SDPA">
The `sdpa` attention method is generally more compute-efficient than the `eager` method.
```py
from transformers import Llama4ForConditionalGeneration
import torch
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
attn_implementation="sdpa",
device_map="auto",
torch_dtype=torch.bfloat16,
)
```
</hfoption>
<hfoption id="Eager">
The `eager` attention method is set by default, so no need for anything different when loading the model:
```py
from transformers import Llama4ForConditionalGeneration
import torch
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
```
</hfoption>
</hfoptions>
### Quantization
Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for available quantization backends.
At time of release, both FBGEMM and LLM-Compressor are supported; more quantization methods will be supported in the days that follow the release.
See below for examples using both:
Here is an example loading an BF16 model in FP8 using the FBGEMM approach:
<hfoptions id="Quantization">
<hfoption id="FBGEMM">
```python
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, FbgemmFp8Config
import torch
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
messages = [
{"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
quantization_config=FbgemmFp8Config()
)
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
print(outputs[0])
```
</hfoption>
<hfoption id="LLM-Compressor">
To use the LLM-Compressor technique, we recommend leveraging the pre-quantized FP8 checkpoint available with the release:
```python
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
import torch
model_id = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
tokenizer = AutoTokenizer.from_pretrained(model_id)
messages = [
{"role": "user", "content": "Who are you?"},
]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
tp_plan="auto",
torch_dtype=torch.bfloat16,
)
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])
print(outputs[0])
```
</hfoption>
</hfoptions>
### Offloading
Enabling CPU-offloading means that components of the model might be moved to CPU instead of GPU in case the GPU-memory available isn't sufficient to load the entire model.
At inference, different components will be loaded/unloaded from/to the GPU on the fly. This ensures that the model can be loaded on smaller machines as long as the CPU-memory is sufficient.
However, this also slows down inference as it adds communication overhead.
In order to enable CPU-offloading, you simply need to specify the `device_map` to `auto` at model load:
```py
from transformers import Llama4ForConditionalGeneration
import torch
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
```
## Llama4Config
[[autodoc]] Llama4Config
## Llama4TextConfig
[[autodoc]] Llama4TextConfig
## Llama4VisionConfig
[[autodoc]] Llama4VisionConfig
## Llama4Processor
[[autodoc]] Llama4Processor
## Llama4ImageProcessorFast
[[autodoc]] Llama4ImageProcessorFast
## Llama4ImageProcessor
[[autodoc]] Llama4ImageProcessor
## Llama4ForConditionalGeneration
[[autodoc]] Llama4ForConditionalGeneration
- forward
## Llama4ForCausalLM
[[autodoc]] Llama4ForCausalLM
- forward
## Llama4TextModel
[[autodoc]] Llama4TextModel
- forward
## Llama4ForCausalLM
[[autodoc]] Llama4ForCausalLM
- forward
## Llama4VisionModel
[[autodoc]] Llama4VisionModel
- forward

View File

@ -562,6 +562,12 @@ _import_structure = {
"models.levit": ["LevitConfig"],
"models.lilt": ["LiltConfig"],
"models.llama": ["LlamaConfig"],
"models.llama4": [
"Llama4Config",
"Llama4Processor",
"Llama4TextConfig",
"Llama4VisionConfig",
],
"models.llava": [
"LlavaConfig",
"LlavaProcessor",
@ -1354,6 +1360,7 @@ else:
_import_structure["models.detr"].append("DetrImageProcessorFast")
_import_structure["models.gemma3"].append("Gemma3ImageProcessorFast")
_import_structure["models.got_ocr2"].append("GotOcr2ImageProcessorFast")
_import_structure["models.llama4"].append("Llama4ImageProcessorFast")
_import_structure["models.llava"].append("LlavaImageProcessorFast")
_import_structure["models.llava_next"].append("LlavaNextImageProcessorFast")
_import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast")
@ -2510,6 +2517,15 @@ else:
"GlmPreTrainedModel",
]
)
_import_structure["models.llama4"].extend(
[
"Llama4ForCausalLM",
"Llama4ForConditionalGeneration",
"Llama4TextModel",
"Llama4VisionModel",
"Llama4PreTrainedModel",
]
)
_import_structure["models.glpn"].extend(
[
"GLPNForDepthEstimation",
@ -5807,6 +5823,12 @@ if TYPE_CHECKING:
from .models.levit import LevitConfig
from .models.lilt import LiltConfig
from .models.llama import LlamaConfig
from .models.llama4 import (
Llama4Config,
Llama4Processor,
Llama4TextConfig,
Llama4VisionConfig,
)
from .models.llava import (
LlavaConfig,
LlavaProcessor,
@ -6646,6 +6668,7 @@ if TYPE_CHECKING:
from .models.detr import DetrImageProcessorFast
from .models.gemma3 import Gemma3ImageProcessorFast
from .models.got_ocr2 import GotOcr2ImageProcessorFast
from .models.llama4 import Llama4ImageProcessorFast
from .models.llava import LlavaImageProcessorFast
from .models.llava_next import LlavaNextImageProcessorFast
from .models.llava_onevision import LlavaOnevisionImageProcessorFast
@ -7827,6 +7850,13 @@ if TYPE_CHECKING:
LlamaModel,
LlamaPreTrainedModel,
)
from .models.llama4 import (
Llama4ForCausalLM,
Llama4ForConditionalGeneration,
Llama4PreTrainedModel,
Llama4TextModel,
Llama4VisionModel,
)
from .models.llava import (
LlavaForConditionalGeneration,
LlavaPreTrainedModel,

View File

@ -1626,7 +1626,7 @@ class HybridCache(Cache):
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. If you're using more than 1 computation device, you
should pass the `layer_device_map` argument instead.
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
dtype (torch.dtype, *optional*, defaults to `torch.bfloat16`):
The default `dtype` to use when initializing the layer.
layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*):
Mapping between the layers and its device. This is required when you are manually initializing the cache
@ -1663,84 +1663,73 @@ class HybridCache(Cache):
max_batch_size: int,
max_cache_len: Optional[int] = None,
device: Union[torch.device, str, None] = None,
dtype: torch.dtype = torch.float32,
dtype: torch.dtype = torch.bfloat16,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)
self.sliding_window = getattr(config.get_text_config(), "attention_chunk_size", 8192)
else:
self.sliding_window = config.sliding_window
self.max_cache_len = max_cache_len
self.max_batch_size = max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self._dtype = dtype
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
self.is_sliding = torch.tensor(
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
)
if hasattr(config.get_text_config(), "no_rope_layers"):
self.is_sliding = config.no_rope_layers
else:
layer_switch = getattr(config, "sliding_window_pattern", 2)
self.is_sliding = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
self.cumulative_length = [0 for _ in range(config.num_hidden_layers)]
def initialise_cache_layer(self, layer_idx, key_states):
if len(self.key_cache) > layer_idx:
return
num_key_value_heads = key_states.shape[1]
device = key_states.device
global_cache_shape = (self.max_batch_size, num_key_value_heads, self.max_cache_len, self.head_dim)
sliding_cache_shape = (
self.max_batch_size,
self.num_key_value_heads,
min(config.sliding_window, max_cache_len),
num_key_value_heads,
self.sliding_window,
self.head_dim,
)
device = torch.device(device) if device is not None and isinstance(device, str) else None
for i in range(config.num_hidden_layers):
if layer_device_map is not None:
layer_device = layer_device_map[i]
else:
layer_device = device
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = sliding_cache_shape if self.is_sliding[layer_idx] else global_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=device)
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
if cache_position.shape[0] > max_cache_len:
k_out = key_states[:, :, -max_cache_len:, :]
v_out = value_states[:, :, -max_cache_len:, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states
cumulative_length = self.cumulative_length[layer_idx]
is_full = cumulative_length >= max_cache_len
if is_full:
full_key_states = torch.cat((k_out[:, :, 1:, :], key_states), dim=-2)
full_value_states = torch.cat((v_out[:, :, 1:, :], value_states), dim=-2)
elif not is_full and cumulative_length + key_states.shape[2] > max_cache_len:
full_key_states = torch.cat((k_out[:, :, :cumulative_length, :], key_states), dim=-2)
full_value_states = torch.cat((v_out[:, :, :cumulative_length, :], value_states), dim=-2)
else:
self.key_cache[layer_idx].index_copy_(2, cache_position, key_states)
self.value_cache[layer_idx].index_copy_(2, cache_position, value_states)
self.cumulative_length[layer_idx] += key_states.shape[-2]
return self.key_cache[layer_idx], self.value_cache[layer_idx]
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, max_cache_len - 1)
to_shift = cache_position >= max_cache_len - 1
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
return k_out, v_out
self.key_cache[layer_idx].copy_(full_key_states[:, :, -max_cache_len:, :])
self.value_cache[layer_idx].copy_(full_value_states[:, :, -max_cache_len:, :])
self.cumulative_length[layer_idx] += key_states.shape[-2]
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return full_key_states, full_value_states
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
k_out[:, :, cache_position] = key_states
@ -1760,7 +1749,7 @@ class HybridCache(Cache):
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window")
self.initialise_cache_layer(layer_idx, key_states)
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
# when the cache is initialized in the forward pass (e.g. Gemma2)
@ -1774,7 +1763,7 @@ class HybridCache(Cache):
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
if sliding_window:
if self.is_sliding[layer_idx]:
update_fn = self._sliding_update
else:
update_fn = self._static_update
@ -1801,6 +1790,8 @@ class HybridCache(Cache):
"`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
"Using the `layer_idx` argument is not supported."
)
if len(self.key_cache) == 0:
return 0
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
def reset(self):
@ -1809,6 +1800,7 @@ class HybridCache(Cache):
# In-place ops prevent breaking the static address
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
self.cumulative_length = [0 for _ in range(len(self.cumulative_length))]
class MambaCache:

View File

@ -801,18 +801,19 @@ class PretrainedConfig(PushToHubMixin):
def to_diff_dict(self) -> dict[str, Any]:
"""
Removes all attributes from config which correspond to the default config attributes for better readability and
serializes to a Python dictionary.
Removes all attributes from the configuration that correspond to the default config attributes for
better readability, while always retaining the `config` attribute from the class. Serializes to a
Python dictionary.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance.
"""
config_dict = self.to_dict()
# get the default config dict
# Get the default config dict (from a fresh PreTrainedConfig instance)
default_config_dict = PretrainedConfig().to_dict()
# get class specific config dict
# Get class-specific config dict if not part of a composition
class_config_dict = self.__class__().to_dict() if not self.is_composition else {}
serializable_config_dict = {}
@ -847,8 +848,7 @@ class PretrainedConfig(PushToHubMixin):
if not isinstance(self.quantization_config, dict)
else self.quantization_config
)
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
# Pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = serializable_config_dict.pop("_pre_quantization_dtype", None)
self.dict_torch_dtype_to_str(serializable_config_dict)

View File

@ -416,6 +416,7 @@ class GenerationConfig(PushToHubMixin):
if isinstance(self.cache_config, dict):
self.cache_config = cache_config_class.from_dict(self.cache_config)
self.return_legacy_cache = kwargs.pop("return_legacy_cache", None)
self.prefill_chunk_size = kwargs.pop("prefill_chunk_size", None)
# Parameters for manipulation of the model output logits
self.temperature = kwargs.pop("temperature", 1.0)

View File

@ -3405,7 +3405,12 @@ class GenerationMixin:
os.environ["TOKENIZERS_PARALLELISM"] = "0"
model_forward = self.get_compiled_call(generation_config.compile_config)
is_prefill = True
if generation_config.prefill_chunk_size is not None:
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
is_prefill = False
else:
is_prefill = True
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
@ -4855,6 +4860,45 @@ class GenerationMixin:
else:
return input_ids
def _prefill_chunking(self, input_ids: torch.LongTensor, generation_config: GenerationConfig, **model_kwargs):
# Even if we are not compiling the forward, flex is always compiled when used. With chunk prefill, we may
# end up needing just a bit more graphs than the default (which is 8). Doing this avoids very cryptic warnings
torch._dynamo.config.cache_size_limit = 64
chunk_size = generation_config.prefill_chunk_size
# Only chunk up the token just before last, so that decoding is completely performed outside this function
# (here we simply prefill the cache)
input_chunks = torch.split(input_ids[:, :-1], chunk_size, dim=-1)
if "past_key_values" not in model_kwargs:
raise ValueError("Cannot use prefill chunkink without a cache")
model_forward = self.get_compiled_call(generation_config.compile_config)
attention_mask = model_kwargs.pop("attention_mask", None)
past_length = 0
for input_chunk in input_chunks:
current_length = past_length + input_chunk.shape[-1]
# Prepare inputs
if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask[:, :current_length]
model_kwargs["cache_position"] = torch.arange(
past_length, current_length, dtype=torch.long, device=input_chunk.device
)
model_kwargs["position_ids"] = model_kwargs["cache_position"].unsqueeze(0)
model_inputs = self.prepare_inputs_for_generation(input_chunk, **model_kwargs)
outputs = model_forward(**model_inputs, return_dict=True)
model_kwargs["past_key_values"] = outputs.past_key_values
past_length = current_length
model_kwargs["attention_mask"] = attention_mask
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1
_ = model_kwargs.pop("position_ids", None)
return model_kwargs
def _speculative_sampling(
candidate_input_ids,

View File

@ -53,7 +53,7 @@ _import_structure = {
"unset_hf_deepspeed_config",
],
"eetq": ["replace_with_eetq_linear"],
"fbgemm_fp8": ["FbgemmFp8Linear", "replace_with_fbgemm_fp8_linear"],
"fbgemm_fp8": ["FbgemmFp8Linear", "FbgemmFp8Llama4TextExperts", "replace_with_fbgemm_fp8_linear"],
"finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"],
"fsdp": ["is_fsdp_managed_module"],
"ggml": [
@ -192,7 +192,7 @@ if TYPE_CHECKING:
unset_hf_deepspeed_config,
)
from .eetq import replace_with_eetq_linear
from .fbgemm_fp8 import FbgemmFp8Linear, replace_with_fbgemm_fp8_linear
from .fbgemm_fp8 import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts, replace_with_fbgemm_fp8_linear
from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear
from .fsdp import is_fsdp_managed_module
from .ggml import (

View File

@ -0,0 +1,54 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers.utils import is_torch_available
if is_torch_available():
import torch
import torch.nn as nn
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP
def skip(*args, **kwargs):
pass
class CompressedExpertsLinear(nn.Module):
"""
A module that implements a compressed version of a list of expert modules.
This is specifically designed to work with Llama4TextExperts in MoE layers.
"""
def __init__(self, config):
# Skip random weight initialization for experts. Otherwise,
# the init of this module would take over minutes. For a model
# with tens of layers of experts, it would easily take over 20 minutes.
nn.init.kaiming_uniform_ = skip
nn.init.uniform_ = skip
nn.init.normal_ = skip
super().__init__()
self.num_experts = config.num_local_experts
self.expert_modules = nn.ModuleList([Llama4TextMLP(config) for _ in range(self.num_experts)])
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
hidden_states = hidden_states.reshape(self.num_experts, -1, hidden_states.shape[-1])
expert_routed_out_list = []
for expert_idx in range(self.num_experts):
expert_routed_out_list.append(self.expert_modules[expert_idx](hidden_states[expert_idx]))
routed_out = torch.cat(expert_routed_out_list, dim=0)
return routed_out

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..activations import ACT2FN
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
@ -28,36 +29,36 @@ if is_fbgemm_gpu_available():
logger = logging.get_logger(__name__)
class FbgemmFp8Linear(torch.nn.Module):
class FbgemmFp8Linear(torch.nn.Linear):
def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32):
super().__init__()
super().__init__(in_features, out_features, bias)
self.in_features = in_features
self.out_features = out_features
self.register_buffer("weight", torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
self.register_buffer("weight_scale", torch.zeros((out_features, 1), dtype=weight_dtype))
self.weight = torch.nn.Parameter(torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
self.weight_scale = torch.nn.Parameter(torch.zeros((out_features, 1), dtype=weight_dtype))
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
if bias:
self.register_buffer("bias", torch.zeros((self.out_features), dtype=weight_dtype))
self.bias = torch.nn.Parameter(torch.zeros((self.out_features), dtype=weight_dtype))
else:
self.bias = None
def forward(self, x):
num_tokens = None
# quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here
output_shape = (*x.shape[:-1], -1)
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x.view(-1, x.shape[-1]), num_tokens, self.input_scale_ub
x.view(-1, x.shape[-1]), scale_ub=self.input_scale_ub
)
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
# The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
weight_scale_float32 = self.weight_scale.to(torch.float32)
output = torch.ops.fbgemm.f8f8bf16_rowwise(
x_quantized, self.weight, x_scale, self.weight_scale, use_fast_accum=True
x_quantized, self.weight, x_scale, weight_scale_float32, use_fast_accum=True
)
output = output + self.bias if self.bias is not None else output
# Hacky for now, we have the output to the device of x
@ -67,6 +68,92 @@ class FbgemmFp8Linear(torch.nn.Module):
return output
class FbgemmFp8Llama4TextExperts(nn.Module):
def __init__(self, config, dtype=torch.float32):
super().__init__()
self.num_experts = config.num_local_experts
self.intermediate_size = config.intermediate_size
self.hidden_size = config.hidden_size
self.expert_dim = self.intermediate_size
self.act_fn = ACT2FN[config.hidden_act]
# Register FP8 buffers for gate_up_proj
self.gate_up_proj = torch.nn.Parameter(
torch.zeros((self.num_experts, self.hidden_size, 2 * self.expert_dim), dtype=torch.float8_e4m3fn)
)
self.gate_up_proj_scale = torch.nn.Parameter(
torch.zeros((self.num_experts, 1, self.expert_dim * 2), dtype=torch.float32)
)
# Register FP8 buffers for down_proj
self.down_proj = torch.nn.Parameter(
torch.zeros((self.num_experts, self.expert_dim, self.hidden_size), dtype=torch.float8_e4m3fn)
)
self.down_proj_scale = torch.nn.Parameter(
torch.zeros((self.num_experts, self.hidden_size, 1), dtype=torch.float32)
)
# Register input scale upper bound
self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
def forward(self, hidden_states):
"""
Args:
hidden_states (torch.Tensor): (batch_size * token_num, hidden_size)
Returns:
torch.Tensor: (batch_size * token_num, hidden_size)
"""
# Reshape hidden states for expert computation
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
num_tokens = None
# Pre-allocate tensor for all expert outputs with same shape as hidden_states
next_states = torch.empty_like(hidden_states)
for i in range(self.num_experts):
# Extract expert's hidden states
expert_hidden = hidden_states[i]
expert_hidden_reshaped = expert_hidden.reshape(-1, self.hidden_size)
# Quantize for this expert
expert_quantized, expert_scale = torch.ops.fbgemm.quantize_fp8_per_row(
expert_hidden_reshaped, num_tokens, self.input_scale_ub
)
sharded_expert_dim = self.gate_up_proj.shape[-1] // 2
gate_up_proj_scale_float32 = self.gate_up_proj_scale.to(torch.float32)
gate = torch.ops.fbgemm.f8f8bf16_rowwise(
expert_quantized,
self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous(),
expert_scale,
gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous(),
use_fast_accum=True,
)
up = torch.ops.fbgemm.f8f8bf16_rowwise(
expert_quantized,
self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous(),
expert_scale,
gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous(),
use_fast_accum=True,
)
activated = up * self.act_fn(gate)
activated_quantized, activated_scale = torch.ops.fbgemm.quantize_fp8_per_row(
activated, num_tokens, self.input_scale_ub
)
down_proj_scale_float32 = self.down_proj_scale.to(torch.float32)
expert_output = torch.ops.fbgemm.f8f8bf16_rowwise(
activated_quantized,
self.down_proj[i].transpose(0, 1).contiguous(),
activated_scale,
down_proj_scale_float32[i].view(-1, 1).contiguous(),
use_fast_accum=True,
)
next_states[i] = expert_output
next_states = next_states.to(hidden_states.device)
return next_states.view(-1, self.hidden_size)
def _replace_with_fbgemm_fp8_linear(
model,
modules_to_not_convert=None,
@ -74,12 +161,17 @@ def _replace_with_fbgemm_fp8_linear(
quantization_config=None,
has_been_replaced=False,
pre_quantized=False,
config=None,
tp_plan=None,
):
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
import re
if current_key_name is None:
current_key_name = []
@ -105,9 +197,27 @@ def _replace_with_fbgemm_fp8_linear(
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
# set non persistant buffer outside of init_empty_weights
model._modules[name].input_scale_ub = torch.tensor(
[quantization_config.activation_scale_ub],
dtype=torch.float,
)
if module.__class__.__name__ == "Llama4TextExperts" and name not in modules_to_not_convert:
current_key_name_str = ".".join(current_key_name)
if not any(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
with init_empty_weights(include_buffers=True):
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj_scale")] = tp_plan[
re.sub(r"\d+", "*", current_key_name_str + ".gate_up_proj")
]
tp_plan[re.sub(r"\d+", "*", current_key_name_str + ".down_proj_scale")] = None
model._modules[name] = FbgemmFp8Llama4TextExperts(
config.text_config,
)
model._modules[name].input_scale_ub = torch.tensor(
[quantization_config.activation_scale_ub], dtype=torch.float
)
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_fbgemm_fp8_linear(
module,
@ -116,6 +226,8 @@ def _replace_with_fbgemm_fp8_linear(
quantization_config,
has_been_replaced=has_been_replaced,
pre_quantized=pre_quantized,
config=config,
tp_plan=tp_plan,
)
# Remove the last key for recursion
current_key_name.pop(-1)
@ -123,7 +235,13 @@ def _replace_with_fbgemm_fp8_linear(
def replace_with_fbgemm_fp8_linear(
model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
pre_quantized=False,
config=None,
tp_plan=None,
):
"""
A helper function to replace all `torch.nn.Linear` modules by `FbgemmFp8Linear` modules.
@ -151,9 +269,14 @@ def replace_with_fbgemm_fp8_linear(
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
modules_to_not_convert = list(set(modules_to_not_convert))
model, has_been_replaced = _replace_with_fbgemm_fp8_linear(
model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
model,
modules_to_not_convert,
current_key_name,
quantization_config,
pre_quantized=pre_quantized,
config=config,
tp_plan=tp_plan,
)
if not has_been_replaced:
logger.warning(
"You are loading your model using FP8 quantization but no linear modules were found in your model."

View File

@ -34,10 +34,7 @@ from ..utils import is_torch_flex_attn_available
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import (
BlockMask,
flex_attention,
)
from torch.nn.attention.flex_attention import BlockMask, flex_attention
from torch.nn.attention.flex_attention import (
create_block_mask as create_block_causal_mask_flex,
)
@ -64,14 +61,23 @@ class WrappedFlexAttention:
Initialize or update the singleton instance.
"""
if self._is_flex_compiled is False:
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
self._compiled_flex_attention = torch.compile(flex_attention, backend="inductor")
self._is_flex_compiled = True
def __call__(self):
return self._compiled_flex_attention
def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
Offset = Union[torch.Tensor, int]
def make_flex_block_causal_mask(
attention_mask_2d: torch.Tensor,
attention_chunk_size: Optional[int] = None,
query_length=None,
key_length=None,
offsets: Optional[Tuple[Offset, Offset]] = None,
) -> "BlockMask":
"""
Create a block causal document mask for a batch of sequences, both packed and unpacked.
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
@ -94,10 +100,13 @@ def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
Returns:
BlockMask
"""
attention_mask_2d = torch.nn.functional.pad(attention_mask_2d, value=0, pad=(0, key_length))
device = attention_mask_2d.device
document_ids = attention_mask_2d.clone()
document_ids = attention_mask_2d
batch_size, total_seq_len = document_ids.shape
if attention_chunk_size is not None:
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
document_ids = (document_ids.fill_(1).cumsum(-1) - 1) // (attention_chunk_size)
# Instead of passing a tensor mask, flex attention requires a mask_mod function
# that determines which elements of QK^T should be included in the attention
@ -112,18 +121,30 @@ def make_flex_block_causal_mask(attention_mask_2d: torch.Tensor) -> "BlockMask":
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask = q_idx >= kv_idx
causal_mask = q_idx >= kv_idx # not valid when decoding
document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
padding_mask = document_ids[batch_idx, q_idx] > 0
return causal_mask & document_mask & padding_mask
padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
final_mask = causal_mask & padding_mask & document_mask
return final_mask
if offsets is not None:
q_offset = offsets[0]
kv_offset = offsets[1]
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
offset_q = q_idx + q_offset
offset_kv = kv_idx + kv_offset
return causal_mask_mod(batch_idx, head_idx, offset_q, offset_kv)
else:
mask_mod = causal_mask_mod
return create_block_causal_mask_flex(
mask_mod=causal_mask_mod,
B=batch_size,
mask_mod=mask_mod,
B=1,
H=None, # attention head
Q_LEN=total_seq_len,
KV_LEN=total_seq_len,
Q_LEN=query_length,
KV_LEN=key_length,
device=device,
_compile=True,
)
@ -144,6 +165,18 @@ def compile_friendly_flex_attention(
)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def flex_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
@ -174,14 +207,25 @@ def flex_attention_forward(
score = score + head_mask[batch_idx][head_idx][0][0]
return score
enable_gqa = True
num_local_query_heads = query.shape[1]
# When running TP this helps:
if not ((num_local_query_heads & (num_local_query_heads - 1)) == 0):
key = repeat_kv(key, query.shape[1] // key.shape[1])
value = repeat_kv(value, query.shape[1] // value.shape[1])
enable_gqa = False
kernel_options = kwargs.get("kernel_options", None)
attn_output, attention_weights = compile_friendly_flex_attention(
query,
key,
value,
score_mod=score_mod,
block_mask=block_mask,
enable_gqa=True,
enable_gqa=enable_gqa,
scale=scaling,
kernel_options=kernel_options,
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
# For simplification, we thus always return it as no additional computations are introduced.
return_lse=True,

View File

@ -26,6 +26,13 @@ try:
_hub_kernels_available = True
_KERNEL_MAPPING: Dict[str, Dict[Union[Device, str], LayerRepository]] = {
"Llama4TextMoe": {
"cuda": LayerRepository(
# Move to kernels-community/moe once we release.
repo_id="kernels-community/moe-new-models",
layer_name="Llama4TextMoe",
)
},
"MultiScaleDeformableAttention": {
"cuda": LayerRepository(
repo_id="kernels-community/deformable-detr",

View File

@ -31,7 +31,7 @@ def sdpa_attention_forward(
value = repeat_kv(value, module.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
if attention_mask is not None and causal_mask.ndim == 4:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions

View File

@ -61,6 +61,21 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li
return [single_size] * blocks
str_to_torch_dtype = {
"BOOL": torch.bool,
"U8": torch.uint8,
"I8": torch.int8,
"I16": torch.int16,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I32": torch.int32,
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn,
}
def get_packed_weights(param, empty_param, device_mesh, rank, dim):
"""
When weights are packed (gate_up_proj), we need to make sure each shard gets its correct share.
@ -106,6 +121,12 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
tensors_slices += range(block_offset + start, block_offset + stop)
block_offset += block_size
slice_dtype = slice_.get_dtype()
# Handle F8_E4M3 dtype by converting to float16 before slicing
# Without upcasting, the slicing causes : RuntimeError: "index_cpu" not implemented for 'Float8_e4m3fn'
if slice_dtype == "F8_E4M3":
slice_ = slice_[...].to(torch.float16)
if dim == 0:
tensor = slice_[tensors_slices, ...]
elif dim == 1 or dim == -2:
@ -114,7 +135,7 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
tensor = slice_[..., tensors_slices]
else:
raise ValueError(f"Unsupported dim {dim}, only dim 0, 1 or 2 are supported")
return tensor
return tensor.to(str_to_torch_dtype[slice_dtype])
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
@ -199,11 +220,12 @@ class GatherParallel(TensorParallelLayer):
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
if isinstance(inputs[0], DTensor):
inputs[0] = inputs[0].to_local()
inputs = inputs[0].to_local()
return inputs
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# this op cannot be asynch, otherwise it completely breaks the outputs of models
torch.distributed.all_reduce(outputs[0], op=torch.distributed.ReduceOp.SUM, async_op=False)
return outputs
@ -266,7 +288,7 @@ class ColwiseParallel(TensorParallelLayer):
# transform the input layouts to the desired layouts of ColwiseParallel
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
return input_tensor
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
@ -291,7 +313,7 @@ class ColwiseParallel(TensorParallelLayer):
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# outputs is a shard on last dimension DTensor, i.e. Shard(-1)
if outputs.placements != output_layouts:
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
outputs = outputs.redistribute(placements=output_layouts, async_op=False)
# back to local tensor
return outputs.to_local() if use_local_output else outputs
@ -343,16 +365,6 @@ class RowwiseParallel(TensorParallelLayer):
self.use_local_output = use_local_output
self.use_dtensor = use_dtensor
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
return input_tensor
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
# means Rowwise as nn.Linear is input * weight^T + bias, where
@ -371,6 +383,20 @@ class RowwiseParallel(TensorParallelLayer):
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
return nn.Parameter(parameter)
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
if hasattr(mod, "bias") and mod.bias is not None:
mod._bias = mod.bias
mod.bias = None
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
return input_tensor
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# Rowwise sharding produces partial output, depending on output layouts:
@ -378,6 +404,8 @@ class RowwiseParallel(TensorParallelLayer):
# 2. to shard -> reduce_scatter
if outputs.placements != output_layouts:
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
if hasattr(mod, "_bias"):
outputs += mod._bias
# back to local tensor if use_local_output is True
return outputs.to_local() if use_local_output else outputs
@ -418,6 +446,90 @@ class PackedRowwiseParallel(RowwiseParallel):
return nn.Parameter(parameter)
class SequenceParallel(TensorParallelLayer):
"""
SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with
input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the
`RMSNorm python implementation <https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34>`__
This style implements the operation that is described in the paper
`Reducing Activation Recomputation in Large Transformer Models <https://arxiv.org/abs/2205.05198>`__
If the input passed in to this ``nn.Module`` is a :class:`torch.Tensor`, it assumes that the input is already sharded
on the sequence dimension and converts the input to a :class:`DTensor` sharded on the sequence dimension. If the input
passed in to this ``nn.Module`` is already a :class:`DTensor` but is not sharded on the sequence dimension, it would
redistribute the input to be sharded on the sequence dimension.
The output of the ``nn.Module`` will be sharded on the sequence dimension.
Keyword Args:
sequence_dim (int, optional):
The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to
become a DTensor that is sharded on the sequence dimension, default: 1.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False.
Returns:
A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``.
Example::
>>> # xdoctest: +SKIP(failing)
>>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim
>>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`.
>>>
>>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}),
>>> ...
.. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e.
``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom
inits for the weights on those modules, you need to broadcast the weights before/after parallelizing
to ensure that they are replicated.
"""
def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False):
super().__init__()
self.input_layouts = (Replicate(),)
self.desired_input_layouts = (Shard(1),)
self.output_layouts = (Replicate(),)
self.use_local_output = use_local_output
self.use_dtensor = True
self.sequence_sharding = (Shard(sequence_dim),)
self.use_local_output = use_local_output
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)
if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
return input_tensor
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
outputs = outputs.redistribute(
placements=(Replicate(),), async_op=True
) # maybe we have to replicate ? because next layer is not sharded
return outputs.to_local() # if use_local_output else outputs
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
parameter = param[:]
parameter = parameter.to(param_casting_dtype)
if to_contiguous:
parameter = parameter.contiguous()
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, [Replicate()], run_check=False)
return nn.Parameter(parameter)
SUPPORTED_TP_STYLES = {
"colwise",
"rowwise",
@ -428,6 +540,7 @@ SUPPORTED_TP_STYLES = {
"local",
"gather",
"local_packed_rowwise",
"sequence_parallel",
}
@ -459,6 +572,8 @@ def translate_to_torch_parallel_style(style: str):
return GatherParallel()
elif style == "local_packed_rowwise":
return PackedRowwiseParallel(use_dtensor=False)
elif style == "sequence_parallel":
return SequenceParallel()
else:
raise ValueError(f"Unsupported parallel style value: {style}")
@ -518,6 +633,7 @@ def shard_and_distribute_module(
tp_plan = model._tp_plan
module_to_tp = model.get_submodule(param_name)
current_module_plan = None
rank = int(rank)
generic_param_name = re.sub(r"\d+", "*", parameter_name)
if generic_param_name in tp_plan:
current_module_plan = tp_plan[generic_param_name]
@ -531,12 +647,18 @@ def shard_and_distribute_module(
module_to_tp._is_hooked = True
if current_module_plan is not None:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
param = tp_layer.partition_tensor(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)
try:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
param = tp_layer.partition_tensor(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)
except NotImplementedError as e:
print(
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
)
else:
# TODO log no plan modules in set
# print("No plan for", parameter_name,end ="\n")
param = param[...].to(param_casting_dtype)
if is_contiguous:
param = param.contiguous()

View File

@ -484,6 +484,7 @@ str_to_torch_dtype = {
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn,
}
if is_torch_greater_or_equal("2.1.0"):
@ -1777,7 +1778,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
model_tags = None
_auto_class = None
_no_split_modules = None
_no_split_modules = []
_skip_keys_device_placement = None
_keep_in_fp32_modules = None
@ -1914,16 +1915,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
# If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
if self.base_model is self:
self._pp_plan = (
self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
)
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
else:
self._tp_plan = self._tp_plan or {}
for name, module in self.named_children():
if plan := getattr(module, "_tp_plan", None):
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.items()})
self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else None
self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
for name, module in self.named_children():
if plan := getattr(module, "_tp_plan", None):
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
for _, v in self._tp_plan.items():
@ -4050,6 +4046,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
import sys
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
# This is the easiest way to dispatch to the current process device
device_map = tp_device
# Assuming sharding the model onto the world
@ -4234,6 +4231,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
device_map = hf_quantizer.update_device_map(device_map)
config = hf_quantizer.update_tp_plan(config)
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
if hasattr(hf_quantizer.quantization_config.quant_method, "value"):
@ -4366,9 +4364,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules
model=model, device_map=device_map, keep_in_fp32_modules=model._keep_in_fp32_modules, config=config
)
# We store the original dtype for quantized models as we cannot easily retrieve it
# once the weights have been quantized
# Note that once you have loaded a quantized model, you can't change its dtype so this will
@ -4897,7 +4894,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
name,
casting_dtype,
to_contiguous,
tp_device.index,
os.environ["RANK"],
device_mesh,
)
@ -5170,6 +5167,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
(where we want the speed-ups of compiled version with static shapes)."""
# Only reset it if not present or different from previous config
if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT
return self.__call__
default_config = getattr(self.generation_config, "compile_config", CompileConfig())
if (
not hasattr(self, "_compiled_call")

View File

@ -148,6 +148,7 @@ from . import (
levit,
lilt,
llama,
llama4,
llava,
llava_next,
llava_next_video,

View File

@ -544,10 +544,6 @@ class _BaseAutoModelClass:
if kwargs_orig.get("quantization_config", None) is not None:
kwargs["quantization_config"] = kwargs_orig["quantization_config"]
# AutoClass-specific config manipulation
config = copy.deepcopy(config)
config = cls._prepare_config_for_auto_class(config)
has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map
has_local_code = type(config) in cls._model_mapping.keys()
trust_remote_code = resolve_trust_remote_code(
@ -570,6 +566,8 @@ class _BaseAutoModelClass:
)
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
if model_class.config_class == config.sub_configs.get("text_config", None):
config = config.get_text_config()
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)

View File

@ -170,6 +170,8 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("levit", "LevitConfig"),
("lilt", "LiltConfig"),
("llama", "LlamaConfig"),
("llama4", "Llama4Config"),
("llama4_text", "Llama4TextConfig"),
("llava", "LlavaConfig"),
("llava_next", "LlavaNextConfig"),
("llava_next_video", "LlavaNextVideoConfig"),
@ -519,6 +521,8 @@ MODEL_NAMES_MAPPING = OrderedDict(
("llama", "LLaMA"),
("llama2", "Llama2"),
("llama3", "Llama3"),
("llama4", "Llama4"),
("llama4_text", "Llama4ForCausalLM"),
("llava", "LLaVa"),
("llava_next", "LLaVA-NeXT"),
("llava_next_video", "LLaVa-NeXT-Video"),
@ -776,6 +780,7 @@ SPECIAL_MODEL_TYPE_TO_MODULE_NAME = OrderedDict(
("rt_detr_resnet", "rt_detr"),
("granitevision", "llava_next"),
("sam_vision_model", "sam"),
("llama4_text", "llama4"),
]
)

View File

@ -104,6 +104,7 @@ else:
("layoutlmv2", ("LayoutLMv2ImageProcessor",)),
("layoutlmv3", ("LayoutLMv3ImageProcessor",)),
("levit", ("LevitImageProcessor",)),
("llama4", ("Llama4ImageProcessor", "Llama4ImageProcessorFast")),
("llava", ("LlavaImageProcessor", "LlavaImageProcessorFast")),
("llava_next", ("LlavaNextImageProcessor", "LlavaNextImageProcessorFast")),
("llava_next_video", ("LlavaNextVideoImageProcessor",)),

View File

@ -17,7 +17,6 @@
import warnings
from collections import OrderedDict
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from .auto_factory import (
_BaseAutoBackboneClass,
@ -161,6 +160,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("levit", "LevitModel"),
("lilt", "LiltModel"),
("llama", "LlamaModel"),
("llama4", "Llama4ForConditionalGeneration"),
("longformer", "LongformerModel"),
("longt5", "LongT5Model"),
("luke", "LukeModel"),
@ -547,6 +547,8 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("jamba", "JambaForCausalLM"),
("jetmoe", "JetMoeForCausalLM"),
("llama", "LlamaForCausalLM"),
("llama4", "Llama4ForCausalLM"),
("llama4_text", "Llama4ForCausalLM"),
("mamba", "MambaForCausalLM"),
("mamba2", "Mamba2ForCausalLM"),
("marian", "MarianForCausalLM"),
@ -634,6 +636,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
("ijepa", "IJepaModel"),
("imagegpt", "ImageGPTModel"),
("levit", "LevitModel"),
("llama4", "Llama4VisionModel"),
("mllama", "MllamaVisionModel"),
("mobilenet_v1", "MobileNetV1Model"),
("mobilenet_v2", "MobileNetV2Model"),
@ -849,6 +852,7 @@ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = OrderedDict(
("idefics3", "Idefics3ForConditionalGeneration"),
("instructblip", "InstructBlipForConditionalGeneration"),
("kosmos-2", "Kosmos2ForConditionalGeneration"),
("llama4", "Llama4ForConditionalGeneration"),
("llava", "LlavaForConditionalGeneration"),
("llava_next", "LlavaNextForConditionalGeneration"),
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
@ -1492,6 +1496,7 @@ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
("emu3", "Emu3TextModel"),
("flaubert", "FlaubertModel"),
("ibert", "IBertModel"),
("llama4", "Llama4TextModel"),
("longformer", "LongformerModel"),
("mllama", "MllamaTextModel"),
("mobilebert", "MobileBertModel"),
@ -1678,30 +1683,6 @@ _AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="languag
class AutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
@classmethod
def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig:
"""
Additional autoclass-specific config post-loading manipulation. In this specific autoclass, if the config has
a nested text decoder section, uses that section instead.
Under the hood, multimodal models mapped by AutoModelForCausalLM assume the text decoder receives its own
config, rather than the config for the whole model. This is used e.g. to load the text-only part of a VLM.
"""
possible_text_config_names = ("decoder", "generator", "text_config")
text_config_names = []
for text_config_name in possible_text_config_names:
if hasattr(config, text_config_name):
text_config_names += [text_config_name]
text_config = config.get_text_config(decoder=True)
if text_config_names and type(text_config) in cls._model_mapping.keys():
warnings.warn(
"Loading a multimodal model with `AutoModelForCausalLM` is deprecated and will be removed in v5. "
"`AutoModelForCausalLM` will be used to load only the text-to-text generation module.",
FutureWarning,
)
return config
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling")

View File

@ -77,6 +77,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("kosmos-2", "Kosmos2Processor"),
("layoutlmv2", "LayoutLMv2Processor"),
("layoutlmv3", "LayoutLMv3Processor"),
("llama4", "Llama4Processor"),
("llava", "LlavaProcessor"),
("llava_next", "LlavaNextProcessor"),
("llava_next_video", "LlavaNextVideoProcessor"),

View File

@ -292,6 +292,20 @@ else:
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"llama4",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"llama4_text",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
("llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("llava_next", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("llava_next_video", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),

View File

@ -0,0 +1,29 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure
if TYPE_CHECKING:
from .configuration_llama4 import *
from .image_processing_llama4_fast import *
from .modeling_llama4 import *
from .processing_llama4 import *
else:
import sys
_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)

View File

@ -0,0 +1,432 @@
# coding=utf-8
# Copyright 2025 The LLAMA4 and HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class Llama4VisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Llama4VisionModel`]. It is used to instantiate a
Llama4 vision model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Llama4 109B.
e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
num_hidden_layers (`int`, *optional*, defaults to 34):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input image.
intermediate_size (`int`, *optional*, defaults to 5632):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
vision_output_dim (`int`, *optional*, defaults to 7680):
Dimensionality of the vision model output. Includes output of transformer
encoder with intermediate layers and global transformer encoder.
image_size (`int`, *optional*, defaults to 448):
The size (resolution) of each image *tile*.
patch_size (`int`, *optional*, defaults to 14):
The size (resolution) of each patch.
norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the layer normalization layers.
vision_feature_layer (``, *optional*, defaults to -1): TODO
vision_feature_select_strategy (`int`, *optional*, defaults to `"default"`): TODO
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
pixel_shuffle_ratio (`int`, *optional*, defaults to 0.5): TODO
projector_input_dim (`int`, *optional*, defaults to 4096): TODO
projector_output_dim (`int`, *optional*, defaults to 4096): TODO
multi_modal_projector_bias (`int`, *optional*, defaults to `False`): TODO
projector_dropout (`int`, *optional*, defaults to 0.0): TODO
attention_dropout (`int`, *optional*, defaults to 0.0): TODO
rope_theta (`int`, *optional*, defaults to 10000): TODO
```"""
base_model_tp_plan = {
"model.layers.*.self_attn.q_proj": "colwise",
"model.layers.*.self_attn.k_proj": "colwise",
"model.layers.*.self_attn.v_proj": "colwise",
"model.layers.*.self_attn.o_proj": "rowwise",
"vision_adapter.mlp.fc1": "colwise",
"vision_adapter.mlp.fc2": "rowwise",
"patch_embedding.linear": "colwise_rep",
}
model_type = "llama4_vision_model"
base_config_key = "vision_config"
def __init__(
self,
hidden_size: int = 768,
hidden_act: str = "gelu",
num_hidden_layers: int = 34,
num_attention_heads: int = 16,
num_channels: int = 3,
intermediate_size: int = 5632,
vision_output_dim: int = 7680,
image_size: int = 448,
patch_size: int = 14,
norm_eps: float = 1e-5,
vision_feature_layer=-1,
vision_feature_select_strategy="default",
initializer_range: float = 0.02,
pixel_shuffle_ratio=0.5,
projector_input_dim=4096,
projector_output_dim=4096,
multi_modal_projector_bias=False,
projector_dropout=0.0,
attention_dropout=0.0,
rope_theta=10000,
**kwargs,
):
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.num_hidden_layers = num_hidden_layers
self.num_channels = num_channels
self.intermediate_size = intermediate_size
self.image_size = image_size
self.vision_output_dim = vision_output_dim
self.patch_size = patch_size
self.norm_eps = norm_eps
self.num_attention_heads = num_attention_heads
self.initializer_range = initializer_range
self.pixel_shuffle_ratio = pixel_shuffle_ratio
self.projector_input_dim = projector_input_dim
self.projector_output_dim = projector_output_dim
self.multi_modal_projector_bias = multi_modal_projector_bias
self.projector_dropout = projector_dropout
self.attention_dropout = attention_dropout
self.vision_feature_layer = vision_feature_layer
self.vision_feature_select_strategy = vision_feature_select_strategy
self.rope_theta = rope_theta
super().__init__(**kwargs)
class Llama4TextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Llama4TextModel`]. It is used to instantiate a
Llama4 text model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Llama4 109B.
e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 202048):
Vocabulary size of the Llama4 text model. Defines the maximum number of different tokens that can be represented
by the `inputs_ids` passed when calling [`Llama4TextModel`].
hidden_size (`int`, *optional*, defaults to 5120):
Dimensionality of the embeddings and hidden states.
intermediate_size (`int`, *optional*, defaults to 8192):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
intermediate_size_mlp (`int`, *optional*, defaults to 16384): TODO
num_hidden_layers (`int`, *optional*, defaults to 48):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 40):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If not
specified, will default to `num_attention_heads`.
head_dim (`int`, *optional*, defaults to 128): TODO
hidden_act (`str` or `Callable`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the encoder and pooler.
max_position_embeddings (`int`, *optional*, defaults to 131072):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions.
pad_token_id (`int`, *optional*, defaults to 128004):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the beginning of sentence token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the end of sentence token.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to `500000.0`):
The base period of the RoPE embeddings.
attention_dropout (`int`, *optional*, defaults to 0.0): TODO
num_experts_per_tok (`int`, *optional*, defaults to 1): TODO
num_local_experts (`int`, *optional*, defaults to 16): TODO
moe_layers (`int`, *optional*): TODO
interleave_moe_layer_step (`int`, *optional*, defaults to 1): TODO
use_qk_norm (`int`, *optional*, defaults to `True`): TODO
output_router_logits (`int`, *optional*, defaults to `False`): TODO
router_aux_loss_coef (`int`, *optional*, defaults to 0.001): TODO
router_jitter_noise (`int`, *optional*, defaults to 0.0): TODO
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
<TODO>
<TODO>
no_rope_layers (`int`, *optional*): TODO
no_rope_layer_interval (`int`, *optional*, defaults to 4): TODO
attention_chunk_size (`int`, *optional*, defaults to 8192):
<TODO>
attn_temperature_tuning (`int`, *optional*, defaults to 4): TODO
floor_scale (`int`, *optional*, defaults to 8192): TODO
attn_scale (`int`, *optional*, defaults to 0.1): TODO
Example:
```"""
model_type = "llama4_text"
keys_to_ignore_at_inference = ["past_key_values"]
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.input_layernorm.weight": "sequence_parallel",
"layers.*.post_attention_layernorm.weight": "sequence_parallel",
"norm.weight": "sequence_parallel",
"layers.*.feed_forward.shared_expert.gate_proj": "local_colwise",
"layers.*.feed_forward.shared_expert.up_proj": "local_colwise",
"layers.*.feed_forward.shared_expert.down_proj": "local_rowwise",
"layers.*.feed_forward.experts.gate_up_proj": "local_packed_rowwise", # row because not linear
"layers.*.feed_forward.experts.down_proj": "local_colwise", # col because not linear
"layers.*.feed_forward.experts": "local",
"layers.*.feed_forward.gate_proj": "local_colwise",
"layers.*.feed_forward.up_proj": "local_colwise",
"layers.*.feed_forward.down_proj": "local_rowwise",
"layers.*.feed_forward": "gather",
}
def __init__(
self,
vocab_size=202048,
hidden_size=5120,
intermediate_size=8192,
intermediate_size_mlp=16384,
num_hidden_layers=48,
num_attention_heads=40,
num_key_value_heads=8,
head_dim=128,
hidden_act="silu",
max_position_embeddings=4096 * 32,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=500000,
attention_dropout=0.0,
num_experts_per_tok=1,
num_local_experts=16,
moe_layers=None,
interleave_moe_layer_step=1,
use_qk_norm=True,
output_router_logits=False,
router_aux_loss_coef=0.001,
router_jitter_noise=0.0,
rope_scaling=None,
no_rope_layers=None,
no_rope_layer_interval=4,
attention_chunk_size=8192,
attn_temperature_tuning=4,
floor_scale=8192,
attn_scale=0.1,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.attn_temperature_tuning = attn_temperature_tuning
self.attn_scale = attn_scale
self.floor_scale = floor_scale
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.intermediate_size_mlp = intermediate_size_mlp
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.rope_scaling = rope_scaling
self.attention_bias = False
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
self.use_qk_norm = use_qk_norm
self.num_experts_per_tok = num_experts_per_tok
self.num_local_experts = num_local_experts
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.router_jitter_noise = router_jitter_noise
default_no_rope_layers = [
int((layer_idx + 1) % no_rope_layer_interval != 0) for layer_idx in range(self.num_hidden_layers)
]
# no_rope_layers == [] is invalid as we cannot have 0 layers
self.no_rope_layers = no_rope_layers if no_rope_layers else default_no_rope_layers
self.interleave_moe_layer_step = interleave_moe_layer_step
self.moe_layers = (
moe_layers
if moe_layers is not None
else list(range(interleave_moe_layer_step - 1, num_hidden_layers, interleave_moe_layer_step))
)
self.attention_chunk_size = attention_chunk_size
class Llama4Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Llama4Model`]. It is used to instantiate an
Llama4 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the Llama4 109B.
e.g. [meta-llama/Llama-4-Scout-17B-16E](https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`Llama4VisionConfig`, *optional*):
The Llama4 Vision config.
text_config (`Llama4TextConfig`, *optional*):
The Llama4 Text config.
boi_token_index (`int`, *optional*, defaults to 200080):
The begin-of-image token index to wrap the image prompt.
eoi_token_index (`int`, *optional*, defaults to 200081):
The end-of-image token index to wrap the image prompt.
image_token_index (`int`, *optional*, defaults to 200092):
The image token index to encode the image prompt.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
```python
>>> from transformers import Llama4Model, Llama4Config
>>> # Initializing a Llama4 7B style configuration
>>> configuration = Llama4Config()
>>> # Initializing a model from the Llama4 7B style configuration
>>> model = Llama4Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "llama4"
sub_configs = {"text_config": Llama4TextConfig, "vision_config": Llama4VisionConfig}
base_model_tp_plan = {
"multi_modal_projector.linear_1": "colwise_rep",
}
def __init__(
self,
vision_config=None,
text_config=None,
boi_token_index=200080,
eoi_token_index=200081,
image_token_index=200092,
tie_word_embeddings=False,
**kwargs,
):
if vision_config is None:
self.vision_config = Llama4VisionConfig()
logger.info("vision_config is None, using default llama4 vision config")
elif isinstance(vision_config, dict):
self.vision_config = Llama4VisionConfig(**vision_config)
elif isinstance(vision_config, Llama4VisionConfig):
self.vision_config = vision_config
self.boi_token_index = boi_token_index
self.eoi_token_index = eoi_token_index
self.image_token_index = image_token_index
if text_config is None:
self.text_config = Llama4TextConfig()
logger.info("text_config is None, using default llama4 text config")
elif isinstance(text_config, dict):
self.text_config = Llama4TextConfig(**text_config)
elif isinstance(text_config, Llama4TextConfig):
self.text_config = text_config
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
__all__ = ["Llama4Config", "Llama4TextConfig", "Llama4VisionConfig"]

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,480 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast Image processor class for Got-OCR-2."""
import math
from collections import defaultdict
from functools import lru_cache
from typing import List, Optional, Set, Tuple, Union
from ...image_processing_utils import BatchFeature
from ...image_processing_utils_fast import (
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
ImageInput,
PILImageResampling,
SizeDict,
)
from ...processing_utils import Unpack
from ...utils import (
TensorType,
add_start_docstrings,
is_torch_available,
is_torchvision_available,
is_torchvision_v2_available,
)
if is_torch_available():
import torch
if is_torchvision_available():
if is_torchvision_v2_available():
from torchvision.transforms.v2 import functional as F
else:
from torchvision.transforms import functional as F
def get_factors(dividend: int) -> Set[int]:
"""
Calculate all factors of a given number, i.e. a dividor that leaves
no remainder. For example, if dividend=12, it will return {1, 2, 3, 4, 6, 12}.
Args:
dividend (int): The number to find factors for.
Returns:
set: A set containing all factors of the number.
"""
factors_set = set()
for i in range(1, int(dividend**0.5) + 1):
if dividend % i == 0:
factors_set.add(i)
factors_set.add(dividend // i)
return factors_set
def get_max_res_without_distortion(
image_size: Tuple[int, int],
target_size: Tuple[int, int],
) -> Tuple[int, int]:
"""
Determines the maximum resolution to which an image can be resized to without distorting its
aspect ratio, based on the target resolution.
Args:
image_size (Tuple[int, int]): The original resolution of the image (height, width).
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
Returns:
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
Example:
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
(134, 200)
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
(450, 338)
"""
original_height, original_width = image_size
target_height, target_width = target_size
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.floor(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.floor(original_width * scale_h), target_width)
return new_height, new_width
class Llama4ImageProcessorKwargs(DefaultFastImageProcessorKwargs):
max_patches: Optional[int]
resize_to_max_canvas: Optional[bool]
def split_to_tiles(images: torch.Tensor, num_tiles_height: int, num_tiles_width: int) -> torch.Tensor:
# Split image into number of required tiles (width x height)
batch_size, num_channels, height, width = images.size()
images = images.view(
batch_size,
num_channels,
num_tiles_height,
height // num_tiles_height,
num_tiles_width,
width // num_tiles_width,
)
# Permute dimensions to reorder the axes
image = images.permute(0, 2, 4, 1, 3, 5).contiguous()
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
image = image.view(
batch_size,
num_tiles_width * num_tiles_height,
num_channels,
height // num_tiles_height,
width // num_tiles_width,
)
return image
@lru_cache(maxsize=1)
def find_supported_resolutions(max_num_chunks: int, patch_size: SizeDict) -> torch.Tensor:
"""
Computes all of the allowed resolutions for a fixed number of chunks
and patch_size. Useful for when dividing an image into chunks.
Args:
max_num_chunks (int): Maximum number of chunks for processing.
patch_size (int): Size of the side of the patch.
Returns:
torch.Tensor: List of possible resolutions as tuples (height, width).
Example:
>>> max_num_chunks = 5
>>> patch_size = 224
>>> find_supported_resolutions(max_num_chunks, patch_size)
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
(672, 224), (224, 448), (448, 224)])
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
{
0.25: [(1, 4)],
1.0: [(2, 2), (1, 1)],
4.0: [(4, 1)],
0.33: [(1, 3)],
3.0: [(3, 1)],
0.5: [(1, 2)],
2.0: [(2, 1)]
}
and return the resolutions multiplied by the patch_size:
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
"""
height, width = patch_size.height, patch_size.width
if height != width:
raise ValueError("`size` must be square.")
patch_size = height
asp_dict = defaultdict(list)
for chunk_size in range(max_num_chunks, 0, -1):
_factors = sorted(get_factors(chunk_size))
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
for height, width in _asp_ratios:
ratio_float = height / width
asp_dict[ratio_float].append((height, width))
# get the resolutions multiplied by the patch_size
possible_resolutions = []
for key, value in asp_dict.items():
for height, depth in value:
possible_resolutions.append((height * patch_size, depth * patch_size))
return possible_resolutions
def pad_to_best_fit(
images: "torch.Tensor",
target_size: Tuple[int, int],
background_color: Union[int, Tuple[int, int, int]] = 0,
) -> "torch.Tensor":
"""
Pads an image to fit the target size.
Args:
images (`np.ndarray`):
The images to pad.
background_color (`int` or `Tuple[int, int, int]`, *optional*, defaults to 0):
The color to use for the padding. Can be an integer for single channel or a
tuple of integers representing for multi-channel images. If passed as integer
in mutli-channel mode, it will default to `0` in subsequent channels.
Returns:
`torch.Tensor`: The padded images.
"""
num_channels = images.shape[1] if len(images.shape) == 4 else images.shape[0]
if isinstance(background_color, int):
background_color = [background_color] + [0] * (num_channels - 1)
elif len(background_color) != num_channels:
raise ValueError(
f"background_color must have no more than {num_channels} elements to match the number of channels"
)
height, width = images.shape[-2:]
target_height, target_width = target_size
paste_x_right = target_width - width
paste_y_right = target_height - height
padded_images = F.pad(images, padding=[0, 0, paste_x_right, paste_y_right], fill=background_color)
return padded_images
def get_best_fit(
image_size: Tuple[int, int],
possible_resolutions: torch.Tensor,
resize_to_max_canvas: bool = False,
) -> Tuple[int, int]:
"""
Determines the best canvas possible from a list of possible resolutions to, without distortion,
resize an image to.
For each possible resolution, calculates the scaling factors for
width and height, and selects the smallest one, which is the limiting side.
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
If upscaling is possible (any of the scaling factors is greater than 1),
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
reduce downscaling as much as possible.
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
has more padding.
Args:
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
row represents a possible resolution (height, width).
resize_to_max_canvas (bool): If True, will return the largest upscaling resolution.
Returns:
List[int]: The best resolution [height, width] for the given image.
Example:
>>> image_size = (200, 300)
>>> possible_resolutions = torch.tensor([[224, 672],
... [672, 224],
... [224, 448],
... [448, 224],
... [224, 224]])
>>> get_best_fit(image_size, possible_resolutions)
[224, 448]
We have:
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
Only one of the scales > 1:
upscaling_possible = tensor([1.1200, 1.1200])
smallest_rescale = tensor(1.1200)
So we pick the resolution with the smallest smallest area:
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
optimal_canvas = tensor([224, 448])
"""
original_height, original_width = image_size
# get all possible resolutions heights/widths
target_heights, target_widths = (
possible_resolutions[:, 0],
possible_resolutions[:, 1],
)
# get scaling factors to resize the image without distortion
scale_w = target_widths / original_width
scale_h = target_heights / original_height
# get the min scale between width and height (limiting side -> no distortion)
scales = torch.where(scale_h > scale_w, scale_w, scale_h)
# filter only scales that allow upscaling
upscaling_options = scales[scales >= 1]
if len(upscaling_options) > 0:
if resize_to_max_canvas:
selected_scale = torch.max(upscaling_options)
else:
selected_scale = torch.min(upscaling_options)
else:
# no upscaling possible,
# get the minimum downscaling (max scale for scales<1)
downscaling_options = scales[scales < 1]
selected_scale = torch.max(downscaling_options)
# get all resolutions that support this scaling factor,
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
chosen_canvas = possible_resolutions[scales == selected_scale]
# if there are multiple resolutions,
# get the one with minimum area to reduce padding
if len(chosen_canvas) > 1:
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
optimal_idx = torch.argmin(areas)
optimal_canvas = chosen_canvas[optimal_idx]
else:
optimal_canvas = chosen_canvas[0]
return tuple(optimal_canvas.tolist())
@add_start_docstrings(
"Constructs a fast Llama4 image processor.",
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING,
"""
max_patches (`int`, *optional*, defaults to 16):
The maximum number of patches to be extracted from the image.
Can be overridden by the `max_patches` parameter in the `preprocess` method.
resize_to_max_canvas (`bool`, *optional*, defaults to False):
Whether to resize the image to the maximum canvas size.
If True, picks the canvas the allows the largest resizing without distortion.
If False, downsample as little as possible, including no resizing at all,
but never upsample, unless the image is smaller than the patch size.
""",
)
class Llama4ImageProcessorFast(BaseImageProcessorFast):
resample = PILImageResampling.BILINEAR
image_mean = [0.5, 0.5, 0.5]
image_std = [0.5, 0.5, 0.5]
size = {"height": 336, "width": 336}
do_resize = True
do_rescale = True
do_normalize = True
do_convert_rgb = True
max_patches = 16
resize_to_max_canvas = False
valid_kwargs = Llama4ImageProcessorKwargs
def __init__(self, **kwargs: Unpack[Llama4ImageProcessorKwargs]):
super().__init__(**kwargs)
def rescale_and_normalize(
self,
images: "torch.Tensor",
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Union[float, List[float]],
image_std: Union[float, List[float]],
) -> "torch.Tensor":
"""
Rescale and normalize images.
Override to rescale and normalize the images in torch.bfloat16 as in the original implementation
"""
if do_rescale and do_normalize:
images = images.to(dtype=torch.bfloat16) * rescale_factor
images = self.normalize(images, image_mean, image_std)
elif do_rescale:
images = images * rescale_factor
elif do_normalize:
images = self.normalize(images, image_mean, image_std)
return images
@add_start_docstrings(
BASE_IMAGE_PROCESSOR_FAST_DOCSTRING_PREPROCESS,
"""
max_patches (`int`, *optional*, defaults to 16):
The maximum number of patches to be extracted from the image.
Can be overridden by the `max_patches` parameter in the `preprocess` method.
resize_to_max_canvas (`bool`, *optional*, defaults to False):
Whether to resize the image to the maximum canvas size.
If True, picks the canvas the allows the largest resizing without distortion.
If False, downsample as little as possible, including no resizing at all,
but never upsample, unless the image is smaller than the patch size.
""",
)
def preprocess(self, images: ImageInput, **kwargs: Unpack[Llama4ImageProcessorKwargs]) -> BatchFeature:
return super().preprocess(images, **kwargs)
def _preprocess(
self,
images: List["torch.Tensor"],
size: SizeDict,
max_patches: int,
resize_to_max_canvas: bool,
interpolation: Optional["F.InterpolationMode"],
do_rescale: bool,
rescale_factor: float,
do_normalize: bool,
image_mean: Optional[Union[float, List[float]]],
image_std: Optional[Union[float, List[float]]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
possible_resolutions = find_supported_resolutions(max_num_chunks=max_patches, patch_size=size)
possible_resolutions = torch.tensor(possible_resolutions)
# process images by batch, grouped by shape
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_processed_images = {}
grouped_aspect_ratios = {}
for shape, stacked_images in grouped_images.items():
image_size = stacked_images.shape[-2:]
target_size = get_best_fit(image_size, possible_resolutions, resize_to_max_canvas=resize_to_max_canvas)
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
max_upscaling_size = None if resize_to_max_canvas else size.height
if max_upscaling_size is not None:
new_target_height = min(max(image_size[0], max_upscaling_size), target_size[0])
new_target_width = min(max(image_size[1], max_upscaling_size), target_size[1])
target_size_without_distortion = (new_target_height, new_target_width)
# resize to target_size while preserving aspect ratio
new_size_without_distortion = get_max_res_without_distortion(image_size, target_size_without_distortion)
new_size_without_distortion = SizeDict(
height=max(new_size_without_distortion[0], 1), width=max(new_size_without_distortion[1], 1)
)
processed_images = self.resize(
stacked_images,
new_size_without_distortion,
interpolation=interpolation,
)
# pad to target_size to be able to split into tiles
processed_images = pad_to_best_fit(processed_images, target_size)
processed_images = self.rescale_and_normalize(
processed_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
ratio_h, ratio_w = (
target_size[0] // size.height,
target_size[1] // size.height,
)
# split into tiles
processed_images = split_to_tiles(processed_images, ratio_h, ratio_w)
grouped_processed_images[shape] = processed_images
grouped_aspect_ratios[shape] = torch.tensor([[ratio_h, ratio_w]] * stacked_images.shape[0])
# add a global tile to the processed tile if there are more than one tile
if ratio_h * ratio_w > 1:
global_tiles = self.resize(
stacked_images,
size,
interpolation=interpolation,
)
global_tiles = self.rescale_and_normalize(
global_tiles, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
grouped_processed_images[shape] = torch.cat([processed_images, global_tiles.unsqueeze(1)], dim=1)
processed_images = reorder_images(grouped_processed_images, grouped_images_index)
aspect_ratios_list = reorder_images(grouped_aspect_ratios, grouped_images_index)
processed_images = torch.cat(processed_images, dim=0) if return_tensors else processed_images
aspect_ratios = torch.stack(aspect_ratios_list, dim=0) if return_tensors else aspect_ratios_list
return BatchFeature(
data={"pixel_values": processed_images, "aspect_ratios": aspect_ratios}, tensor_type=return_tensors
)
__all__ = ["Llama4ImageProcessorFast"]

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@ -981,6 +981,8 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
else:
self.device = device if device is not None else -1
if torch.distributed.is_initialized():
self.device = self.model.device
logger.warning(f"Device set to use {self.device}")
self.binary_output = binary_output

View File

@ -1178,10 +1178,6 @@ class ProcessorMixin(PushToHubMixin):
unused_kwargs = {}
unused_keys = set(kwargs_from_config) - set(valid_kwargs)
if unused_keys:
unused_key_str = ", ".join(unused_keys)
logger.warning(
f"Some kwargs in processor config are unused and will not have any effect: {unused_key_str}. "
)
unused_kwargs = {k: processor_config[k] for k in unused_keys}
return unused_kwargs

View File

@ -43,8 +43,7 @@ is_torch_greater_or_equal_than_1_12 = is_torch_greater_or_equal("1.12", accept_d
_torch_distributed_available = torch.distributed.is_available()
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
from torch.distributed.tensor import Replicate
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
pass
def softmax_backward_data(parent, grad_output, output, dim, self):
@ -335,29 +334,6 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
return torch.isin(elements, test_elements)
# TODO need to add the __repr__ that shows that it is a colwise parallel
# See https://github.com/pytorch/pytorch/issues/145726
def translate_to_torch_parallel_style(style: str):
"""
In model configurations, we use a neutral type (string) to specify parallel
styles, here we translate them into torch.distributed tensor-parallel
types.
"""
if not isinstance(style, str):
raise ValueError(f"Unsupported parallel style type {type(style)}, expected str")
if style == "colwise":
return ColwiseParallel()
elif style == "rowwise":
return RowwiseParallel()
elif style == "colwise_rep":
return ColwiseParallel(output_layouts=Replicate())
elif style == "rowwise_rep":
return RowwiseParallel(input_layouts=Replicate())
else:
raise ValueError(f"Unsupported parallel style value: {style}")
def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs):
"""
LRU cache decorator from standard functools library, but with a workaround to disable
@ -382,88 +358,3 @@ def compile_compatible_method_lru_cache(*lru_args, **lru_kwargs):
return wrapper
return decorator
def distribute_module(
module: nn.Module,
device_mesh=None,
partition_fn=None,
input_fn=None,
output_fn=None,
) -> nn.Module:
"""
This function expose three functions to control the parameters/inputs/outputs of the module:
1. To perform sharding on the module before runtime execution by specifying the
``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor`
parameters according to the `partition_fn` specified).
2. To control the inputs or outputs of the module during runtime execution by
specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to
:class:`DTensor`, convert the output back to ``torch.Tensor``)
Args:
module (:class:`nn.Module`): user module to be partitioned.
device_mesh (:class:`DeviceMesh`): the device mesh to place the module.
partition_fn (Callable): the function to partition parameters (i.e. shard certain
parameters across the ``device_mesh``). If ``partition_fn`` is not specified,
by default we replicate all module parameters of ``module`` across the mesh.
input_fn (Callable): specify the input distribution, i.e. could control how the
input of the module is sharded. ``input_fn`` will be installed as a module
``forward_pre_hook`` (pre forward hook).
output_fn (Callable): specify the output distribution, i.e. could control how the
output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be
installed as a module ``forward_hook`` (post forward hook).
Returns:
A module that contains parameters/buffers that are all ``DTensor`` s.
.. note::
When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module``
return nn.Module with PyTorch/XLA SPMD annotated parameters. See
`this issue <https://github.com/pytorch/pytorch/issues/92909>`__
for more details. The XLA integration is experimental and subject to change.
"""
torch._C._log_api_usage_once("torch.dtensor.distribute_module")
device_mesh = device_mesh
# register input_fn as module forward pre hook
if input_fn is not None:
# check the input_fn signature
num_args = len(inspect.signature(input_fn).parameters)
if num_args == 2:
# input_fn only takes in inputs and device mesh
logger.warning(
"Deprecating input_fn that takes two arguments (inputs, device_mesh), "
"please use input_fn that takes in (module, inputs, device_mesh) instead!",
FutureWarning,
stacklevel=2,
)
module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[call-arg]
elif num_args == 3:
# input_fn takes in module, inputs, device mesh
module.register_forward_pre_hook(lambda mod, inputs: input_fn(mod, inputs, device_mesh))
else:
raise ValueError(f"input_fn should take in 3 arguments, but got {num_args} arguments!")
# register output_fn as module forward hook
if output_fn is not None:
num_args = len(inspect.signature(output_fn).parameters)
if num_args == 2:
# output_fn only takes in outputs and device mesh
logger.warning(
"Deprecating output_fn that takes two arguments (inputs, device_mesh), "
"please use output_fn that takes in (module, inputs, device_mesh) instead!",
FutureWarning,
stacklevel=2,
)
module.register_forward_hook(
lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg]
)
elif num_args == 3:
module.register_forward_hook(lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh))
else:
raise ValueError(f"output_fn should take in 3 arguments, but got {num_args} arguments!")
return module

52
src/transformers/quantizers/base.py Executable file → Normal file
View File

@ -15,7 +15,8 @@ from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from ..utils import is_torch_available
from ..utils.quantization_config import QuantizationConfigMixin
from ..utils.quantization_config import QuantizationConfigMixin, QuantizationMethod
from .quantizers_utils import get_module_from_name
if TYPE_CHECKING:
@ -23,6 +24,9 @@ if TYPE_CHECKING:
if is_torch_available():
import torch
from torch.nn import ModuleList
else:
ModuleList = str
class HfQuantizer(ABC):
@ -198,6 +202,10 @@ class HfQuantizer(ABC):
"""
return
def update_tp_plan(self, config):
"updates the tp plan for the scales"
return config
def preprocess_model(self, model: "PreTrainedModel", **kwargs):
"""
Setting model attributes and/or converting model before weights loading. At this point
@ -212,6 +220,7 @@ class HfQuantizer(ABC):
"""
model.is_quantized = True
model.quantization_method = self.quantization_config.quant_method
self._convert_model_for_quantization(model)
return self._process_model_before_weight_loading(model, **kwargs)
def postprocess_model(self, model: "PreTrainedModel", **kwargs):
@ -288,3 +297,44 @@ class HfQuantizer(ABC):
@property
@abstractmethod
def is_trainable(self): ...
def _convert_model_for_quantization(self, model):
from accelerate import init_empty_weights
for name, module in model.named_modules():
module_class_name = module.__class__.__name__
if (
module_class_name in MODULES_TO_PATCH_FOR_QUANTIZATION.keys()
and self.quantization_config.quant_method == QuantizationMethod.COMPRESSED_TENSORS
):
with init_empty_weights():
parent_module, name = get_module_from_name(model, name)
parent_module._modules[name] = MODULES_TO_PATCH_FOR_QUANTIZATION[module_class_name](
model.config.get_text_config()
)
class SequentialLlama4TextExperts(ModuleList):
"""
A module that implements a compressed version of a list of expert modules.
This is specifically designed to work with Llama4TextExperts in MoE layers.
"""
def __init__(self, config):
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP
super().__init__([Llama4TextMLP(config) for _ in range(config.num_local_experts)])
self.num_experts = config.num_local_experts
def forward(
self,
hidden_states: "torch.Tensor",
) -> "torch.Tensor":
hidden_states = hidden_states.reshape(self.num_experts, -1, hidden_states.shape[-1])
routed_out = torch.zeros_like(hidden_states)
for expert_idx in range(self.num_experts):
routed_out[expert_idx] = self[expert_idx](hidden_states[expert_idx])
return routed_out
MODULES_TO_PATCH_FOR_QUANTIZATION = {"Llama4TextExperts": SequentialLlama4TextExperts}

View File

@ -146,6 +146,19 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
self.compressor.quantization_config.quantization_status = QuantizationStatus.FROZEN
self.compressor.decompress(model_path=cache_path, model=model)
def update_tp_plan(self, config):
additional_plan = {
"layers.*.feed_forward.experts.*.gate_proj.weight": "local_colwise",
"layers.*.feed_forward.experts.*.gate_proj.weight_scale": "local_colwise",
"layers.*.feed_forward.experts.*.up_proj.weight": "local_colwise",
"layers.*.feed_forward.experts.*.up_proj.weight_scale": "local_colwise",
"layers.*.feed_forward.experts.*.down_proj.weight": "local_rowwise",
}
if config.get_text_config() is not None and config.get_text_config().base_model_tp_plan is not None:
config.get_text_config().base_model_tp_plan.update(additional_plan)
return config
@property
def is_trainable(self):
return True

View File

@ -116,7 +116,7 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
state_dict: Dict[str, Any],
**kwargs,
):
from ..integrations import FbgemmFp8Linear
from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
module, tensor_name = get_module_from_name(model, param_name)
@ -129,6 +129,13 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
if tensor_name == "weight_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
return True
if isinstance(module, FbgemmFp8Llama4TextExperts):
if self.pre_quantized or tensor_name == "bias":
return False
else:
if tensor_name == "gate_up_proj_scale" or tensor_name == "down_proj_scale":
raise ValueError("Expect unquantized weights but got a quantized weight_scale")
return True
return False
def create_quantized_param(
@ -143,12 +150,52 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
"""
Quantizes weights into weight and weight_scale
"""
new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value)
from ..integrations import FbgemmFp8Llama4TextExperts
module, tensor_name = get_module_from_name(model, param_name)
module._buffers[tensor_name] = new_value.to(target_device)
# to have the right output shape -> (out_features, 1)
module._buffers["weight_scale"] = weight_scale.view(weight_scale.shape[0], 1).to(target_device)
if isinstance(module, FbgemmFp8Llama4TextExperts):
if tensor_name == "gate_up_proj":
# Process each expert separately
# Transpose the second and third dimension
transposed_param = param_value.transpose(1, 2)
# Reshape to 2D for quantization
original_shape = transposed_param.shape
flattened_param = transposed_param.reshape(-1, original_shape[-1])
# Quantize using per row instead of per column
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
# Reshape back to original dimensions
new_value = new_value_flat.reshape(original_shape)
new_value = new_value.transpose(1, 2)
weight_scale = weight_scale_flat.reshape(original_shape[0], 1, original_shape[1])
elif tensor_name == "down_proj":
# Process each expert separately
# Transpose the weights for proper quantization
transposed_param = param_value.transpose(1, 2)
# Reshape to 2D for quantization
original_shape = transposed_param.shape
flattened_param = transposed_param.reshape(-1, original_shape[-1])
# Quantize using per column
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
# Reshape back to original dimensions
new_value = new_value_flat.reshape(original_shape)
new_value = new_value.transpose(1, 2)
weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1)
module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter(weight_scale.to(target_device))
else:
new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(param_value)
module._parameters[f"{tensor_name}_scale"] = torch.nn.Parameter(
weight_scale.view(weight_scale.shape[0], 1).to(target_device)
)
module._parameters[tensor_name] = torch.nn.Parameter(new_value.to(target_device))
if unexpected_keys is not None and param_name in unexpected_keys:
unexpected_keys.remove(param_name)
@ -165,25 +212,29 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
):
from ..integrations import replace_with_fbgemm_fp8_linear
tp_plan = model._tp_plan
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
)
config = model.config
model = replace_with_fbgemm_fp8_linear(
model,
modules_to_not_convert=self.modules_to_not_convert,
quantization_config=self.quantization_config,
pre_quantized=self.pre_quantized,
config=config,
tp_plan=tp_plan,
)
model.config.quantization_config = self.quantization_config
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
from ..integrations import FbgemmFp8Linear
from ..integrations import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts
not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, FbgemmFp8Linear):
if isinstance(module, FbgemmFp8Linear) or isinstance(module, FbgemmFp8Llama4TextExperts):
for missing in missing_keys:
if (
(name in missing or name in f"{prefix}.{missing}")

View File

@ -3950,7 +3950,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
verbose (`bool`): Whether or not to print more information and warnings.
"""
if max_length is None and len(ids) > self.model_max_length and verbose:
if max_length is None and len(ids) > self.model_max_length and verbose and self.model_max_length != 0:
if not self.deprecation_warnings.get("sequence-length-is-longer-than-the-specified-maximum", False):
logger.warning(
"Token indices sequence length is longer than the specified maximum sequence length "

View File

@ -5823,6 +5823,41 @@ class LlamaPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
class Llama4ForCausalLM(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Llama4ForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Llama4PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Llama4TextModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class Llama4VisionModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LlavaForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]

View File

@ -72,6 +72,13 @@ class GotOcr2ImageProcessorFast(metaclass=DummyObject):
requires_backends(self, ["torchvision"])
class Llama4ImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torchvision"])
class LlavaImageProcessorFast(metaclass=DummyObject):
_backends = ["torchvision"]

View File

@ -408,6 +408,13 @@ class LevitImageProcessor(metaclass=DummyObject):
requires_backends(self, ["vision"])
class Llama4ImageProcessor(metaclass=DummyObject):
_backends = ["vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
class LlavaImageProcessor(metaclass=DummyObject):
_backends = ["vision"]

View File

View File

@ -0,0 +1,128 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_torch_available():
pass
if is_vision_available() and is_torchvision_available():
from transformers import Llama4ImageProcessorFast
class Llama4ImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
image_size=18,
min_resolution=30,
max_resolution=400,
max_patches=1,
do_resize=True,
size=None,
do_normalize=True,
do_pad=False,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
do_convert_rgb=True,
):
super().__init__()
size = size if size is not None else {"height": 20, "width": 20}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.max_patches = max_patches
self.do_resize = do_resize
self.size = size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.do_pad = do_pad
self.do_convert_rgb = do_convert_rgb
def prepare_image_processor_dict(self):
return {
"max_patches": self.max_patches,
"do_resize": self.do_resize,
"size": self.size,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_convert_rgb": self.do_convert_rgb,
"do_pad": self.do_pad,
}
def expected_output_image_shape(self, images):
return self.num_channels, self.size["height"], self.size["width"]
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
@require_torch
@require_vision
class Llama4ImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
test_slow_image_processor = False
fast_image_processing_class = Llama4ImageProcessorFast if is_torchvision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = Llama4ImageProcessingTester(self)
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processor, "do_resize"))
self.assertTrue(hasattr(image_processor, "size"))
self.assertTrue(hasattr(image_processor, "do_normalize"))
self.assertTrue(hasattr(image_processor, "image_mean"))
self.assertTrue(hasattr(image_processor, "image_std"))
self.assertTrue(hasattr(image_processor, "do_convert_rgb"))
def test_split_tiles(self):
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class(**self.image_processor_dict)
image = self.image_processor_tester.prepare_image_inputs(equal_resolution=True)[0]
processed_images = image_processor(
image,
max_patches=16,
)
self.assertEqual(len(processed_images.pixel_values), 1)
self.assertEqual(processed_images.pixel_values[0].shape[0], 17)
self.assertEqual(processed_images.pixel_values[0].shape[-2:], (20, 20))

View File

@ -0,0 +1,121 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Llama4 model."""
import unittest
from transformers import is_torch_available
from transformers.testing_utils import (
require_read_token,
require_torch_large_gpu,
slow,
torch_device,
)
if is_torch_available():
import torch
from transformers import (
Llama4ForConditionalGeneration,
Llama4Processor,
)
@slow
@require_torch_large_gpu
@require_read_token
class Llama4IntegrationTest(unittest.TestCase):
model_id = "ll-re/Llama-4-17B-Omni-Instruct"
# This variable is used to determine which CUDA device are we using for our runners (A10 or T4)
# Depending on the hardware we get different logits / generations
cuda_compute_capability_major_version = None
@classmethod
def setUpClass(cls):
if is_torch_available() and torch.cuda.is_available():
# 8 is for A100 / A10 and 7 for T4
cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0]
cls.model = Llama4ForConditionalGeneration.from_pretrained(
"ll-re/Llama-4-17B-Omni-Instruct", device_map="auto", torch_dtype=torch.float32
)
def setUp(self):
self.processor = Llama4Processor.from_pretrained("ll-re/Llama-4-17B-Omni-Instruct", padding_side="left")
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png"
self.messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{"type": "image", "url": url},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
def test_model_17b_16e_fp16(self):
EXPECTED_TEXT = [
"The capital of France is Paris, which is located in the north-central part of the country. Paris is known for its iconic landmarks such as the",
"Roses are red, violets are blue, and this poem is about you. Roses are red, violets are blue, and I love",
]
messages = [
{"role": "user", "content": "Who are you?"},
]
inputs = self.processor.apply_chat_template(
messages, add_generation_prompt=True, return_tensors="pt", return_dict=True
).to(torch_device)
output = self.model.generate(**inputs, max_new_tokens=100)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
print(output_text)
self.assertEqual(output_text, EXPECTED_TEXT)
def test_model_17b_16e_batch(self):
messages_2 = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png",
},
{"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
{"type": "text", "text": "Are these images identical?"},
],
},
]
inputs = self.processor.apply_chat_template(
[self.messages, messages_2],
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=True,
add_generation_prompt=True,
).to(torch_device)
output = self.model.generate(**inputs, max_new_tokens=30, do_sample=False)
output_text = self.processor.batch_decode(output, skip_special_tokens=True)
EXPECTED_TEXTS = [
'user\nYou are a helpful assistant.\n\n\n\n\n\nWhat is shown in this image?\nmodel\nCertainly! \n\nThe image shows a brown cow standing on a sandy beach with clear turquoise water and a blue sky in the background. It looks like',
"user\nYou are a helpful assistant.\n\n\n\n\n\n\n\n\n\nAre these images identical?\nmodel\nNo, these images are not identical. \n\nHere's a breakdown of the differences:\n\n* **Image 1:** Shows a cow"
] # fmt: skip
self.assertEqual(output_text, EXPECTED_TEXTS)

View File

@ -0,0 +1,65 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import tempfile
import unittest
from typing import Optional
from transformers import AutoProcessor, Llama4Processor, PreTrainedTokenizerFast
from transformers.testing_utils import require_vision
from transformers.utils import is_vision_available
from ...test_processing_common import ProcessorTesterMixin
if is_vision_available():
from transformers import Llama4ImageProcessor
@require_vision
class Llama4ProcessorTest(ProcessorTesterMixin, unittest.TestCase):
processor_class = Llama4Processor
def setUp(self):
self.tmpdirname = tempfile.mkdtemp()
image_processor = Llama4ImageProcessor(max_patches=1, size={"height": 20, "width": 20})
tokenizer = PreTrainedTokenizerFast.from_pretrained("unsloth/Llama-3.2-11B-Vision-Instruct-unsloth-bnb-4bit")
processor_kwargs = self.prepare_processor_dict()
processor = Llama4Processor(image_processor, tokenizer, **processor_kwargs)
processor.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).tokenizer
def get_image_processor(self, **kwargs):
return AutoProcessor.from_pretrained(self.tmpdirname, **kwargs).image_processor
def tearDown(self):
shutil.rmtree(self.tmpdirname)
# Override as Llama4ProcessorProcessor needs image tokens in prompts
def prepare_text_inputs(self, batch_size: Optional[int] = None):
if batch_size is None:
return "lower newer <image>"
if batch_size < 1:
raise ValueError("batch_size must be greater than 0")
if batch_size == 1:
return ["lower newer <image>"]
return ["lower newer <image>", "<image> upper older longer string"] + ["<image> lower newer"] * (
batch_size - 2
)

View File

@ -236,6 +236,16 @@ SPECIAL_CASES_TO_ALLOW = {
"text_config",
"vision_config",
],
"Llama4Config": ["boi_token_index", "eoi_token_index"],
"Llama4TextConfig": [
"interleave_moe_layer_step",
"no_rope_layer_interval",
"no_rope_layers",
"output_router_logits",
"router_aux_loss_coef",
"router_jitter_noise",
],
"Llama4VisionConfig": ["multi_modal_projector_bias", "norm_eps"],
}
@ -358,6 +368,8 @@ def check_attribute_being_used(config_class, attributes, default_value, source_s
"rope_theta",
"partial_rotary_factor",
"pretraining_tp",
"boi_token_index",
"eoi_token_index",
]
attributes_used_in_generation = ["encoder_no_repeat_ngram_size"]

View File

@ -67,6 +67,7 @@ _re_parse_description = re.compile(r"\*optional\*, defaults to (.*)$")
# docstrings instead. If formatting should be ignored for the docstring, you can put a comment # no-format on the
# line before the docstring.
OBJECTS_TO_IGNORE = [
"Llama4Processor",
# Deprecated
"InputExample",
"InputFeatures",

View File

@ -226,10 +226,16 @@ def check_dummies(overwrite: bool = False):
for _actual, _dummy in zip(
actual_dummies["torch"].split("class"), dummy_files["torch"].split("class")
):
found = False
if _actual != _dummy:
actual_broken = _actual
dummy_broken = _dummy
found = True
break
if found:
print("A transient error was found with the dummies, please investigate.")
raise ValueError(
"The main __init__ has objects that are not present in "
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py.\n"

View File

@ -144,6 +144,8 @@ IGNORE_NON_TESTED = (
"Qwen2_5_VLModel", # Building part of bigger (tested) model. Tested implicitly through Qwen2_5_VLForConditionalGeneration.
"MllamaTextModel", # Building part of bigger (tested) model. # TODO: add tests
"MllamaVisionModel", # Building part of bigger (tested) model. # TODO: add tests
"Llama4TextModel", # Building part of bigger (tested) model. # TODO: add tests
"Llama4VisionModel", # Building part of bigger (tested) model. # TODO: add tests
"Emu3VQVAE", # Building part of bigger (tested) model
"Emu3TextModel", # Building part of bigger (tested) model
]