Compare commits

..

290 Commits

Author SHA1 Message Date
055aef8690 don't initialize 2025-11-17 11:34:51 +00:00
0103627278 first 2025-11-14 09:03:57 +00:00
7d8df526e6 fix 2025-11-10 19:48:40 +01:00
00b00448e7 Merge branch 'refactor-weight-loading' into fix-bnb 2025-11-10 19:02:27 +01:00
3651460288 Merge branch 'refactor-weight-loading' into fix-bnb 2025-11-10 18:56:54 +01:00
f5a7c33dce mt5 fuck 2025-11-10 18:22:10 +01:00
3c8c7572e6 ty @SunMarc this fixes the buffers
Co-authored-by: SunMarc <SunMarc@users.noreply.github.com>
2025-11-10 18:21:59 +01:00
2f0a6aed58 big revert, don't break this behaviour 2025-11-10 17:39:06 +01:00
f93f35709c update some models 2025-11-10 17:38:54 +01:00
5881d8eb91 deal with buffers 2025-11-10 17:24:19 +01:00
ea5822db85 Merge branch 'refactor-weight-loading' into fix-bnb 2025-11-10 16:11:29 +01:00
9fa1b7a2c4 guard needed for compressed-tensors 2025-11-10 16:10:40 +01:00
e033947a5c shared todo? 2025-11-10 15:24:38 +01:00
7b457fd04c fix init weights for non param gate up projs 2025-11-10 14:47:07 +01:00
09bcd2ee11 fixes 2025-11-10 14:30:27 +01:00
86a4e51647 fix deformable detr 2025-11-10 14:18:05 +01:00
5be67b96fc update error message 2025-11-10 13:58:03 +01:00
8755a4beef fix hunyuan 2025-11-10 13:47:23 +01:00
db4fe31ddf Merge branch 'refactor-weight-loading' into fix-bnb 2025-11-10 13:44:28 +01:00
94a53d4c66 uupdate 2025-11-10 12:18:02 +01:00
de74aebbc7 checkout 2025-11-10 11:30:55 +01:00
7b7c990364 [build-ci-image] 2025-11-10 11:14:47 +01:00
c137ea3323 fix data-2-vec 2025-11-10 11:09:24 +01:00
0412832432 fix smart apply 2025-11-10 10:45:04 +01:00
bbf5b000e2 asyncio? 2025-11-10 10:32:25 +01:00
f7d0183d2b fix xcodex 2025-11-10 10:02:14 +01:00
2a00e493c2 nits 2025-11-10 09:24:20 +01:00
3ffc59ef92 fix resize token embeddings 2025-11-10 09:23:29 +01:00
d176b48973 Merge branch 'refactor-weight-loading' of github.com:huggingface/transformers into refactor-weight-loading 2025-11-07 18:45:33 +01:00
76d66be5e5 fix prophnet 2025-11-07 18:45:30 +01:00
a052513335 should fix it 2025-11-07 18:39:23 +01:00
3e69622256 Fix tests single gpu 2025-11-07 18:28:26 +01:00
44943fb87d fix big faileurs 2025-11-07 18:19:54 +01:00
a0029f207b Merge branch 'main' into refactor-weight-loading 2025-11-07 18:07:56 +01:00
5c9d56cb07 fixup 2025-11-07 17:35:58 +01:00
f8f0973415 more changes to untangle old hardcoded ting 2025-11-07 17:34:01 +01:00
e4df75269a rm report 2025-11-07 17:13:44 +01:00
e235eeddb7 Merge remote-tracking branch 'upstream/fix-bnb' into fix-bnb 2025-11-07 17:09:11 +01:00
d841a04b3e Fix loadedparam 2025-11-07 17:08:38 +01:00
443573aeb8 moe case 2025-11-07 16:22:43 +01:00
85ab08590a update decoder.bias 2025-11-07 16:19:50 +01:00
75d3afcb48 remove explict sharing of some tied keys. 2025-11-07 15:56:25 +01:00
72eff97c4d Update src/transformers/core_model_loading.py
Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
2025-11-07 14:21:04 +01:00
9788014a93 Merge remote-tracking branch 'upstream/refactor-weight-loading' into fix-bnb 2025-11-07 14:18:56 +01:00
9c0db728bd push 2025-11-07 14:14:55 +01:00
386e259b85 update 2025-11-07 14:12:48 +01:00
e16da231ca rm import 2025-11-07 14:12:41 +01:00
18b02eea94 update 2025-11-07 13:58:24 +01:00
0e7d2d052d propnet is dumb 2025-11-07 12:24:44 +01:00
dde5500d80 just push this for now 2025-11-07 12:11:49 +01:00
074a449f6b properly fix qwen init 2025-11-07 11:59:34 +01:00
9fde9f7893 fix qwen and long cat flash 2025-11-07 11:54:41 +01:00
9a76a6eee3 fix long cat flash 2025-11-07 11:53:09 +01:00
32226787a9 fix led 2025-11-07 11:49:31 +01:00
8ff4ad56a5 Ouiiii 2025-11-07 11:09:50 +01:00
78d46227f8 lol 2025-11-07 11:03:34 +01:00
2fa058fe8a up 2025-11-07 09:59:01 +01:00
f692f4bdcb subclass nn.Parameters 2025-11-07 08:55:33 +01:00
acbeeae720 Merge branch 'refactor-weight-loading' into fix-bnb 2025-11-06 18:27:58 +01:00
399388d1fe rm print 2025-11-06 18:22:37 +01:00
bdbc01a6a4 Fix bnb loading ! 2025-11-06 18:21:45 +01:00
d22363560f nits 2025-11-06 12:48:13 +01:00
c48e1edb49 up? 2025-11-06 10:53:07 +01:00
0c2b667d13 an attempt 2025-11-06 10:20:44 +01:00
1dabb4c334 here we go again 2025-11-06 00:30:48 +01:00
0f022b59d9 Revert "tied weight first shot to the fiiiixxxxxx"
This reverts commit 3fea865810e4dc832919e0a7f853ca5d3d426c72.
2025-11-06 00:20:50 +01:00
0e51decd6d last update 2025-11-05 22:57:14 +01:00
e341529210 nits and fixes 2025-11-05 22:41:08 +01:00
f72f96d400 fixes for more models torch_bc 2025-11-05 22:02:04 +01:00
cc0819540b fix some ppolry defined tied_weights_keys for now 2025-11-05 21:50:54 +01:00
84dd6eb26e :) 2025-11-05 20:28:24 +01:00
82f94b8ae0 does this help? 2025-11-05 20:21:38 +01:00
3fea865810 tied weight first shot to the fiiiixxxxxx 2025-11-05 19:08:06 +01:00
e4cadfb1c2 Merge branch 'refactor-weight-loading' of github.com:huggingface/transformers into refactor-weight-loading 2025-11-05 18:19:13 +01:00
6cb3794080 try less strict source check 2025-11-05 18:18:23 +01:00
2526cc5d91 mixtral init 2025-11-05 17:29:01 +01:00
710b1fffcf fix 2025-11-05 17:24:16 +01:00
07574dddd4 Merge branch 'refactor-weight-loading' of github.com:huggingface/transformers into refactor-weight-loading 2025-11-05 17:21:54 +01:00
a228fd0ad2 revert change to init scheme (no need for params) 2025-11-05 16:32:00 +01:00
ef8b6c3548 small update 2025-11-05 16:20:52 +01:00
b57d7897c4 remove ALL custom tie weights 2025-11-05 15:44:22 +01:00
d9e7fe65c8 Merge branch 'refactor-weight-loading' of github.com:huggingface/transformers into refactor-weight-loading 2025-11-05 15:37:39 +01:00
92c0229af4 more fixes related to strict matching regex 2025-11-05 15:29:46 +01:00
58389a1ff0 remove some tie_weights custome funcs when not needed 2025-11-05 15:11:59 +01:00
acc5b2452a remove all buffering -> much faster without it 2025-11-05 15:08:36 +01:00
5146dec408 Merge branch 'refactor-weight-loading' of github.com:huggingface/transformers into refactor-weight-loading 2025-11-05 14:06:03 +01:00
d91701f7ee improve 2025-11-05 14:05:31 +01:00
8baa3fe987 Merge branch 'refactor-weight-loading' of github.com:huggingface/transformers into refactor-weight-loading 2025-11-05 14:03:52 +01:00
2733ff69c4 some updates 2025-11-05 14:03:38 +01:00
b8927d67ef oupsi 2025-11-05 14:02:58 +01:00
8c16de161f cleanup a bit 2025-11-05 13:51:11 +01:00
57988f25a2 improve tqdm bar 2025-11-05 13:28:04 +01:00
c43495a51a fox umt5 2025-11-05 13:23:33 +01:00
912562c08a fix? 2025-11-05 13:12:26 +01:00
2ff765e9ed fix whisper as well 2025-11-05 12:05:00 +01:00
e7165da04d fix more individual models 2025-11-05 12:03:35 +01:00
ead2ac3776 Merge branch 'refactor-weight-loading' of github.com:huggingface/transformers into refactor-weight-loading 2025-11-05 11:59:05 +01:00
74a0e9c71b try to remove custom tiing logic when its stupid 2025-11-05 11:59:01 +01:00
e2aefee7fc fix import to avoid jit execution 2025-11-05 11:37:50 +01:00
bd36211210 remove semaphores 2025-11-05 11:26:43 +01:00
45271710d0 update 2025-11-05 11:14:06 +01:00
42fd4c4325 update 2025-11-05 10:43:06 +01:00
db02b9d716 nit 2025-11-05 10:42:01 +01:00
9601b82ce7 up 2025-11-05 10:19:43 +01:00
ff108789ca _dtype nit 2025-11-05 10:13:20 +01:00
5c54332e3b more fixes 2025-11-05 10:11:12 +01:00
8936cc408f fix asjusting 2025-11-05 10:07:18 +01:00
50714d8ca7 Merge branch 'refactor-weight-loading' of github.com:huggingface/transformers into refactor-weight-loading 2025-11-05 10:04:12 +01:00
c921cedee7 lol 2025-11-05 10:04:06 +01:00
dcad7030b2 fix 2025-11-05 09:45:05 +01:00
1652c9c52f make it fast 2025-11-05 09:25:27 +01:00
a581fd75e7 hoey 2025-11-05 08:59:12 +01:00
89846e7d81 up 2025-11-05 08:58:11 +01:00
20d1b340c4 AI UPDATE 2025-11-04 18:07:40 +01:00
5e71bd4ae7 more updates 2025-11-04 17:18:08 +01:00
5794d27d1c fix more 2025-11-04 17:11:44 +01:00
0b95826c97 fix-copies 2025-11-04 16:55:31 +01:00
32b9273893 up 2025-11-04 16:48:34 +01:00
0fb23403e4 up 2025-11-04 16:39:29 +01:00
5d7507b16d more fixes 2025-11-04 15:57:25 +01:00
8fd255c7f0 up more 2025-11-04 15:47:19 +01:00
ba3de5add4 fix xopies 2025-11-04 15:28:34 +01:00
ba1a8b64c0 fix ernie 2025-11-04 15:13:48 +01:00
76b6a92d74 more up 2025-11-04 14:16:51 +01:00
f85f2397ec mllama 2025-11-04 14:11:37 +01:00
675b2bca69 more 2025-11-04 14:10:13 +01:00
dc5a22c2af more 2025-11-04 13:47:35 +01:00
4f212de424 more 2025-11-04 13:46:07 +01:00
e088408964 more update 2025-11-04 13:30:56 +01:00
d7c81717ae more 2025-11-04 12:31:53 +01:00
da7dc100ac ship validated ones 2025-11-04 11:58:09 +01:00
4894a25774 current shitty changes 2025-11-03 18:52:08 +01:00
93862177d8 glubs 2025-11-03 18:01:22 +01:00
8f7b1d02bb fix a test 2025-11-03 17:57:12 +01:00
8b924a3b12 fix and fix 2025-11-03 17:44:06 +01:00
a170f290a8 update 2025-11-03 16:26:11 +01:00
00b95ee009 nit 2025-11-03 13:54:47 +01:00
1c87945a3c small fixes 2025-11-03 13:45:40 +01:00
02386ce7c6 fix more tie weights keys 2025-11-03 13:29:37 +01:00
8a8beff73e rev 2025-11-03 13:16:34 +01:00
77ccbb17fd up 2025-11-03 13:15:57 +01:00
ab6ee8aed4 ish 2025-11-03 12:49:04 +01:00
a8fb5540c9 nits 2025-11-03 12:35:05 +01:00
23e3ed7489 small fix 2025-11-03 12:33:20 +01:00
ce8c1c1978 fix hunyuan 2025-11-03 12:27:07 +01:00
d923061e63 removeunused 2025-11-03 12:23:49 +01:00
2ff85326fc ups 2025-11-03 12:07:22 +01:00
80517f5322 dik why we tie weights twice but,..,,. 2025-11-03 11:55:29 +01:00
7d78aa1b37 up 2025-11-03 11:07:52 +01:00
22fcdaf9c6 up 2025-11-03 10:54:26 +01:00
f2938df853 small fixes 2025-11-03 10:48:08 +01:00
d1e84db344 fix some tests 2025-11-03 10:39:01 +01:00
4d7970991c ah actually we don't discard lm head if missing -> needs to be moved to correct device and etc 2025-11-03 10:25:45 +01:00
6c88206d3b fix the init of param 2025-11-03 10:11:05 +01:00
82a35bcc89 nits 2025-11-03 10:01:38 +01:00
3baf4b7f6b lol so much time lost on this shit 2025-11-03 09:51:09 +01:00
9b6a7a445b fixup 2025-11-03 09:35:24 +01:00
85973fc9ad fix triton import error 2025-11-03 09:35:08 +01:00
c515eb6d91 Merge branch 'main' of github.com:huggingface/transformers into refactor-weight-loading 2025-11-03 09:33:02 +01:00
4d34cedff5 Merge branch 'refactor-weight-loading' of github.com:huggingface/transformers into refactor-weight-loading 2025-11-03 09:32:22 +01:00
9cb0432c2d qol 2025-11-03 09:32:16 +01:00
0da6e92757 upsates 2025-11-01 18:37:53 +00:00
9022bc293e Merge branch 'refactor-weight-loading' of github.com:huggingface/transformers into refactor-weight-loading 2025-11-01 17:48:51 +01:00
b148577e3c up 2025-11-01 17:48:17 +01:00
7eda8aa764 dtype 2025-11-01 13:54:45 +00:00
606452d69e nits 2025-11-01 12:37:40 +00:00
a79de84819 cleanup what is no longer used 2025-11-01 09:56:30 +00:00
20b6142aa7 small updates? 2025-11-01 09:52:55 +00:00
29aa0515a0 Merge branch 'main' of github.com:huggingface/transformers into refactor-weight-loading 2025-11-01 09:37:55 +00:00
52d85e0fb4 merge 2025-11-01 10:05:40 +01:00
e59b1fffab Merge branch 'refactor-weight-loading' of github.com:huggingface/transformers into refactor-weight-loading 2025-10-31 16:58:15 +01:00
6b398e149f nit 2025-10-31 16:54:39 +01:00
0ebb1b6219 fixup 2025-10-31 15:43:56 +00:00
7061956922 qol + nits 2025-10-31 15:21:48 +00:00
29e017d50a cleanup 2025-10-31 15:14:44 +01:00
19f94d0f40 how many tests does this fix? 2025-10-31 09:43:00 +01:00
e465bc0ae0 fak 2025-10-31 08:20:20 +01:00
913171a9d8 did not know glob was only 3.13 2025-10-31 08:13:38 +01:00
07e265d10d up 2025-10-31 08:04:54 +01:00
3e4d8ea958 up 2025-10-31 08:04:47 +01:00
1d4411aa17 update 2025-10-31 08:02:47 +01:00
e848ab6165 up 2025-10-31 07:53:36 +01:00
573af7594c fix import and error 2025-10-31 07:50:47 +01:00
2d84aba1da fix glob import 2025-10-31 07:45:42 +01:00
9f5ec4ac90 nit 2025-10-31 07:36:42 +01:00
ef5123b8ad Merge branch 'refactor-weight-loading' of github.com:huggingface/transformers into refactor-weight-loading 2025-10-31 07:35:46 +01:00
6d0aa66327 fix tie weight keys? 2025-10-31 07:35:43 +01:00
7f196f9313 small nits 2025-10-31 07:22:36 +01:00
b225885f58 Apply suggestion from @LysandreJik
Co-authored-by: Lysandre Debut <hi@lysand.re>
2025-10-31 07:21:15 +01:00
904283dd1c Apply suggestion from @LysandreJik
Co-authored-by: Lysandre Debut <hi@lysand.re>
2025-10-31 07:20:43 +01:00
d34482c6a0 Merge branch 'refactor-weight-loading' of github.com:huggingface/transformers into refactor-weight-loading 2025-10-31 07:17:49 +01:00
e0fd1e42e3 better error handling (Am I too rust-y) ? 2025-10-31 07:15:54 +01:00
f4775fcac4 updates based on review 2025-10-31 06:30:39 +01:00
00846a2ef4 Apply suggestion from @LysandreJik
Co-authored-by: Lysandre Debut <hi@lysand.re>
2025-10-31 06:09:59 +01:00
5d4d27e6e2 up 2025-10-30 18:34:42 +01:00
b320474eae update 2025-10-30 18:27:46 +01:00
630934707d smal nit 2025-10-30 17:21:58 +01:00
c3c534fe67 licence 2025-10-30 17:20:35 +01:00
edf96f8451 update conversion mapping! 2025-10-30 17:18:04 +01:00
00e36042a8 yups 2025-10-30 17:04:46 +01:00
48c85c78da revert small granite moe stuff 2025-10-30 17:01:37 +01:00
912dd2f7ba updates 2025-10-30 16:37:58 +01:00
9bed48862c more fixups 2025-10-30 16:36:28 +01:00
50a85efdcd fix ernie 2025-10-30 16:05:22 +01:00
d9bb0e340e fix olmoe 2025-10-30 15:57:15 +01:00
fe9b047899 up 2025-10-30 15:34:07 +01:00
9f615bcc1c update 2025-10-30 15:29:28 +01:00
a01ad8d63e small nits 2025-10-30 14:37:11 +01:00
8cf96946e7 nit 2025-10-30 14:21:10 +01:00
28a1d22526 add qwen2_moe to the mapping! 2025-10-30 14:01:46 +01:00
6c9fda4e0e small updates 2025-10-30 12:02:59 +01:00
4443658942 Merge branch 'main' of github.com:huggingface/transformers into refactor-weight-loading 2025-10-30 11:43:53 +01:00
0402e564ce yups 2025-10-30 11:43:09 +01:00
134959c142 styling 2025-10-30 11:35:54 +01:00
17f25f9f3b fix copies 2025-10-30 11:35:19 +01:00
3cde7b0606 fix bunch of tests 2025-10-30 11:34:07 +01:00
22145750da ship most fixes 2025-10-30 10:36:20 +01:00
c53755fce7 smoll QOL 2025-10-30 09:55:49 +01:00
f1312dc91c fix llama tests ? 2025-10-30 09:46:13 +01:00
edeacc3867 move progress 2025-10-30 09:02:21 +01:00
de09779953 Merge branch 'refactor-weight-loading' of github.com:huggingface/transformers into refactor-weight-loading 2025-10-30 08:55:42 +01:00
e1eb5a4adb nit 2025-10-30 08:55:37 +01:00
aa0ebbec82 the way to make local tensor + Dtensor work 2025-10-29 21:53:51 +00:00
ac1af43293 TP + QUANTIZE now works 2025-10-29 21:43:34 +00:00
653933c293 fix fp8 2025-10-29 21:25:10 +00:00
a92cb1fe61 Youhou 2025-10-29 21:22:56 +00:00
ec49d7339d Merge branch 'refactor-weight-loading' of github.com:huggingface/transformers into refactor-weight-loading 2025-10-29 20:20:21 +01:00
965b006613 small update 2025-10-29 20:17:50 +01:00
9735c6e011 current updates 2025-10-29 17:12:26 +00:00
a8998de322 fix auto for mps 2025-10-29 15:21:50 +01:00
c3f5437233 fix tie weight embeddding? 2025-10-29 15:10:27 +01:00
a5859af437 local changes 2025-10-29 14:42:39 +01:00
8e74adc4d0 support tp dtensor 2025-10-28 16:14:41 +00:00
62ccfd9b7f nits 2025-10-28 15:36:55 +00:00
7efb487d31 fix-copies 2025-10-28 15:36:51 +00:00
0519e21dd3 fix fp8, it now works 2025-10-28 15:35:04 +00:00
6f6deb0f88 update 2025-10-28 15:20:54 +00:00
466df965f3 updates 2025-10-28 14:56:48 +00:00
2fe87ce1dd updates 2025-10-28 11:31:45 +00:00
c6bb839d21 fixes 2025-10-28 10:19:29 +00:00
fe220cf182 quantization works 2025-10-27 23:00:10 +00:00
667133317e fix-copies 2025-10-27 22:18:24 +00:00
7b64815cc5 fix modular 2025-10-27 22:06:55 +00:00
b01dd4fd98 ruff 2025-10-27 22:06:48 +00:00
58fc7b5799 nits 2025-10-27 22:06:06 +00:00
c9417f9872 kill poool asap 2025-10-27 21:57:29 +00:00
b82c4f256f i was just missing a "clone" :) 2025-10-27 21:55:55 +00:00
a693417568 small upstead 2025-10-27 19:26:31 +00:00
b6027426f2 fixes 2025-10-27 19:03:59 +00:00
b4ef14c23b cleanup 2025-10-27 18:43:06 +00:00
fb3422794d we have a forward pass "running" but out is gibberish for now! 2025-10-27 18:38:19 +00:00
fbea44e9e2 small updates 2025-10-27 17:08:58 +00:00
d7d922acd5 Merge branch 'refactor-weight-loading' of github.com:huggingface/transformers into refactor-weight-loading 2025-10-27 12:53:24 +00:00
30f41f2c9d fix device map 2025-10-27 12:53:20 +00:00
9c9669360c nit 2025-10-27 13:16:16 +01:00
b20b69373e latest changes 2025-10-27 12:41:07 +01:00
afdb59ddf4 up 2025-10-21 17:23:26 -07:00
4b2058be0e more updates and cleanup 2025-10-21 17:16:02 -07:00
b2e97bf570 current status 2025-10-21 16:02:33 -07:00
f74d41f18f fixup 2025-10-20 19:07:44 -07:00
36a4b5d5ac update 2025-10-20 19:07:30 -07:00
bde538dc0f fix 2025-10-20 18:35:43 -07:00
7728fda7c7 more small changes 2025-10-20 18:34:10 -07:00
d36e62c12d up 2025-10-20 18:31:03 -07:00
b8586194ce fixup 2025-10-20 18:10:19 -07:00
0e56676260 works a little bit 2025-10-20 18:03:14 -07:00
d1c47d0e02 up 2025-10-20 15:21:48 -07:00
e40427fc57 push what I have for now! 2025-10-20 11:39:24 -07:00
1aae8d97f2 update 2025-10-19 17:57:51 +02:00
0569ee8693 update, we are getting close to something "usable" 2025-10-19 17:33:58 +02:00
e0da883e85 updates 2025-10-19 16:51:18 +02:00
f62bc7e0dd nits and comments here andd there 2025-10-17 13:23:30 +02:00
bfb804756d update 2025-10-17 10:20:46 +02:00
a08b927826 cleanup 2025-10-17 10:17:26 +02:00
8ca058d64c cleanup 2025-10-17 10:14:19 +02:00
01f8a7e419 current status 2025-10-17 10:09:07 +02:00
e956317273 Merge branch 'main' of github.com:huggingface/transformers into refactor-weight-loading 2025-10-17 09:15:06 +02:00
213a64d4ae some updates 2025-10-16 18:31:38 +02:00
f8d1f98dc1 comment 2025-10-16 15:21:35 +02:00
9c07ead1fc deisng 2025-10-16 14:59:20 +02:00
86e48e242b style 2025-10-16 14:29:07 +02:00
8a3e3d43bb update 2025-10-16 14:24:24 +02:00
46b7632fbc update 2025-10-16 14:09:10 +02:00
0ff608d466 Merge branch 'refactor-weight-loading' of github.com:huggingface/transformers into refactor-weight-loading 2025-10-15 16:23:53 +02:00
15ec137a1d current state 2025-10-15 16:23:48 +02:00
993c2fbe74 Update src/transformers/conversion_mapping.py 2025-10-14 18:26:28 +02:00
7bb32d5f7f up 2025-10-14 15:40:45 +02:00
22734c5047 my draft 2025-10-14 15:12:29 +02:00
941738e5f3 Merge branch 'main' into refactor-weight-loading 2025-10-14 12:52:32 +02:00
d76ebe4195 ai draft 2025-10-13 15:08:15 +02:00
756 changed files with 10503 additions and 7926 deletions

View File

@ -20,4 +20,4 @@ jobs:
contents: read
with:
workflow_name: ${{ inputs.workflow_name }}
run_count: ${{ fromJSON(inputs.run_count) }}
run_count: ${{ fromJSON(inputs.run_count) }}

View File

@ -87,6 +87,9 @@ jobs:
PR_FILES: ${{ steps.pr_info.outputs.files }}
if: ${{ inputs.pr_number != '' }}
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Extract PR details
id: pr_info
uses: actions/github-script@v6

View File

@ -13,6 +13,9 @@ jobs:
outputs:
PR_NUMBER: ${{ steps.set_pr_number.outputs.PR_NUMBER }}
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Get PR number
shell: bash
env:

View File

@ -13,6 +13,9 @@ jobs:
name: Notify new model
runs-on: ubuntu-22.04
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- uses: actions/checkout@v4
with:
fetch-depth: 0

View File

@ -35,6 +35,9 @@ jobs:
PR_MERGE_COMMIT_DATE: ${{ needs.get-pr-info.outputs.PR_MERGE_COMMIT_DATE }}
PR_MERGE_COMMIT_TIMESTAMP: ${{ needs.get-pr-info.outputs.PR_MERGE_COMMIT_TIMESTAMP }}
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- run: |
COMMENT_TIMESTAMP=$(date -d "${COMMENT_DATE}" +"%s")
echo "COMMENT_DATE: $COMMENT_DATE"
@ -54,6 +57,9 @@ jobs:
statuses: write
runs-on: ubuntu-22.04
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Create Run
id: create_run
env:
@ -77,6 +83,9 @@ jobs:
pull-requests: write
runs-on: ubuntu-22.04
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Reply to the comment
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
@ -112,6 +121,9 @@ jobs:
GITHUB_RUN_URL: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}
STATUS_OK: ${{ contains(fromJSON('["skipped", "success"]'), needs.create_run.result) }}
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Get `build-doc` job status
run: |
echo "${{ needs.build-doc.result }}"

View File

@ -23,6 +23,10 @@ jobs:
outputs:
jobs: ${{ steps.get_jobs.outputs.jobs_to_run }}
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
# This checkout to the main branch
- uses: actions/checkout@v4
with:
@ -89,6 +93,10 @@ jobs:
pull-requests: write
runs-on: ubuntu-22.04
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Check and update comment if needed
uses: actions/github-script@v7
env:

View File

@ -11,6 +11,10 @@ jobs:
outputs:
matrix: ${{ steps.set-matrix.outputs.matrix }}
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Check out code
uses: actions/checkout@v4

View File

@ -18,6 +18,10 @@ jobs:
shell: bash -l {0}
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Checkout repository
uses: actions/checkout@v4

View File

@ -46,6 +46,10 @@ jobs:
PR_HEAD_SHA: ${{ needs.get-pr-info.outputs.PR_HEAD_SHA }}
PR_MERGE_SHA: ${{ needs.get-pr-info.outputs.PR_MERGE_COMMIT_SHA }}
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Verify `merge_commit` timestamp is older than the issue comment timestamp
env:
COMMENT_DATE: ${{ github.event.comment.created_at }}
@ -67,6 +71,10 @@ jobs:
models: ${{ steps.models_to_run.outputs.models }}
quantizations: ${{ steps.models_to_run.outputs.quantizations }}
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- uses: actions/checkout@v4
with:
fetch-depth: "0"
@ -109,6 +117,10 @@ jobs:
pull-requests: write
runs-on: ubuntu-22.04
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Reply to the comment
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
@ -131,6 +143,10 @@ jobs:
pull-requests: write
runs-on: ubuntu-22.04
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Reply to the comment
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
@ -152,6 +168,10 @@ jobs:
statuses: write
runs-on: ubuntu-22.04
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Create Run
id: create_run
env:
@ -210,6 +230,10 @@ jobs:
if: ${{ always() && needs.create_run.result == 'success' }}
runs-on: ubuntu-22.04
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Show reports from jobs
env:
MODEL_REPORT: ${{ needs.model-ci.outputs.report }}

View File

@ -30,6 +30,10 @@ jobs:
name: Setup
runs-on: ubuntu-22.04
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Setup
run: |
mkdir "setup_values"

View File

@ -14,6 +14,9 @@ jobs:
outputs:
run_number: ${{ steps.get_number.outputs.run_number }}
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Get number
id: get_number
run: |

View File

@ -10,5 +10,9 @@ jobs:
runs-on: ubuntu-22.04
if: ${{ always() }}
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Trigger scheduled AMD CI via workflow_run
run: echo "Trigger scheduled AMD CI via workflow_run"

View File

@ -32,6 +32,9 @@ jobs:
name: Setup
runs-on: ubuntu-22.04
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Setup
env:
prev_workflow_run_id: ${{ inputs.prev_workflow_run_id || env.prev_workflow_run_id }}

View File

@ -32,6 +32,10 @@ jobs:
name: Setup
runs-on: ubuntu-22.04
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Setup
run: |
mkdir "setup_values"

View File

@ -38,6 +38,10 @@ jobs:
folder_slices: ${{ steps.set-matrix.outputs.folder_slices }}
quantization_matrix: ${{ steps.set-matrix.outputs.quantization_matrix }}
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Checkout
uses: actions/checkout@v4
with:
@ -122,6 +126,10 @@ jobs:
--cap-add=sys_nice
--shm-size=64G
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Checkout
uses: actions/checkout@v4
with:
@ -191,6 +199,10 @@ jobs:
--cap-add=sys_nice
--shm-size=64G
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Checkout
uses: actions/checkout@v4
with:
@ -263,6 +275,10 @@ jobs:
--cap-add=sys_nice
--shm-size=64G
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Checkout
uses: actions/checkout@v4
with:

View File

@ -78,6 +78,9 @@ jobs:
slice_ids: ${{ steps.set-matrix.outputs.slice_ids }}
quantization_matrix: ${{ steps.set-matrix-quantization.outputs.quantization_matrix }}
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Update clone
working-directory: /transformers
env:
@ -184,6 +187,9 @@ jobs:
image: huggingface/transformers-all-latest-gpu
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Update clone
working-directory: /transformers
env:
@ -256,6 +262,9 @@ jobs:
image: huggingface/transformers-all-latest-gpu
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Update clone
working-directory: /transformers
env:
@ -329,6 +338,9 @@ jobs:
image: ${{ inputs.docker }}
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Update clone
working-directory: ${{ inputs.working-directory-prefix }}/transformers
env:
@ -434,6 +446,9 @@ jobs:
image: huggingface/transformers-quantization-latest-gpu
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Echo folder ${{ matrix.folders }}
shell: bash
env:
@ -518,6 +533,9 @@ jobs:
image: ${{ inputs.docker }}
options: --gpus all --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Update clone
working-directory: /transformers
env:
@ -588,6 +606,9 @@ jobs:
steps:
# Checkout in order to run `utils/extract_warnings.py`. Avoid **explicit** checkout (i.e. don't specify `ref`) for
# security reason.
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Checkout transformers
uses: actions/checkout@v4

View File

@ -38,6 +38,10 @@ jobs:
runs-on: ubuntu-22.04
if: always() && !cancelled()
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Preliminary job status
shell: bash
# For the meaning of these environment variables, see the job `Setup`

View File

@ -30,6 +30,10 @@ jobs:
outputs:
RUNNER: ${{ steps.set_runner.outputs.RUNNER }}
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Get runner to use
shell: bash
env:
@ -58,6 +62,10 @@ jobs:
container:
image: ${{ github.event.inputs.docker_image }}
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Update clone
working-directory: /transformers
env:

View File

@ -14,16 +14,21 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
steps:
- uses: actions/checkout@v4
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: 3.8
- name: Checkout
uses: actions/checkout@v4
- name: Install requirements
run: |
pip install PyGithub
- name: Close stale issues
run: |
python scripts/stale.py
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: 3.8
- name: Install requirements
run: |
pip install PyGithub
- name: Close stale issues
run: |
python scripts/stale.py

View File

@ -10,6 +10,10 @@ jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- name: Checkout code
uses: actions/checkout@v4
with:

View File

@ -14,6 +14,10 @@ jobs:
shell: bash -l {0}
steps:
- uses: GitHubSecurityLab/actions-permissions/monitor@v1
with:
config: ${{ vars.PERMISSIONS_CONFIG }}
- uses: actions/checkout@v4
- name: Setup environment

View File

@ -45,6 +45,7 @@ repo-consistency:
python utils/check_modular_conversion.py
python utils/check_dummies.py
python utils/check_repo.py
python utils/check_init_weights_data.py
python utils/check_inits.py
python utils/check_pipeline_typing.py
python utils/check_config_docstrings.py

View File

@ -508,16 +508,16 @@ BERT `_init_weights` Methode:
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
```
Sie können weitere benutzerdefinierte Schemata verwenden, wenn Sie eine spezielle Initialisierung für einige Module benötigen. Zum Beispiel in
@ -533,9 +533,9 @@ def _init_weights(self, module):
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
```
Das Flag `_is_hf_initialized` wird intern verwendet, um sicherzustellen, dass wir ein Submodul nur einmal initialisieren. Wenn Sie es auf

View File

@ -314,16 +314,16 @@ Random initialization occurs in the `_init_weights` method of `BrandNewLlamaPreT
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
```
The initialization scheme can look different if you need to adapt it to your model. For example, [`Wav2Vec2ForPreTraining`] initializes [nn.Linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) in its last two linear layers.
@ -339,9 +339,9 @@ def _init_weights(self, module):
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
```
### Convert checkpoints to Transformers

View File

@ -159,7 +159,7 @@ conversation3 = [
conversations = [conversation1, conversation2, conversation3]
inputs = processor.apply_chat_template(
conversations,
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,

View File

@ -149,7 +149,7 @@ The example below packs `up_proj` and `gate_proj` into a single `gate_up_proj` m
```python
class Llama4TextExperts(nn.Module):
...
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
```
Batch matrix multiplication can be used in the `forward` pass to compute the output of the `gate_up_proj` module.

View File

@ -406,16 +406,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
```
特定のモジュールに特別な初期化が必要な場合、カスタムスキームをさらに持つことができます。たとえば、
@ -431,9 +431,9 @@ def _init_weights(self, module):
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
```
`_is_hf_initialized`フラグは、サブモジュールを一度だけ初期化することを確実にするために内部で使用されます。

View File

@ -348,16 +348,16 @@ model = BrandNewBertModel(BrandNewBertConfig())
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
```
몇 가지 모듈에 대해 특별한 초기화가 필요한 경우 사용자 정의 방식을 사용할 수도 있습니다. 예를 들어, `Wav2Vec2ForPreTraining`에서 마지막 두 개의 선형 레이어는 일반적인 PyTorch `nn.Linear`의 초기화를 가져야 하지만, 다른 모든 레이어는 위와 같은 초기화를 사용해야 합니다. 이는 다음과 같이 코드화됩니다:
@ -371,9 +371,9 @@ def _init_weights(self, module):
module.project_hid._is_hf_initialized = True
module.project_q._is_hf_initialized = True
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
```
`_is_hf_initialized` 플래그는 서브모듈을 한 번만 초기화하도록 내부적으로 사용됩니다. `module.project_q``module.project_hid`에 대해 `True`로 설정함으로써, 우리가 수행한 사용자 정의 초기화가 이후에 덮어쓰이지 않도록 합니다. 즉, `_init_weights` 함수가 이들에게 적용되지 않습니다.

View File

@ -152,7 +152,7 @@ class ParallelInterface(MutableMapping):
```python
class Llama4TextExperts(nn.Module):
...
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
self.gate_up_proj = nn.Parameter(torch.zeros(self.num_experts, self.hidden_size, 2 * self.expert_dim))
```
배치 행렬 곱셈을 `forward` 패스에서 사용하여 `gate_up_proj` 모듈의 출력을 계산할 수 있습니다.

View File

@ -502,16 +502,10 @@ class DummyBertLMPredictionHead(nn.Module):
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
@ -536,18 +530,18 @@ class DummyBertPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, DummyBertLMPredictionHead):
module.bias.data.zero_()
module.bias.zero_()
@auto_docstring(

View File

@ -265,7 +265,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel):
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
if "RMSNorm" in module.__class__.__name__:
module.weight.data.zero_()
module.weight.zero_()
class MyNewModel2ForSequenceClassification(GenericForSequenceClassification, MyNewModel2PreTrainedModel):

View File

@ -104,9 +104,9 @@ class NewTaskModelPreTrainedModel(PreTrainedModel):
std = getattr(self.config, "initializer_range", self.config.get_text_config().initializer_range)
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
def token_type_ids_mask_function(
@ -428,7 +428,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
def __init__(self, config):
@ -440,7 +440,15 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
prefix = "model.language_model."
prefixed_mapping = {
f"{prefix}{target}": f"{prefix}{source}"
for target, source in self.language_model._tied_weights_keys.items()
}
if isinstance(self._tied_weights_keys, dict):
self._tied_weights_keys.update(prefixed_mapping)
else:
self._tied_weights_keys = prefixed_mapping
self.post_init()
def get_input_embeddings(self):

View File

@ -505,16 +505,10 @@ class RobertaLMPredictionHead(nn.Module):
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
@ -539,18 +533,18 @@ class RobertaPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, RobertaLMPredictionHead):
module.bias.data.zero_()
module.bias.zero_()
@auto_docstring(

View File

@ -846,11 +846,11 @@ class TestDetrPreTrainedModel(PreTrainedModel):
nn.init.xavier_uniform_(module.output_proj.weight.data)
nn.init.constant_(module.output_proj.bias.data, 0.0)
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if hasattr(module, "reference_points") and not self.config.two_stage:

View File

@ -19,7 +19,15 @@ class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration):
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
if self.language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"model.language_model.{k}" for k in self.language_model._tied_weights_keys]
prefix = "model.language_model."
prefixed_mapping = {
f"{prefix}{target}": f"{prefix}{source}"
for target, source in self.language_model._tied_weights_keys.items()
}
if isinstance(self._tied_weights_keys, dict):
self._tied_weights_keys.update(prefixed_mapping)
else:
self._tied_weights_keys = prefixed_mapping
self.post_init()

View File

@ -18,7 +18,7 @@ import platform
import re
import string
import time
from collections.abc import AsyncIterator, Callable
from collections.abc import AsyncIterator
from typing import Annotated, Optional
import click
@ -98,39 +98,55 @@ If you're a new user, check this basic flag guide: https://huggingface.co/docs/t
class RichInterface:
def __init__(
self,
model_id: str,
user_id: str,
token_processors: Optional[list[Callable[[str], str]]] = None,
sequence_processor: Optional[list[Callable[[str], str]]] = None,
):
def __init__(self, model_id: str, user_id: str):
self._console = Console()
self.model_id = model_id
self.user_id = user_id
token_processors = token_processors or []
sequence_processor = sequence_processor or []
self.token_processors = [self._special_token_processor, *token_processors]
self.sequence_processor = [self._code_blocks_sequence_processor, *sequence_processor]
async def stream_output(self, stream: AsyncIterator[ChatCompletionStreamOutput]) -> tuple[str, int]:
self._console.print(f"[bold blue]<{self.model_id}>:")
with Live(console=self._console, refresh_per_second=4) as live:
text = ""
async for token in await stream:
outputs = token.choices[0].delta.content
if not outputs:
continue
text += outputs
sequence = self._process_sequence(text)
# Escapes single words encased in <>, e.g. <think> -> \<think\>, for proper rendering in Markdown.
# It only escapes single words that may have `_`, optionally following a `/` (e.g. </think>)
outputs = re.sub(r"<(/*)(\w*)>", r"\<\1\2\>", outputs)
markdown = Markdown(sequence, code_theme="github-dark")
text += outputs
# Render the accumulated text as Markdown
# NOTE: this is a workaround for the rendering "unstandard markdown"
# in rich. The chatbots output treat "\n" as a new line for
# better compatibility with real-world text. However, rendering
# in markdown would break the format. It is because standard markdown
# treat a single "\n" in normal text as a space.
# Our workaround is adding two spaces at the end of each line.
# This is not a perfect solution, as it would
# introduce trailing spaces (only) in code block, but it works well
# especially for console output, because in general the console does not
# care about trailing spaces.
lines = []
for line in text.splitlines():
lines.append(line)
if line.startswith("```"):
# Code block marker - do not add trailing spaces, as it would
# break the syntax highlighting
lines.append("\n")
else:
lines.append(" \n")
markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
# Update the Live console output
live.update(markdown, refresh=True)
self._console.print()
return text
def input(self) -> str:
@ -164,46 +180,6 @@ class RichInterface:
self._console.print(f"[bold blue]{config}")
self._console.print()
def _special_token_processor(self, token):
# Escapes single words encased in <>, e.g. <think> -> \<think\>, for proper rendering in Markdown.
# It only escapes single words that may have `_`, optionally following a `/` (e.g. </think>)
return re.sub(r"<(/*)(\w*)>", r"\<\1\2\>", token)
def _code_blocks_sequence_processor(self, sequence):
# Render the accumulated text as Markdown
# NOTE: this is a workaround for the rendering "unstandard markdown"
# in rich. The chatbots output treat "\n" as a new line for
# better compatibility with real-world text. However, rendering
# in markdown would break the format. It is because standard markdown
# treat a single "\n" in normal text as a space.
# Our workaround is adding two spaces at the end of each line.
# This is not a perfect solution, as it would
# introduce trailing spaces (only) in code block, but it works well
# especially for console output, because in general the console does not
# care about trailing spaces.
lines = []
for line in sequence.splitlines():
lines.append(line)
if line.startswith("```"):
# Code block marker - do not add trailing spaces, as it would
# break the syntax highlighting
lines.append("\n")
else:
lines.append(" \n")
return "".join(lines).strip()
def _process_token(self, token):
for token_processor in self.token_processors:
token = token_processor(token)
return token
def _process_sequence(self, sequence):
for sequence_processor in self.sequence_processor:
sequence = sequence_processor(sequence)
return sequence
class ChatCommand(typer.core.TyperCommand):
"""Custom Click command to override missing parameter error message.
@ -264,8 +240,6 @@ class Chat:
help="Path to a local generation config file or to a HuggingFace repo containing a `generation_config.json` file. Other generation settings passed as CLI arguments will be applied on top of this generation config."
),
] = None,
token_processors: Optional[list[Callable[[str], str]]] = None,
sequence_processors: Optional[list[Callable[[str], str]]] = None,
) -> None:
"""Chat with a model from the command line."""
self.base_url = base_url
@ -273,9 +247,6 @@ class Chat:
self.system_prompt = system_prompt
self.save_folder = save_folder
self.token_processors = token_processors or []
self.sequence_processors = sequence_processors or []
# Generation settings
config = load_generation_config(generation_config)
config.update(do_sample=True, max_new_tokens=256) # some default values
@ -381,12 +352,7 @@ class Chat:
# Main logic
async def _inner_run(self):
interface = RichInterface(
model_id=self.model_id,
user_id=self.user,
sequence_processor=self.sequence_processors,
token_processors=self.token_processors,
)
interface = RichInterface(model_id=self.model_id, user_id=self.user)
interface.clear()
chat = new_chat_history(self.system_prompt)

View File

@ -14,7 +14,9 @@
import asyncio
import base64
import copy
import datetime
import enum
import functools
import gc
import io
import json
@ -30,7 +32,8 @@ from threading import Thread
from typing import Annotated, Optional, TypedDict, Union
import typer
from huggingface_hub import scan_cache_dir
from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE
from openai.types.chat.chat_completion import Choice
from tokenizers.decoders import DecodeStream
@ -717,45 +720,51 @@ class Serve:
"""
return f"data: {chunk.model_dump_json(exclude_none=True)}\n\n"
@staticmethod
def get_gen_models(cache_dir: Optional[str] = None) -> list[dict[str, any]]:
@functools.cache
def get_gen_models(self) -> list[dict[str, any]]:
"""
List generative models in the cache.
This is by no means a limit to which models may be instantiated with `transformers serve`: any chat-based
model working with generate can work.
This is a limited list of models to ensure we have a discoverable /v1/models endpoint for third-party
integrations.
"""
generative_models = []
models = [
"Menlo/Jan-nano",
"Menlo/Jan-nano-128k",
"Qwen/Qwen2.5-0.5B-Instruct",
"Qwen/Qwen2.5-3B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
"Qwen/Qwen2.5-14B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct",
"meta-llama/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.3-70B-Instruct",
"HuggingFaceTB/SmolVLM-Instruct",
"ibm-granite/granite-vision-3.2-2b",
"Qwen/Qwen2.5-VL-7B-Instruct",
]
for repo in scan_cache_dir(cache_dir).repos:
if repo.repo_type != "model":
continue
refs = repo.refs
for ref, revision_info in refs.items():
files = revision_info.files
config_path = next((f.file_path for f in files if f.file_name == "config.json"), None)
if not config_path:
continue
config = json.loads(config_path.open().read())
if "architectures" not in config:
continue
architectures = config["architectures"]
if any(arch for arch in architectures if arch in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values()):
print(repo.repo_id, ref)
author = repo.repo_id.split("/") if "/" in repo.repo_id else ""
repo_id = repo.repo_id + (f"@{ref}" if ref != "main" else "")
generative_models.append(
{
"owned_by": author,
"id": repo_id,
"object": "model",
"created": repo.last_modified,
}
)
return generative_models
if HF_HUB_OFFLINE:
return [
{
"id": model,
"object": "model",
"created": datetime.datetime.now().timestamp(),
"owned_by": model.split("/")[0],
}
for model in models
]
else:
model_infos = [model_info(model) for model in models]
return [
{
"id": model.id,
"object": "model",
"created": model.created_at.timestamp(),
"owned_by": model.author,
}
for model in model_infos
]
def continuous_batching_chat_completion(self, req: dict, request_id: str) -> StreamingResponse | JSONResponse:
"""

View File

@ -876,7 +876,7 @@ class PreTrainedConfig(PushToHubMixin):
if hasattr(self, "quantization_config"):
serializable_config_dict["quantization_config"] = (
self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict)
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
else self.quantization_config
)
self.dict_dtype_to_str(serializable_config_dict)
@ -910,7 +910,7 @@ class PreTrainedConfig(PushToHubMixin):
if hasattr(self, "quantization_config"):
output["quantization_config"] = (
self.quantization_config.to_dict()
if not isinstance(self.quantization_config, dict)
if not isinstance(self.quantization_config, dict) and self.quantization_config is not None
else self.quantization_config
)
self.dict_dtype_to_str(output)

View File

@ -0,0 +1,141 @@
# coding=utf-8
# Copyright (C) 2025 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .core_model_loading import Concatenate, MergeModulelist, WeightConverter
from .utils import is_torch_available
if is_torch_available():
import torch
def _build_checkpoint_conversion_mapping():
mapping = {
"mixtral": [
WeightConverter(
source_keys=[
"block_sparse_moe.experts.*.w1.weight",
"block_sparse_moe.experts.*.w3.weight",
], # you give me a list of 2 keys, I collect a list of a list of tensors
target_keys="mlp.experts.gate_up_proj", # target key gets the list of two tensors
operations=[
MergeModulelist(
dim=0
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
Concatenate(dim=1), # each process has 2 tensors, gate and up, we concat them into gate_up
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
),
WeightConverter(
source_keys=[
"block_sparse_moe.experts.*.w2.weight",
],
target_keys="mlp.experts.down_proj", # target key gets the list of two tensors
operations=[
MergeModulelist(
dim=0
), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors
], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first
),
# WeightConverter(
# ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
# "self_attn.qkv_proj",
# operations=[Concatenate(dim=0)], # more like stack?
# ),
WeightConverter("*.block_sparse_moe.", "*.mlp."),
],
"qwen2_moe": [
WeightConverter(
source_keys=[
"mlp.experts.*.gate_proj.weight",
"mlp.experts.*.up_proj.weight",
],
target_keys="mlp.experts.gate_up_proj",
operations=[MergeModulelist(dim=0), Concatenate(dim=1)],
),
WeightConverter(
source_keys=["mlp.experts.*.down_proj.weight"],
target_keys="mlp.experts.down_proj",
operations=[MergeModulelist(dim=0)],
),
],
"legacy": [
WeightConverter(
source_keys="LayerNorm.gamma",
target_keys="LayerNorm.weight",
),
WeightConverter(
source_keys="LayerNorm.beta",
target_keys="LayerNorm.bias",
),
],
}
if hasattr(torch.nn.utils.parametrizations, "weight_norm"):
mapping["legacy"] += [
WeightConverter(
source_keys="weight_g",
target_keys="parametrizations.weight.original0",
),
WeightConverter(
source_keys="weight_v",
target_keys="parametrizations.weight.original1",
),
]
else:
mapping["legacy"] += [
WeightConverter(
source_keys="parametrizations.weight.original0",
target_keys="weight_g",
),
WeightConverter(
source_keys="parametrizations.weight.original1",
target_keys="weight_v",
),
]
mapping["phimoe"] = mapping["mixtral"].copy()
mapping["deepseek_v2"] = mapping["qwen2_moe"].copy()
mapping["deepseek_v3"] = mapping["qwen2_moe"].copy()
mapping["dot1"] = mapping["qwen2_moe"].copy()
mapping["ernie_4_5_moe"] = mapping["qwen2_moe"].copy()
mapping["glm4_moe"] = mapping["qwen2_moe"].copy()
mapping["glm4v_moe"] = mapping["qwen2_moe"].copy()
mapping["jamba"] = mapping["qwen2_moe"].copy()
mapping["lfm2_moe"] = mapping["mixtral"].copy()
mapping["long_cat_flash"] = mapping["qwen2_moe"].copy()
mapping["qwen3_moe"] = mapping["qwen2_moe"].copy()
mapping["qwen3_omni_moe"] = mapping["qwen2_moe"].copy()
mapping["qwen3_next"] = mapping["qwen2_moe"].copy()
mapping["qwen3_vl_moe"] = mapping["qwen2_moe"].copy()
mapping["hunyuan_v1_moe"] = mapping["qwen2_moe"].copy()
mapping["minimax"] = mapping["mixtral"].copy()
return mapping
_checkpoint_conversion_mapping_cache = None
def get_checkpoint_conversion_mapping():
global _checkpoint_conversion_mapping_cache
if _checkpoint_conversion_mapping_cache is None:
_checkpoint_conversion_mapping_cache = _build_checkpoint_conversion_mapping()
globals()["_checkpoint_conversion_mapping"] = _checkpoint_conversion_mapping_cache
return _checkpoint_conversion_mapping_cache
def __getattr__(name):
if name == "_checkpoint_conversion_mapping":
return get_checkpoint_conversion_mapping()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -0,0 +1,761 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Core helpers for loading model checkpoints."""
from __future__ import annotations
import itertools
import os
import re
from abc import abstractmethod
from collections import defaultdict
from collections.abc import MutableMapping, MutableSet, Sequence
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from types import MethodType
from typing import Any, Optional, Union
import torch
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, TensorParallelLayer, DTensor, Replicate
from .quantizers import HfQuantizer
from .utils import is_torch_greater_or_equal, logging
from .utils.quantization_config import QuantizationMethod
_torch_distributed_available = torch.distributed.is_available()
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
if _is_dtensor_available:
from torch.distributed.tensor import DTensor
import itertools
import os
import re
from abc import abstractmethod
from collections import defaultdict
from collections.abc import MutableMapping, MutableSet, Sequence
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from types import MethodType
from typing import Any, Optional, Union
import torch
from .integrations.tensor_parallel import ALL_PARALLEL_STYLES, TensorParallelLayer
from .quantizers import HfQuantizer
from .utils import is_torch_greater_or_equal, logging
from .utils.quantization_config import QuantizationMethod
_torch_distributed_available = torch.distributed.is_available()
_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
if _is_dtensor_available:
from torch.distributed.tensor import DTensor
logger = logging.get_logger(__name__)
str_to_torch_dtype = {
"BOOL": torch.bool,
"U8": torch.uint8,
"I8": torch.int8,
"I16": torch.int16,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I32": torch.int32,
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
}
logger = logging.get_logger(__name__)
str_to_torch_dtype = {
"BOOL": torch.bool,
"U8": torch.uint8,
"I8": torch.int8,
"I16": torch.int16,
"F16": torch.float16,
"BF16": torch.bfloat16,
"I32": torch.int32,
"F32": torch.float32,
"F64": torch.float64,
"I64": torch.int64,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2,
}
def _glob_to_regex_src(glob: str, *, digits_only: bool = True) -> str:
"""
Convert a glob with '*' into a regex *source* string. We don't use `glob.translate`
'*' matches (\\d+) if digits_only else (.+). Inner groups are non-capturing.
"""
star = r"(\d+)" if digits_only else r"(.+)"
return re.escape(glob).replace(r"\*", star)
def build_glob_alt(
globs: list[str],
) -> tuple[re.Pattern, dict[str, str]]:
r"""
Build one compiled regex alternation with a named group per glob. This allows to run a single
re.match and get the correct group name to finally get which pattern matched.
Returns (compiled_regex, name->glob map).
Example:
```py
>>> reg, map_ = build_glob_alt(["mlp.*.w1", "mlp.*.w2"])
>>> print(reg)
(re.compile(r'(?P<g0>.*mlp\.(\d+)\.w1)|(?P<g1>.*mlp\.(\d+)\.w2)', re.UNICODE),
>>> print(map_)
{'g0': 'mlp.*.w1', 'g1': 'mlp.*.w2'})
>>> match_ = reg.match("model.layers.0.mlp.0.w1.weight")
>>> print(match_.lastgroup)
'g0'
>>> print(map_[match_.lastgroup])
mlp.*.w1
```
"""
name_map: dict[str, str] = {}
parts: list[str] = []
prefix_src = r".*"
for i, g in enumerate(globs):
name = f"g{i}"
name_map[name] = g
pat_src = _glob_to_regex_src(g)
parts.append(f"(?P<{name}>{prefix_src}{pat_src})")
alt_src = "|".join(parts)
return re.compile(alt_src), name_map
def match_glob(key: str, alt: re.Pattern, name_map: dict[str, str]) -> Optional[str]:
"""
Match the key against the alternation; return the original glob string that matched.
"""
m = alt.match(key)
if not m:
return None
return name_map.get(m.lastgroup)
class ConversionOps:
"""Base class for weight conversion operations."""
# The inverse operation class, will be used when saving the checkpoint
reverse_op: type[ConversionOps]
@abstractmethod
def convert(
self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], *args, **kwargs
) -> torch.Tensor:
raise NotImplementedError
class Chunk(ConversionOps):
"""Split a tensor along ``dim`` into equally sized chunks or using explicit ``sizes``."""
reverse_op: type[ConversionOps]
def __init__(self, dim: int = 0, chunks: Optional[int] = None, sizes: Optional[Sequence[int]] = None):
if chunks is None and sizes is None:
raise ValueError("`chunks` or `sizes` must be provided for Chunk operations.")
if chunks is not None and chunks <= 0:
raise ValueError("`chunks` must be a strictly positive integer.")
self.dim = dim
self.chunks = chunks
self.sizes = list(sizes) if sizes is not None else None
self.reverse_op = Concatenate
def convert(self, value: torch.Tensor, *args, **kwargs) -> list[torch.Tensor]:
if not isinstance(value, torch.Tensor):
raise TypeError("Chunk expects a torch.Tensor as input.")
if self.sizes is not None:
return list(torch.split(value, self.sizes, dim=self.dim))
return list(torch.chunk(value, self.chunks, dim=self.dim))
class Concatenate(ConversionOps):
"""Concatenate tensors along `dim` using a reusable buffer."""
reverse_op: type[ConversionOps]
def __init__(self, dim: int = 0):
self.dim = dim
self.reverse_op = Chunk
@torch.no_grad
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> torch.Tensor:
if isinstance(value[0], list):
value = [v[0] for v in value]
tensors = value
if not tensors:
raise ValueError("Fuse requires at least one tensor to concatenate.")
return torch.cat(tuple(tensors), dim=self.dim)
class MergeModulelist(Concatenate):
"""
Merge a list of tensors into a single tensor along the first dimension.
We explicitly define this because for EP or TP you want to make sure you know what you are doing!
"""
def __init__(self, dim: int = 0):
super().__init__(dim=dim)
self.reverse_op = SplitModulelist
@torch.no_grad
def convert(self, value: Sequence[torch.Tensor], *args, **kwargs) -> list[torch.Tensor]:
merged = []
for group in value:
if not isinstance(group, Sequence) or len(group) == 0:
raise ValueError("MergeModulelist requires non-empty sub-sequences.")
group = [k for k in group if k.ndim]
merged.append(torch.stack(group, dim=self.dim))
return merged
class SplitModulelist(ConversionOps):
"""Inverse of :class:`MergeModulelist` using explicit split sizes per group."""
def __init__(self, sizes: Sequence[Sequence[int]], dim: int = 0):
if not isinstance(sizes, Sequence) or not all(isinstance(sub, Sequence) and sub for sub in sizes):
raise ValueError("`sizes` must be a sequence of non-empty sequences of integers.")
self.sizes = [list(sub) for sub in sizes]
self.dim = dim
self.reverse_op = MergeModulelist
@torch.no_grad
def convert(self, value: Sequence[torch.Tensor], *, context: dict[str, Any]) -> list[list[torch.Tensor]]:
if not isinstance(value, Sequence):
raise TypeError("SplitModulelist expects a sequence of tensors.")
if len(value) != len(self.sizes):
raise ValueError("Number of tensors does not match the provided split specifications.")
result: list[list[torch.Tensor]] = []
for tensor, split_sizes in zip(value, self.sizes):
if not isinstance(tensor, torch.Tensor):
raise TypeError("SplitModulelist can only split torch.Tensor instances.")
splits = torch.split(tensor, split_sizes, dim=self.dim)
result.append(list(splits))
return result
class PermuteForRope(ConversionOps):
"""
Applies the permutation required to convert complex RoPE weights to the split sin/cos format.
"""
def __init__(self):
pass
def _apply(self, tensor: torch.Tensor) -> torch.Tensor:
dim1, dim2 = tensor.shape
n_heads = self.config.getattr("num_attention_heads", 1)
tensor = tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2)
tensor = tensor.transpose(1, 2).reshape(dim1, dim2)
return tensor
@torch.no_grad
def convert(
self, value: Union[dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor], config
) -> Union[dict[str, torch.Tensor], list[torch.Tensor], torch.Tensor]:
self.config = config
out = [[self._apply(x) for x in inner] if isinstance(inner, list) else self._apply(inner) for inner in value]
return out
@dataclass(slots=True)
class WeightConverter:
r"""
A weight convert that acts on a pattern of source keys.
The keys need to be collected based on the target keys.
With wild card, glob patterns are matched, so you have to be detailed with what to match. If you match:
`model.layers.*.experts.*` -> it will act on all of them
{"model.layers.*.experts.*": []}
but
`experts.*.mlp` will be layer specific.
{"model.layers.1.experts.*": [], }
- source_keys: str | list[str] (wildcards '*' match digits)
- target_keys: str | list[str] | None
- distributed_operation / operations / quantization_operations are ALWAYS lists.
"""
source_keys: Union[str, list[str]]
target_keys: Optional[Union[str, list[str]]] = None
operations: list[ConversionOps] = field(default_factory=list, repr=False)
distributed_operation: Optional[TensorParallelLayer] = None
quantization_operation: Optional[ConversionOps] = None
def __post_init__(self):
if not isinstance(self.source_keys, list):
self.source_keys = [self.source_keys]
targets_were_none = False
if not isinstance(self.target_keys, list):
if self.target_keys is None:
self.target_keys = list(self.source_keys)
targets_were_none = True
else:
self.target_keys = [self.target_keys]
if not targets_were_none and bool(len(self.source_keys) - 1) + bool(len(self.target_keys) - 1) >= 2:
raise ValueError(
f"source keys={self.source_keys}, target_keys={self.target_keys} but you can only have one to many, one to one or many to one."
)
for pattern in self.source_keys:
if any(ch in pattern for ch in set("^$+?{}[]|()")):
raise AssertionError(f"'{pattern}' is not glob")
for pattern in self.target_keys:
if any(ch in pattern for ch in set("^$+?{}[]|()")):
raise AssertionError(f"'{pattern}' is not glob")
@dataclass(slots=True)
class ConversionEntry:
weight_converter: WeightConverter
collected_tensors: dict = field(default_factory=lambda: defaultdict(dict))
GLOBAL_WORKERS = min(16, (os.cpu_count() or 8) * 2) # NVMe: 8-16; HDD/NFS: 2-4
# Factory function to create LoadedParameter subclasses dynamically
def get_loaded_parameter_class(base_cls):
"""
base_cls: an nn.Parameter subclass (or nn.Parameter) or a Tensor
Returns a new class that combines the base_cls with LoadedParameterMixin
"""
class LoadedParam(base_cls):
_inplace_methods = [
'add_', 'mul_', 'clamp_', 'zero_', 'fill_', 'normal_', 'uniform_',
'copy_', 'erfinv_', 'log_', "__getitem__", "neg_", "exp_", "sub_"
]
def __new__(cls, from_existing, **kwargs):
if isinstance(from_existing, torch.nn.Parameter):
inst = super().__new__(cls, from_existing.data, from_existing.requires_grad, **from_existing.__dict__)
else:
inst = super().__new__(cls, from_existing)
inst._original_type = from_existing
# Explicitly override all in-place methods per instance
for method_name in inst._inplace_methods:
setattr(inst, method_name, MethodType(inst._skip, inst))
return inst
def _skip(self, *args, **kwargs):
"""Helper to skip in-place operations."""
return self
def __repr__(self):
return f"LoadedParameter(data={self.data})"
@property
def data(self):
return super().data
@data.setter
def data(self, new):
pass
def __lt__(self, other): return torch.Tensor.__lt__(self, other)
def __le__(self, other): return torch.Tensor.__le__(self, other)
def __gt__(self, other): return torch.Tensor.__gt__(self, other)
def __ge__(self, other): return torch.Tensor.__ge__(self, other)
def __eq__(self, other): return torch.Tensor.__eq__(self, other)
def __ne__(self, other): return torch.Tensor.__ne__(self, other)
def __iadd__(self, *args, **kwargs): return self
def __isub__(self, *args, **kwargs): return self
def __imul__(self, *args, **kwargs): return self
def __imatmul__(self, *args, **kwargs): return self
def __itruediv__(self, *args, **kwargs): return self
def __ifloordiv__(self, *args, **kwargs): return self
def __imod__(self, *args, **kwargs): return self
def __ipow__(self, *args, **kwargs): return self
def __iand__(self, *args, **kwargs): return self
def __ior__(self, *args, **kwargs): return self
def __ixor__(self, *args, **kwargs): return self
def __ilshift__(self, *args, **kwargs): return self
def __irshift__(self, *args, **kwargs): return self
return LoadedParam
def _materialize_copy(tensor, dtype=None):
tensor = tensor[...]
if dtype is not None:
tensor = tensor.to(dtype)
return tensor
def spawn_materialize(thread_pool, tensor, dtype=None) -> Future:
def _job():
return _materialize_copy(tensor, dtype)
return thread_pool.submit(_job)
def spawn_tp_materialize(thread_pool, tensor, sharding_method, tensor_idx, dtype=None) -> Future:
def _job():
return sharding_method.shard_tensor(tensor, param_casting_dtype=dtype, tensor_idx=tensor_idx)[0]
return thread_pool.submit(_job)
def dot_natural_key(s: str):
parts = s.split(".")
for i, p in enumerate(parts):
# whole-segment digits -> int; otherwise leave as str
if p.isdigit():
parts[i] = int(p)
return parts
@contextmanager
def log_to_misc(
layer_name: str,
misc: MutableMapping[str, str],
extras: Any = None,
op: Union[list[ConversionOps], ConversionOps, None] = None,
):
# A simple helper to handle errors with contextual messages.
try:
yield
except Exception as e:
def _format_op_name(curr_op: Union[list[ConversionOps], ConversionOps, None]) -> Optional[str]:
if curr_op is None:
return None
if isinstance(curr_op, (list, tuple, set)):
names = [o.__class__.__name__ for o in curr_op if o is not None]
if not names:
return None
return ", ".join(names)
return curr_op.__class__.__name__
op_name = _format_op_name(op)
if isinstance(extras, tuple) and len(extras) == 2:
values, target_keys = extras
descriptor = f"{op_name} " if op_name else ""
misc[layer_name] = (
f"{e}\nError: {descriptor}on tensors destined for {target_keys}. Ckpt contains: {len(values[0])}"
)
elif isinstance(extras, str):
suffix = f" via {op_name}" if op_name else ""
misc[layer_name] = f"{e}\nError{suffix} when processing parameter {extras}"
elif extras is None and op_name:
misc[layer_name] = f"{op_name}: {e}"
else:
misc[layer_name] = f"{extras} |Error: {e}"
raise SkipLayer()
def set_param_for_module(
model: torch.nn.Module,
layer_name: str,
param_value: torch.Tensor,
meta_model_state_dict: MutableMapping[str, Any],
empty_param: torch.Tensor,
mismatch_keys: MutableSet[tuple[str, torch.Size, torch.Size]],
missing_keys: MutableSet[str],
misc: MutableMapping[str, Any],
distributed_operation: Optional[TensorParallelLayer],
hf_quantizer,
):
with log_to_misc(layer_name, misc, layer_name):
module_path, _, param_name = layer_name.rpartition(".")
module_obj = model.get_submodule(module_path) if module_path else model
if isinstance(param_value, list):
param_value = param_value[0]
elif not isinstance(param_value, torch.nn.Parameter):
param_value = param_value[...]
ref = meta_model_state_dict.get(layer_name, empty_param)
use_dtensor = hasattr(distributed_operation, "use_dtensor") and distributed_operation.use_dtensor
if not isinstance(param_value, torch.nn.Parameter):
if distributed_operation is not None:
param_value = DTensor.from_local(
param_value,
distributed_operation.device_mesh,
getattr(distributed_operation, "shard", Replicate()),
run_check=False,
shape=ref.size(),
stride=ref.stride(),
)
if not use_dtensor:
# we convert to local
param_value = param_value.to_local()
if param_name not in module_obj._buffers:
param_value = torch.nn.Parameter(param_value, requires_grad=param_value.is_floating_point())
# to skip any inplace method that modifies the param data
param_value = get_loaded_parameter_class(param_value.__class__)(from_existing=param_value)
# skip mismatch for hf_quantizer for now
if ref is not None and ref.shape != param_value.shape and hf_quantizer is None:
mismatch_keys.add((layer_name, param_value.shape, ref.shape))
setattr(module_obj._parameters[param_name], "_is_hf_initialized", False) # Needs to be initialized
missing_keys.discard(layer_name)
else:
missing_keys.discard(layer_name)
param_value._is_hf_initialized = True # super important otherwise _init_weight re-initi if bias is missing
setattr(module_obj, param_name, param_value)
class SkipLayer(Exception):
"""Control-flow sentinel: abort processing of the current layer only."""
pass
def convert_and_load_state_dict_in_model(
model,
state_dict,
weight_mapping,
tp_plan,
hf_quantizer,
dtype=None,
device_map=None,
dtype_plan=None,
device_mesh=None,
loading_task_model_from_base_state_dict: bool = False,
loading_base_model_from_task_state_dict: bool = False,
):
"""
Convert a state dict according to a weight mapping (one WeightConverter per glob pattern),
collecting tensors per *layer instance* (the concrete indices captured from '*').
"""
prefix = model.base_model_prefix
tp_plan = tp_plan or {} # {glob_pattern: plan_obj_or_key}
device_map = device_map or {} # {exact_target_key: device}
dtype_plan = dtype_plan or {} # {glob_pattern: dtype}
weight_mapping = weight_mapping or {} # {glob_pattern: WeightConverter}
meta_model_state_dict = model.state_dict()
missing_keys = set(meta_model_state_dict.keys())
misc = {}
mismatch_keys = set()
unexpected_keys = set()
# Global thread_pool
thread_pool = ThreadPoolExecutor(max_workers=GLOBAL_WORKERS)
_patterns = list(itertools.chain.from_iterable([k.source_keys for k in weight_mapping]))
source_to_target = {sk: k for k in weight_mapping for sk in k.source_keys}
weight_pattern_alt, weight_pattern_by_group_name = build_glob_alt(_patterns)
tp_plan_alt, tp_plan_by_group_name = build_glob_alt(list(tp_plan.keys()))
dtype_policy_alt, dtype_policy_by_group_name = build_glob_alt(list(dtype_plan.keys()))
state_dict = sorted(state_dict.items(), key=lambda kv: dot_natural_key(kv[0]))
# 1. Create the conversion entries
by_conversion_pattern: dict[str, ConversionEntry] = {}
for original_key, tensor in state_dict:
matched_pattern = match_glob(original_key, weight_pattern_alt, weight_pattern_by_group_name)
if matched_pattern is not None:
converter = source_to_target[matched_pattern] # TODO make sure its the ref
sub_with_extractor = partial(re.sub, _glob_to_regex_src(matched_pattern), string=original_key)
entry_key = "|".join(converter.target_keys)
target_key = "|".join(map(sub_with_extractor, [k.replace("*", "\\1") for k in converter.target_keys]))
entry: ConversionEntry = by_conversion_pattern.setdefault(entry_key, ConversionEntry(converter))
converter_key = sub_with_extractor(matched_pattern)
else:
converter = WeightConverter(original_key)
converter_key = entry_key = target_key = original_key
entry = by_conversion_pattern.setdefault(converter_key, ConversionEntry(converter))
_dtype = dtype
new_target_key = [] # test_load_with_mismatched_shapes for AutoModel.from_pretrained(AutoForCausal, vocab=10)
for t in target_key.split("|"):
if t.startswith(prefix) and meta_model_state_dict.get(t.replace(f"{prefix}.", "")) is not None:
t = t.replace(f"{prefix}.", "")
elif meta_model_state_dict.get(f"{prefix}.{t}") is not None:
t = f"{prefix}.{t}"
new_target_key.append(t)
empty_param = meta_model_state_dict.get(t)
# If it does not exist, it's unexpected
if empty_param is None:
if hf_quantizer is not None and hf_quantizer.is_valid_unexpected_keys(t):
pass
else:
unexpected_keys.add(t)
continue
if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, t):
converter.quantization_operation = hf_quantizer.get_quantize_ops()
# TODO: to clean later. We need to use the empty_param from the checkpoint to decide if we upcast the param to a specific dtype
k_dtype = tensor.get_dtype()
dtype = str_to_torch_dtype[k_dtype]
empty_param_checkpoint = torch.empty(size=tensor.get_shape(), dtype=dtype, device="meta")
_, _dtype = _infer_parameter_dtype(model, t, empty_param_checkpoint, hf_quantizer)
else:
_dtype = dtype
matched_dtype_pattern = match_glob(t, dtype_policy_alt, dtype_policy_by_group_name)
if matched_dtype_pattern is not None:
_dtype = dtype_plan[matched_dtype_pattern]
elif empty_param.dtype != _dtype:
_dtype = empty_param.dtype
first_target_key = new_target_key[0]
target_key = "|".join(new_target_key)
future = None
if device_mesh:
if matched_tp_pattern := match_glob(first_target_key, tp_plan_alt, tp_plan_by_group_name):
empty_param = meta_model_state_dict.get(first_target_key)
if getattr(converter, "distributed_operation", {}) is None:
tp_layer = ALL_PARALLEL_STYLES[model.tp_plan[matched_tp_pattern]].__class__
converter.distributed_operation = tp_layer(
device_mesh=device_mesh, rank=device_map[""].index, empty_param=empty_param.clone()
)
# VERY IMPORTANT: this tells us wether we collected stuffs or not.
shard_index = len(entry.collected_tensors[target_key].get(converter_key, []))
future = spawn_tp_materialize(
thread_pool,
tensor,
_dtype,
converter.distributed_operation,
shard_index,
)
if future is None: # If not TP, async materialize the tensors. TODO handle disk offload?
future = spawn_materialize(thread_pool, tensor, _dtype)
entry.collected_tensors[target_key].setdefault(converter_key, []).append(future)
# 2. Actually convert the ckpt
inverse_converters = {}
keys = list(by_conversion_pattern.keys())
with logging.tqdm(total=len(keys), desc="Loading weights") as pbar:
for key in keys[::-1]: # revert to process simple keys first
group = by_conversion_pattern.pop(key)
converter = group.weight_converter
operations = converter.operations if isinstance(converter.operations, list) else [converter.operations]
for layer_name, tensors_for_this_layer in group.collected_tensors.items():
concrete_target_keys = layer_name.split("|")
try:
if bool(set(concrete_target_keys) - unexpected_keys):
with log_to_misc(layer_name, misc):
values = [[k.result() for k in inner] for inner in tensors_for_this_layer.values()]
for op in operations:
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
values = op.convert(values, model.config)
values = [values] if not isinstance(values, list) else values
with log_to_misc(layer_name, misc, (values, concrete_target_keys), operations):
realized_value = {
k: t for k, t in zip(concrete_target_keys, values) if k not in unexpected_keys
}
for k in list(realized_value.keys()).copy():
if op := converter.quantization_operation:
with log_to_misc(layer_name, misc, op=op):
realized_value.update(
op.convert({k: realized_value.pop(k)}, model=model, missing_keys=missing_keys)
)
for k, output_value in realized_value.items():
for src in converter.source_keys: # what should happen to k when we meet k at saving
inverse_converters[k] = {src: converter}
set_param_for_module(
model,
k,
output_value,
meta_model_state_dict,
empty_param,
mismatch_keys,
missing_keys,
misc,
converter.distributed_operation,
hf_quantizer
)
except Exception as e :
raise e
del group
# Update progress bar
pbar.update()
pbar.refresh()
model.inverse_converters = inverse_converters
thread_pool.shutdown(wait=False)
return missing_keys, unexpected_keys, mismatch_keys, misc
# TODO this is not done yet!
def revert_weight_conversion(model, state_dict):
mapping = getattr(model, "", {}) # IDK why but setting this will fail all llava.
reverse_key_mapping = [(v, k) for k, v in mapping.items()]
original_state_dict = {}
for key, value in state_dict.items():
for pattern, inverse_converter in reverse_key_mapping:
# TODO FIXME you name it
replacement = inverse_converter.lstrip("^") # strip off un-needed chars and patterns
replacement = re.sub(r"\(.*\)", "", replacement)
key, n_replace = re.subn(pattern, replacement, key)
# Early exit of the loop
if n_replace > 0:
break
original_state_dict[key] = value
state_dict = original_state_dict
return state_dict
def _infer_parameter_dtype(
model: torch.nn.Module,
param_name: str,
empty_param: torch.Tensor,
hf_quantizer: Optional[HfQuantizer] = None,
) -> tuple[bool, Optional[torch.dtype]]:
try:
old_param = model.get_parameter_or_buffer(param_name)
except Exception as e:
if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.QUARK,
QuantizationMethod.MXFP4,
QuantizationMethod.BITS_AND_BYTES,
}:
return True, None
else:
raise e
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
# We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
# in int/uint/bool and not cast them.
casting_dtype = None
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
# dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes
if hf_quantizer is not None and hf_quantizer.param_needs_quantization(model, param_name):
casting_dtype = model.config._pre_quantization_dtype
else:
casting_dtype = old_param.dtype
return old_param is not None and old_param.is_contiguous(), casting_dtype

View File

@ -19,7 +19,6 @@ from typing import Optional
import torch
from ...utils import is_torch_xpu_available
from ...utils.logging import logging
from ...utils.metrics import traced
@ -36,13 +35,6 @@ def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]:
total_memory = torch.cuda.get_device_properties(device).total_memory
reserved_memory = torch.cuda.memory_reserved(device)
allocated_memory = torch.cuda.memory_allocated(device)
elif is_torch_xpu_available():
device = torch.device("xpu")
torch.xpu.empty_cache()
torch.xpu.synchronize()
total_memory = torch.xpu.get_device_properties(device).total_memory
reserved_memory = torch.xpu.memory_reserved(device)
allocated_memory = torch.xpu.memory_allocated(device)
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = torch.device("mps")
# MPS memory reporting (PyTorch 2.0+)

View File

@ -1635,7 +1635,12 @@ class GenerationMixin(ContinuousMixin):
# TransformersKwargs are model-agnostic attention and generation arguments such as 'output_attentions'
for key, value in model_kwargs.items():
if value is not None and key not in model_args and key not in TransformersKwargs.__optional_keys__:
if (
value is not None
and key not in model_args
and key not in TransformersKwargs.__optional_keys__
and key != "debug_io"
):
unused_model_args.append(key)
if unused_model_args:

View File

@ -383,10 +383,11 @@ class BayesianDetectorModel(PreTrainedModel):
)
self.prior = torch.nn.Parameter(torch.tensor([self.base_rate]))
@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Parameter):
module.weight.data.normal_(mean=0.0, std=0.02)
module.weight.normal_(mean=0.0, std=0.02)
def _compute_posterior(
self,

View File

@ -821,26 +821,14 @@ def split_to_tiles(images: "torch.Tensor", num_tiles_height: int, num_tiles_widt
return image
def _cast_tensor_to_float(x):
if x.is_floating_point():
return x
return x.float()
def _group_images_by_shape(nested_images, *paired_inputs, is_nested: bool = False):
"""
Helper function to flatten a single level of nested image and batch structures and group by shape.
Args:
nested_images (list):
A list of images or a single tensor
paired_inputs (Any, *optional*):
Zero or more lists that mirror the structure of `nested_images` (flat list, or list of lists when
`is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
same shape key. These paired values are grouped alongside `nested_images` but are not stacked in the output, so
they do not need to be tensors.
is_nested (bool, *optional*, defaults to False):
Whether the images are nested.
Returns:
tuple[dict, ...]:
- A dictionary with shape as key and list of images with that shape as value
- A dictionary with shape as key and list of paired values with that shape as value
- A dictionary mapping original indices to (shape, index) tuples
- A dictionary mapping original indices to (shape, index) tuples for each paired input
"""
"""Helper function to flatten a single level of nested image and batch structures and group by shape."""
grouped_images = defaultdict(list)
grouped_images_index = {}
paired_grouped_values = [defaultdict(list) for _ in paired_inputs]
@ -892,20 +880,27 @@ def _reconstruct_nested_structure(indices, processed_images):
return result
def _iterate_items(items, is_nested: bool):
"""
Helper function to iterate over items yielding (key, item) pairs.
def _disable_grouping_output_nested(images, *paired_inputs):
"""Build the disable_grouping output tuple for a single-level nested structure."""
outer_range = range(len(images))
inner_ranges = [range(len(images[i])) for i in outer_range]
For nested structures, yields ((row_index, col_index), item).
For flat structures, yields (index, item).
"""
if is_nested:
for i, row in enumerate(items):
for j, item in enumerate(row):
yield (i, j), item
else:
for i, item in enumerate(items):
yield i, item
# Precompute all (i, j) pairs
ij_pairs = [(i, j) for i in outer_range for j in inner_ranges[i]]
images_dict = {(i, j): images[i][j].unsqueeze(0) for (i, j) in ij_pairs}
paired_dicts = [{(i, j): paired_list[i][j].unsqueeze(0) for (i, j) in ij_pairs} for paired_list in paired_inputs]
index_map = {(i, j): ((i, j), 0) for (i, j) in ij_pairs}
return images_dict, *paired_dicts, index_map
def _disable_grouping_output_flat(images, *paired_inputs):
"""Build the disable_grouping output tuple for a flat list structure."""
idx_range = range(len(images))
images_dict = {i: images[i].unsqueeze(0) for i in idx_range}
paired_dicts = [{i: paired_list[i].unsqueeze(0) for i in idx_range} for paired_list in paired_inputs]
index_map = {i: (i, 0) for i in idx_range}
return images_dict, *paired_dicts, index_map
def group_images_by_shape(
@ -925,7 +920,7 @@ def group_images_by_shape(
Args:
images (Union[list["torch.Tensor"], "torch.Tensor"]):
A list of images or a single tensor
paired_inputs (Any, *optional*):
*paired_inputs (Any):
Zero or more lists that mirror the structure of `images` (flat list, or list of lists when
`is_nested=True`). Each element is paired 1:1 with the corresponding image so it can be grouped by the
same shape key. These paired values are grouped alongside `images` but are not stacked in the output, so
@ -949,14 +944,10 @@ def group_images_by_shape(
disable_grouping = device == "cpu"
if disable_grouping:
return (
{key: img.unsqueeze(0) for key, img in _iterate_items(images, is_nested)},
*[
{key: item.unsqueeze(0) for key, item in _iterate_items(paired_list, is_nested)}
for paired_list in paired_inputs
],
{key: (key, 0) for key, _ in _iterate_items(images, is_nested)},
)
if is_nested:
return _disable_grouping_output_nested(images, *paired_inputs)
else:
return _disable_grouping_output_flat(images, *paired_inputs)
# Handle single level nested structure
grouped_images, *paired_grouped_values, grouped_images_index = _group_images_by_shape(
@ -999,3 +990,14 @@ def reorder_images(
]
return _reconstruct_nested_structure(grouped_images_index, processed_images)
class NumpyToTensor:
"""
Convert a numpy array to a PyTorch tensor.
"""
def __call__(self, image: np.ndarray):
# Same as in PyTorch, we assume incoming numpy images are in HWC format
# c.f. https://github.com/pytorch/vision/blob/61d97f41bc209e1407dcfbd685d2ee2da9c1cdad/torchvision/transforms/functional.py#L154
return torch.from_numpy(image.transpose(2, 0, 1)).contiguous()

View File

@ -36,6 +36,7 @@ _import_structure = {
"get_keys_to_not_convert",
"replace_with_bnb_linear",
"validate_bnb_backend_availability",
"Bnb4bitQuantize",
],
"deepspeed": [
"HfDeepSpeedConfig",
@ -177,6 +178,7 @@ if TYPE_CHECKING:
unpack_weights,
)
from .bitsandbytes import (
Bnb4bitQuantize,
dequantize_and_replace,
get_keys_to_not_convert,
replace_with_bnb_linear,

View File

@ -435,6 +435,7 @@ def _get_device_map(
if max_memory is not None and device_name in max_memory:
inferred_max_memory[device_name] = min(inferred_max_memory[device_name], max_memory[device_name])
model.tie_weights()
device_map = infer_auto_device_map(
model,
max_memory=inferred_max_memory,
@ -512,10 +513,8 @@ def accelerate_disk_offload(
checkpoint_files,
device_map,
checkpoint_keys,
key_renaming_mapping,
sharded_metadata,
dtype,
reverse_key_renaming_mapping,
):
disk_only_shard_files = []
if disk_offload_folder is not None:
@ -534,19 +533,13 @@ def accelerate_disk_offload(
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
else:
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
# Fix the weight map keys according to the key mapping
weight_map = {
key_renaming_mapping[k]: v
for k, v in sharded_metadata["weight_map"].items()
if k in key_renaming_mapping
}
weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
# Find potential checkpoints containing only offloaded weights
disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
disk_offload_index = {
name: {
"safetensors_file": file,
"weight_name": reverse_key_renaming_mapping[name],
"weight_name": name,
"dtype": str_dtype,
}
for name, file in weight_map.items()

View File

@ -1,7 +1,9 @@
import inspect
from copy import deepcopy
from collections import defaultdict
from inspect import signature
from typing import Optional
from ..quantizers.quantizers_utils import get_module_from_name
from ..utils import (
get_available_devices,
is_accelerate_available,
@ -24,10 +26,52 @@ if is_accelerate_available():
import accelerate
from accelerate import init_empty_weights
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
from accelerate.utils import find_tied_parameters
logger = logging.get_logger(__name__)
from ..core_model_loading import ConversionOps
class Bnb4bitQuantize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer
def convert(self, input_dict: torch.Tensor, model: Optional[torch.nn.Module] = None, **kwargs) -> dict[str, torch.Tensor]:
target_key, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value
full_name = target_key
# update param name to get the weights instead of the quantized stats
target_key = self.hf_quantizer.get_param_name(target_key)
module, _ = get_module_from_name(model, target_key)
if not self.hf_quantizer.pre_quantized:
# Support models using `Conv1D` in place of `nn.Linear` (e.g. openai-community/gpt2) by transposing the weight matrix prior to quantization.
# Since weights are saved in the correct "orientation", we skip transposing when loading.
if issubclass(module.source_cls, Conv1D):
value = value.T
old_value = model.get_parameter_or_buffer(target_key)
new_value = bnb.nn.Params4bit(value, requires_grad=False, **old_value.__dict__).to(value.device)
return {target_key : new_value}
else:
module_name = target_key.rsplit(".", 1)[0]
# Save the states for later quantization when they are all gathered
if not hasattr(self.hf_quantizer, "param_quant_stats"):
self.hf_quantizer.param_quant_stats = defaultdict(dict)
self.hf_quantizer.param_quant_stats[module_name].update({full_name: value})
# We are ready for quantization in this case (note, the +1 is for the weight itself)
if len(self.hf_quantizer.param_quant_stats[module_name]) == len(self.hf_quantizer.bnb_keys) + 1:
weight = self.hf_quantizer.param_quant_stats[module_name].pop(f"{module_name}.weight")
new_value = bnb.nn.Params4bit.from_prequantized(
data=weight,
quantized_stats=self.hf_quantizer.param_quant_stats[module_name],
requires_grad=False,
device=value.device,
module=module
)
del self.hf_quantizer.param_quant_stats[module_name]
return {target_key : new_value}
return {}
def _replace_with_bnb_linear(
model,
@ -151,52 +195,6 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
return model
def get_keys_to_not_convert(model):
r"""
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
int8.
Parameters:
model (`torch.nn.Module`):
Input model
"""
# Create a copy of the model and tie the weights, then
# check if it contains tied weights
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
tied_model.tie_weights()
tied_params = find_tied_parameters(tied_model)
tied_keys = sum(tied_params, [])
has_tied_params = len(tied_keys) > 0
# If there is not tied weights, we want to keep the lm_headoutput_embedding) in full precision
if not has_tied_params:
output_emb = model.get_output_embeddings()
if output_emb is not None:
list_last_module = [name for name, module in model.named_modules() if id(module) == id(output_emb)]
return list_last_module
# otherwise, no tied weights, no output embedding defined, simply keep the last module in full precision
list_modules = list(model.named_parameters())
list_last_module = [list_modules[-1][0]]
# add last module together with tied weights
intersection = set(list_last_module) - set(tied_keys)
list_untouched = list(set(tied_keys)) + list(intersection)
# remove ".weight" from the keys
names_to_remove = [".weight", ".bias"]
filtered_module_names = []
for name in list_untouched:
for name_to_remove in names_to_remove:
if name_to_remove in name:
name = name.replace(name_to_remove, "")
filtered_module_names.append(name)
return filtered_module_names
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None):
"""

View File

@ -11,6 +11,7 @@
# specific language governing permissions and limitations under the License.
import logging
from collections.abc import Callable
from typing import Optional
import torch
@ -23,7 +24,13 @@ from ..cache_utils import (
StaticCache,
)
from ..generation.configuration_utils import GenerationConfig
from ..modeling_utils import PreTrainedModel
from ..masking_utils import (
ALL_MASK_ATTENTION_FUNCTIONS,
_ignore_causal_mask_sdpa,
_is_torch_greater_or_equal_than_2_5,
prepare_padding_mask,
)
from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ..pytorch_utils import (
is_torch_greater_or_equal,
is_torch_greater_or_equal_than_2_3,
@ -222,6 +229,10 @@ class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
)
self.model = TorchExportableModuleWithStaticCache(model, batch_size, max_cache_len, device)
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
self.model.model.config._attn_implementation = "sdpa_without_vmap"
def forward(
self,
@ -757,6 +768,11 @@ def convert_and_export_with_cache(
import torch.export._trace
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
model.config._attn_implementation = "sdpa_without_vmap"
with torch.no_grad():
# TODO: The default inputs only work for text models. We need to add support for vision/audio models.
example_input_ids = (
@ -1020,6 +1036,11 @@ def export_with_dynamic_cache(
if not is_torch_greater_or_equal_than_2_3:
raise ImportError("torch >= 2.3 is required.")
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
model.config._attn_implementation = "sdpa_without_vmap"
register_dynamic_cache_export_support()
with torch.no_grad():
@ -1088,3 +1109,92 @@ def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
value = value_list[idx] if idx < len(value_list) else None
cache.update(key, value, idx)
return cache
def sdpa_mask_without_vmap(
batch_size: int,
cache_position: torch.Tensor,
kv_length: int,
kv_offset: int = 0,
mask_function: Optional[Callable] = None,
attention_mask: Optional[torch.Tensor] = None,
local_size: Optional[int] = None,
allow_is_causal_skip: bool = True,
allow_torch_fix: bool = True,
**kwargs,
) -> Optional[torch.Tensor]:
"""
Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
the element should take part in the attention computation, and False that it should not.
This is similar to `masking_utils.sdpa_mask` but does not use `vmap` which is incompatible with export.
Args:
batch_size (`int`):
The batch size of the input sequence.
cache_position (`torch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
kv_length (`int`):
The size that the key and value states will have during the attention computation.
kv_offset (`int`, optional):
An optional offset to indicate at which first position the key and values states will refer to.
mask_function (`Callable`):
The mask factory function describing the mask pattern.
attention_mask (`torch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
local_size (`int`, optional):
The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
to try to skip mask creation if possible.
allow_is_causal_skip (`bool`, optional):
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
`torch.sdpa` instead. Default to `True`.
allow_torch_fix (`bool`, optional):
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
versions. We need an arg to skip it when using eager. By default `True`.
"""
q_length = cache_position.shape[0]
# Potentially pad the 2D mask, and slice it correctly
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
# Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument
if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, local_size):
return None
# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
# but without data-dependent slicing (i.e. torch.compile friendly)
kv_arange = torch.arange(kv_length, device=cache_position.device)
kv_arange += kv_offset
reshaped_cache_position = cache_position.view(-1, 1)
# This is a bit hacky to know what pattern we are using, but all mask creation function actually forward
# the config through kwargs anyway, so it allows to rely on it
# Usually, the `mask_function` is the only entry-point to define the pattern - we could do for loops over it,
# but this is more efficient
sliding_window = getattr(kwargs["config"], "sliding_window", None)
chunk_size = getattr(kwargs["config"], "attention_chunk_size", None)
if sliding_window is not None and chunk_size is not None:
raise ValueError("Cannot use both `sliding_window` and `attention_chunk_size`")
# Simplest and most efficient way to obtain a causal mask
causal_mask = kv_arange <= reshaped_cache_position
# If using sliding window, add the sliding mask
if sliding_window is not None:
sliding_mask_overlay = kv_arange > reshaped_cache_position - sliding_window
causal_mask *= sliding_mask_overlay
# If using chunk attention, add the chunked mask
elif chunk_size is not None:
chunked_mask_overlay = kv_arange // chunk_size == reshaped_cache_position // chunk_size
causal_mask *= chunked_mask_overlay
causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)
if padding_mask is not None:
causal_mask = causal_mask * padding_mask[:, None, None, :]
# Due to a bug in some older torch version, we need to update the mask in case a query is not attending to any
# tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
return causal_mask

View File

@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import re
from collections.abc import Sequence
from typing import Any, Optional, Union
from ..core_model_loading import ConversionOps
from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
@ -30,6 +33,18 @@ if is_accelerate_available():
logger = logging.get_logger(__name__)
try:
_FP8_DTYPE = torch.float8_e4m3fn
_FP8_MIN = torch.finfo(_FP8_DTYPE).min
_FP8_MAX = torch.finfo(_FP8_DTYPE).max
_FP8_IS_INT = False
except AttributeError:
_FP8_DTYPE = torch.int8
_FP8_MIN, _FP8_MAX = -127, 127
_FP8_IS_INT = True
logger.warning_once(
"torch.float8_e4m3fn not available; falling back to int8 emulation for Fp8Quantize operations."
)
# Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
@ -332,6 +347,12 @@ class FP8Linear(nn.Linear):
if self.weight.element_size() > 1:
return F.linear(input, self.weight, self.bias)
else:
if isinstance(self.weight, torch.distributed.tensor.DTensor):
weight = self.weight._local_tensor.contiguous()
scale_inv = self.weight_scale_inv._local_tensor.contiguous()
else:
weight = self.weight.contiguous()
scale_inv = self.weight_scale_inv.contiguous()
# Context manager used to switch among the available accelerators
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
@ -339,9 +360,9 @@ class FP8Linear(nn.Linear):
qinput, scale = act_quant(input, self.block_size[1])
output = w8a8_block_fp8_matmul_triton(
qinput,
self.weight,
weight,
scale,
self.weight_scale_inv,
scale_inv,
self.block_size,
output_dtype=input.dtype,
)
@ -350,9 +371,124 @@ class FP8Linear(nn.Linear):
torch_accelerator_module.synchronize()
if self.bias is not None:
output = output + self.bias
output = torch.nan_to_num(output, nan=0.0)
return output.to(dtype=input.dtype)
def _ceil_div(a, b):
return (a + b - 1) // b
class FP8Expert(nn.Module):
dtype = torch.float8_e4m3fn
def __init__(self, config, block_size, device):
super().__init__()
from ..activations import ACT2FN
self.block_size = block_size
self.num_experts = config.num_local_experts
self.hidden_dim = config.hidden_size
self.intermediate_dim = config.intermediate_size
Wg_out, Wg_in = 2 * self.intermediate_dim, self.hidden_dim
Wd_out, Wd_in = self.hidden_dim, self.intermediate_dim
self.gate_up_proj = nn.Parameter(
torch.zeros(self.num_experts, Wg_out, Wg_in, dtype=FP8Expert.dtype, device=device)
)
self.down_proj = nn.Parameter(
torch.zeros(self.num_experts, Wd_out, Wd_in, dtype=FP8Expert.dtype, device=device)
)
# Create inverse scale tiles only when using 1-byte types (fp8)
if self.gate_up_proj.element_size() == 1:
bo, bi = self.block_size
# gate_up tiles: ceil(Wg_out/bo) x ceil(Wg_in/bi)
gu_scale_o = _ceil_div(Wg_out, bo)
gu_scale_i = _ceil_div(Wg_in, bi)
self.gate_up_proj_scales_inv = nn.Parameter(
torch.zeros(self.num_experts, gu_scale_o, gu_scale_i, dtype=torch.float32, device=device)
)
# down tiles: ceil(Wd_out/bo) x ceil(Wd_in/bi)
dp_scale_o = _ceil_div(Wd_out, bo)
dp_scale_i = _ceil_div(Wd_in, bi)
self.down_proj_scales_inv = nn.Parameter(
torch.zeros(self.num_experts, dp_scale_o, dp_scale_i, dtype=torch.float32, device=device)
)
else:
# Match FP8Linear behavior when not using 1-byte weights
self.register_parameter("gate_up_proj_scale_inv", None)
self.register_parameter("down_proj_scale_inv", None)
# (Optional) bias per projection — many MoEs omit bias; keep None to match your FP8Linear default
self.register_parameter("gate_up_bias", None)
self.register_parameter("down_bias", None)
# Activation used in the MLP (same as your config / ACT2FN)
# Keep a handle here; actual usage happens in forward of your MoE block
self.act_fn = ACT2FN[config.hidden_act]
def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states.index_select(0, token_idx)
gate, up = self.linear(
current_state, self.gate_up_proj[expert_idx], self.gate_up_proj_scales_inv[expert_idx]
).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = self.linear(
current_hidden_states, self.down_proj[expert_idx], self.down_proj_scales_inv[expert_idx]
)
routing_weights = top_k_weights[token_idx, expert_idx].unsqueeze(-1)
current_hidden_states = current_hidden_states * routing_weights.to(current_hidden_states.dtype)
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
return final_hidden_states
def linear(self, input: torch.Tensor, weight: torch.Tensor, weight_scale_inv: torch.Tensor) -> torch.Tensor:
if weight.element_size() > 1:
return F.linear(input, weight, None)
else:
# Context manager used to switch among the available accelerators
device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
with torch_accelerator_module.device(input.device):
qinput, scale = act_quant(input, self.block_size[1])
output = w8a8_block_fp8_matmul_triton(
qinput,
weight,
scale,
weight_scale_inv,
self.block_size,
output_dtype=input.dtype,
)
# Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
# preceding operations are ready before proceeding
torch_accelerator_module.synchronize()
return output.to(dtype=input.dtype)
# TODO: we do need this.... but not recursive...
def _replace_with_fp8_linear(
model,
tp_plan=None,
@ -361,40 +497,48 @@ def _replace_with_fp8_linear(
quantization_config=None,
has_been_replaced=False,
):
"""Replace Linear layers with FP8Linear."""
if current_key_name is None:
current_key_name = []
iterator = list(model.named_parameters()).copy()
for name, empty_tensor in iterator:
current_key_name = name
name = name.rsplit(".", 1)[0] if "." in name else name
module = model.get_submodule(name)
for name, module in model.named_children():
current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []):
current_key_name_str = ".".join(current_key_name)
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
with init_empty_weights():
model._modules[name] = FP8Linear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
activation_scheme=quantization_config.activation_scheme,
block_size=quantization_config.weight_block_size,
current_key_name_str = re.sub(r"\d+", "*", current_key_name)
if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
with init_empty_weights():
if (
"gate_up_proj" in current_key_name
or "down_proj" in current_key_name
and "experts" in current_key_name
): # Experts!
in_features = empty_tensor.size(-2)
out_features = empty_tensor.size(-1)
model.set_submodule(
name,
FP8Expert(
config=model.config,
block_size=quantization_config.weight_block_size,
device=empty_tensor.device,
),
)
has_been_replaced = True
# when changing a layer the TP PLAN for that layer should be updated. TODO
if len(list(module.children())) > 0:
_, has_been_replaced = _replace_with_fp8_linear(
module,
tp_plan,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
)
current_key_name.pop(-1)
elif isinstance(module, nn.Linear):
in_features = module.in_features
out_features = module.out_features
model.set_submodule(
name,
FP8Linear(
in_features=in_features,
out_features=out_features,
bias=module.bias is not None,
device=module.weight.device,
dtype=module.weight.dtype,
activation_scheme=quantization_config.activation_scheme,
block_size=quantization_config.weight_block_size,
),
)
has_been_replaced = True
# when changing a layer the TP PLAN for that layer should be updated. TODO
return model, has_been_replaced
@ -405,7 +549,7 @@ def replace_with_fp8_linear(
quantization_config=None,
):
"""Helper function to replace model layers with FP8 versions."""
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
modules_to_not_convert += ["lm_head"]
if quantization_config.modules_to_not_convert is not None:
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
@ -424,3 +568,124 @@ def replace_with_fp8_linear(
)
return model
class Fp8Quantize(ConversionOps):
"""
A quantization operation that creates two tensors, weight and scale out of a weight.
"""
reverse_op: type[ConversionOps]
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer
self.reverse_op = Fp8Dequantize
def convert(self, input_dict: torch.Tensor, **kwargs) -> dict[str, torch.Tensor]:
# Unpack single key/value (value may be wrapped in a list)
target_keys, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value
# Resolve block size (support dict-like or attr-like quant_config)
block_size = None
if self.hf_quantizer.quantization_config is not None:
if isinstance(self.hf_quantizer.quantization_config, dict):
block_size = self.hf_quantizer.quantization_config.get("weight_block_size")
else:
block_size = getattr(self.hf_quantizer.quantization_config, "weight_block_size", None)
if block_size is None:
block_size = (value.shape[-2], value.shape[-1])
block_m, block_n = block_size
rows, cols = value.shape[-2], value.shape[-1]
# Enforce exact tiling like your original
if rows % block_m != 0 or cols % block_n != 0:
raise ValueError(
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n}). for {target_keys}"
)
# Leading dims can be empty (2D) or include num_experts/... (3D+)
leading_shape = value.shape[:-2]
rows_tiles = rows // block_m
cols_tiles = cols // block_n
original_shape = value.shape
value_fp32 = value.to(torch.float32)
# Reshape to (..., rows_tiles, block_m, cols_tiles, block_n)
reshaped = value_fp32.reshape(*leading_shape, rows_tiles, block_m, cols_tiles, block_n)
# Per-tile max-abs over the block dims
# dims: block_m is at -3, block_n is at -1 after the reshape
max_abs = reshaped.abs().amax(dim=(-3, -1))
safe_max_abs = torch.where(max_abs > 0, max_abs, torch.ones_like(max_abs))
# Tile scale (we store inverse scale like your Linear: weight_scale_inv)
scales = _FP8_MAX / safe_max_abs
scales = torch.where(max_abs > 0, scales, torch.ones_like(scales)) # keep zeros stable
# Broadcast scales back over the block dims and quantize
# max_abs/scales shape: (..., rows_tiles, cols_tiles)
scales_broadcast = scales.unsqueeze(-1).unsqueeze(-3) # -> (..., rows_tiles, 1, cols_tiles, 1)
scaled = reshaped * scales_broadcast
if _FP8_IS_INT:
quantized = torch.clamp(scaled.round(), min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
else:
quantized = torch.clamp(scaled, min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE)
quantized = quantized.reshape(original_shape)
inv_scales = (1.0 / scales).to(torch.float32) # shape: (*leading, rows_tiles, cols_tiles)
if target_keys.endswith("weight"):
scale_key = target_keys.rsplit(".", 1)[0] + ".weight_scale_inv"
else:
scale_key = target_keys + "_scales_inv"
# Return both quantized weights and per-tile inverse scales (keeps leading dims, e.g., num_experts)
return {
target_keys: quantized,
scale_key: inv_scales,
}
class Fp8Dequantize(ConversionOps):
"""Inverse operation of :class:`Fp8Quantize`. Takes a pair (weight, scale) and reconstructs the fp32 tensor."""
def __init__(self, block_size: Optional[tuple[int, int]] = None):
self.block_size = block_size
self.reverse_op = Fp8Quantize
def convert(
self,
value: Union[Sequence[torch.Tensor], dict[str, torch.Tensor]],
*,
context: dict[str, Any],
) -> torch.Tensor:
if isinstance(value, dict):
tensors = list(value.values())
else:
tensors = list(value) if isinstance(value, Sequence) else [value]
if len(tensors) != 2:
raise ValueError("Fp8Dequantize expects exactly two tensors: quantized weights and scales.")
quantized, scales = tensors
if not isinstance(quantized, torch.Tensor) or not isinstance(scales, torch.Tensor):
raise TypeError("Fp8Dequantize expects tensors as inputs.")
quantized_fp32 = quantized.to(torch.float32)
rows, cols = quantized_fp32.shape[-2:]
block_size = self.block_size
if block_size is None:
quant_config = context.get("quantization_config")
block_size = getattr(quant_config, "weight_block_size", None)
if block_size is None:
block_size = (rows, cols)
block_m, block_n = block_size
if rows % block_m != 0 or cols % block_n != 0:
raise ValueError(
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_m}, {block_n})."
)
reshaped = quantized_fp32.reshape(-1, rows // block_m, block_m, cols // block_n, block_n)
expanded_scales = scales.to(torch.float32).reshape(-1, rows // block_m, cols // block_n)
expanded_scales = expanded_scales.unsqueeze(-1).unsqueeze(2)
dequantized = reshaped * expanded_scales
return dequantized.reshape(quantized_fp32.shape)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging
from ..utils import is_accelerate_available, is_torch_available, logging
if is_torch_available():
@ -114,9 +114,6 @@ def convert_moe_packed_tensors(
if not blocks.is_cuda and torch.cuda.is_available():
blocks = blocks.cuda()
scales = scales.cuda()
elif (blocks.device.type != "xpu") and is_torch_xpu_available():
blocks = blocks.to("xpu")
scales = scales.to("xpu")
scales = scales.to(torch.int32) - 127 # TODO that's because 128=2**7
@ -354,8 +351,6 @@ def dequantize(module, param_name, param_value, target_device, dq_param_name, **
dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr))
if target_device == "cpu" and torch.cuda.is_available():
torch.cuda.empty_cache()
elif target_device == "cpu" and is_torch_xpu_available():
torch.xpu.empty_cache()
setattr(module, proj, torch.nn.Parameter(dequantized.to(target_device)))
delattr(module, blocks_attr)
delattr(module, scales_attr)
@ -400,7 +395,7 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, trito
else:
blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2)
if getattr(target_device, "type", target_device) == "cpu":
target_device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
target_device = "cuda"
blocks = blocks.to(target_device).contiguous()
scales = scales.to(target_device).contiguous()
with on_device(target_device):

View File

@ -236,7 +236,7 @@ class PeftAdapterMixin:
**adapter_kwargs,
)
peft_config.inference_mode = not is_trainable
# TODO: WE NEED TOO APPLY OUR DYNAMIC WEIGHT CONVERSION AT SOME POINT HERE!
# Create and add fresh new adapters into the model.
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)

View File

@ -18,6 +18,7 @@ import operator
import os
import re
from functools import partial, reduce
from typing import Optional
import torch
import torch.distributed as dist
@ -306,7 +307,7 @@ def repack_weights(
return final_ordered_tensor
def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
def get_tensor_shard(param, empty_param, device_mesh, rank, dim, tensor_idx: Optional[int] = None):
"""
Generalized tensor sharding across a multi-dimensional device mesh.
Extract only the fraction of the parameter owned by the given `rank` when the parameter would have gone sharding at provided `dim`.
@ -358,32 +359,57 @@ def get_tensor_shard(param, empty_param, device_mesh, rank, dim):
rank (int): Global rank of the current process/device.
dim (int): Dimension along which to shard the tensor.
"""
param_dim = empty_param.dim()
if dim < 0:
dim = param_dim + dim
if dim >= param_dim:
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
param_dim = empty_param.ndim
# Flatten the mesh to get the total number of devices
mesh_shape = device_mesh.shape
world_size = reduce(operator.mul, mesh_shape)
if dim < 0:
dim = param_dim + dim
if empty_param.dim() == 3 and dim == 1 and len(param.get_shape()) == 2:
dim = 0
elif empty_param.dim() == 3 and dim == 2 and len(param.get_shape()) == 2:
dim = 0
shard_size = math.ceil(empty_param.size(dim) / world_size)
start = rank * shard_size
end = min(start + shard_size, empty_param.size(dim))
if dim >= param_dim:
raise ValueError(f"dim {dim} is out of bounds for tensor of dimension {param_dim}")
if rank >= world_size:
raise ValueError(f"Rank {rank} is out of bounds for mesh size {world_size}")
shard_size = math.ceil(empty_param.shape[dim] / world_size)
start = rank * shard_size
# we have the full tensor not 1 part of it.
# in that case, we just assume that the weight was properly saved
# and thus because we TP if the layer is colwise it should not use this. Layer should be packed_colwise
# to inform that it needs to read form a packed tensor. It will also take care of the module list thingy.
# here we take care of potential chunking / layer split / layer chunking.
# The only "hard" case is? if we collect q,k,v -> merge it into qkv. In that case
# actually we still shard dim=0 does not change
# so only case is if the dim of the empty param is 3 and the shard dim is 0 -> we put the
# tensor on a certain device (with the input tensor_index)
dimensions = param.get_shape()
# Construct slicing index dynamically
end = min(start + shard_size, empty_param.shape[dim])
slice_indices = [slice(None)] * param_dim
if start < empty_param.shape[dim]:
if empty_param.dim() == 3 and dim == 0 and len(param.get_shape()) == 2:
# special case we don't "shard" just send this entire tensor to the correct rank.
if start <= tensor_idx < end:
# this tensor does need to be materialized on this device:
return param[:]
else:
return torch.empty([], dtype=torch.int64, device=rank)
slice_indices = [slice(None)] * len(param.get_shape())
if start < param.get_shape()[dim]:
slice_indices[dim] = slice(start, end)
return param[tuple(slice_indices)]
dimensions = list(param.shape)
param = param[tuple(slice_indices)]
if isinstance(param, list): # TODO handle the modulelist case!
param = [p[:] for p in param]
return param
dimensions[dim] = 0
return torch.empty(tuple(dimensions), dtype=torch.int64)
return torch.empty(tuple(dimensions), dtype=torch.int64) # empty allocates memory....
def distribute_module(
@ -410,6 +436,19 @@ class TensorParallelLayer:
"""
use_dtensor = True
device_mesh = None
rank = None
# Used to compare the shape of the original tensor
empty_param = None
# Used to init the corresponding DTensor
shard = None
def __init__(self, device_mesh=None, rank=None, empty_param=None):
self.rank = rank
self.device_mesh = device_mesh
self.empty_param = empty_param
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): ...
@ -439,12 +478,12 @@ class GatherParallel(TensorParallelLayer):
def __init__(
self,
*,
input_layouts: Placement | None = None,
output_layouts: Placement | None = None,
use_local_output: bool = True,
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
self.input_layouts = (input_layouts or Replicate(),)
self.output_layouts = output_layouts
self.desired_input_layouts = (Replicate(),)
@ -465,6 +504,21 @@ class GatherParallel(TensorParallelLayer):
dist.all_reduce(outputs[0], op=dist.ReduceOp.SUM, async_op=False)
return outputs
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
shard = [Replicate()]
parameter = param[...].to(param_casting_dtype)
self.shard = shard
return parameter, shard
def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
distribute_module(
module,
@ -493,6 +547,23 @@ class IsolatedParallel(TensorParallelLayer):
# TODO: figure out dynamo support for instance method and switch this to instance method
return outputs
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
mesh = device_mesh or self.device_mesh
parameter = param[...].to(param_casting_dtype)
if mesh is not None:
parameter = parameter / mesh.size()
self.shard = None
return parameter, None
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
param = param[...].to(param_casting_dtype)
if to_contiguous:
@ -515,8 +586,8 @@ class ReplicateParallel(TensorParallelLayer):
This class is used to replicate computation in a TP layer (used in SP regions when we don't use sequence parallelism for example)
"""
def __init__(self, *, use_dtensor=True, use_local_output=True):
super().__init__()
def __init__(self, use_dtensor=True, use_local_output=True, **kwargs):
super().__init__(**kwargs)
self.input_layouts = (Replicate(),)
self.output_layouts = (Replicate(),)
self.desired_input_layouts = (Replicate(),)
@ -537,12 +608,33 @@ class ReplicateParallel(TensorParallelLayer):
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
return outputs.to_local() if use_local_output and isinstance(outputs, DTensor) else outputs
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
parameter = param[...].to(param_casting_dtype)
shard = [Replicate()]
self.shard = shard
return parameter, shard
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
param = param[...].to(param_casting_dtype)
if to_contiguous:
param = param.contiguous()
param = DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
return param
parameter, shard = self.shard_tensor(
param,
param_type=param_type,
param_casting_dtype=param_casting_dtype,
to_contiguous=to_contiguous,
rank=rank,
device_mesh=device_mesh,
)
if self.use_dtensor:
parameter = DTensor.from_local(parameter, device_mesh, shard, run_check=False)
return parameter
class ColwiseParallel(TensorParallelLayer):
@ -552,13 +644,13 @@ class ColwiseParallel(TensorParallelLayer):
def __init__(
self,
*,
input_layouts: Placement | None = None,
output_layouts: Placement | None = None,
use_local_output: bool = True,
use_dtensor=True,
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
self.input_layouts = (input_layouts or Replicate(),)
self.output_layouts = (output_layouts or Shard(-1),)
self.desired_input_layouts = (Replicate(),)
@ -578,18 +670,34 @@ class ColwiseParallel(TensorParallelLayer):
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
return input_tensor
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
device_mesh = self.device_mesh
empty_param = self.empty_param
rank = self.rank
if param_type == "bias":
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx)
shard = [Shard(-1)]
else:
shard = [Shard(-2)]
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2, tensor_idx)
parameter = parameter.to(param_casting_dtype)
self.shard = shard
return parameter, shard
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
if param_type == "bias":
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1)
shard = [Shard(-1)]
else:
shard = [Shard(-2)]
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -2)
parameter = parameter.to(param_casting_dtype)
parameter, shard = self.shard_tensor(param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh)
if to_contiguous:
parameter = parameter.contiguous()
if self.use_dtensor:
@ -608,6 +716,26 @@ class ColwiseParallel(TensorParallelLayer):
class PackedColwiseParallel(ColwiseParallel):
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
device_mesh = device_mesh or self.device_mesh
empty_param = self.empty_param
rank = rank if rank is not None else self.rank
return get_packed_weights(param, empty_param, device_mesh, rank, -2).to(param_casting_dtype), [Shard(-2)]
def create_nn_parameter(
self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh
):
return nn.Parameter(param, requires_grad=param.is_floating_point())
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
@ -642,18 +770,41 @@ class RowwiseParallel(TensorParallelLayer):
def __init__(
self,
*,
input_layouts: Placement | None = None,
output_layouts: Placement | None = None,
use_local_output: bool = True,
use_dtensor=True,
**kwargs,
):
super().__init__()
super().__init__(**kwargs)
self.input_layouts = (input_layouts or Shard(-1),)
self.output_layouts = (output_layouts or Replicate(),)
self.use_local_output = use_local_output
self.use_dtensor = use_dtensor
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
device_mesh = device_mesh or self.device_mesh
empty_param = self.empty_param
rank = rank if rank is not None else self.rank
if param_type == "bias":
shard = [Replicate()]
parameter = param[...]
else:
parameter = get_tensor_shard(param, empty_param, device_mesh, rank, -1, tensor_idx=tensor_idx)
shard = [Shard(-1)]
parameter = parameter.to(param_casting_dtype)
self.shard = shard
return parameter, shard
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
# means Rowwise as nn.Linear is input * weight^T + bias, where
@ -725,6 +876,21 @@ class RowwiseParallel(TensorParallelLayer):
class PackedRowwiseParallel(RowwiseParallel):
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
device_mesh = device_mesh or self.device_mesh
empty_param = self.empty_param
rank = rank if rank is not None else self.rank
return get_packed_weights(param, empty_param, device_mesh, rank, -1), [Shard(-1)]
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(-2) (0 if you have 1 dim only)
# means Colwise as Linear is input * weight^T + bias, where
@ -783,8 +949,8 @@ class SequenceParallel(TensorParallelLayer):
to ensure that they are replicated.
"""
def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False):
super().__init__()
def __init__(self, sequence_dim: int = 1, use_local_output: bool = False, use_dtensor=False, **kwargs):
super().__init__(**kwargs)
self.input_layouts = (Replicate(),)
self.desired_input_layouts = (Shard(1),)
self.output_layouts = (Replicate(),)
@ -793,6 +959,21 @@ class SequenceParallel(TensorParallelLayer):
self.sequence_sharding = (Shard(sequence_dim),)
self.use_local_output = use_local_output
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
parameter = param[...].to(param_casting_dtype)
shard = [Replicate()]
self.shard = shard
return parameter, shard
@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
input_tensor = inputs[0]
@ -827,10 +1008,34 @@ class GroupedGemmParallel(TensorParallelLayer):
Applies Expert Parallelism to MoE experts by loading the correct experts on each device.
"""
def __init__(self):
super().__init__()
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.use_dtensor = False
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
empty_param = self.empty_param
ep_rank = self.rank
device_mesh = self.device_mesh
global_num_experts = empty_param.shape[0]
if global_num_experts % device_mesh.size() != 0:
raise ValueError(
f"Global number of experts must be divisible by number of devices: {global_num_experts} % {device_mesh.size()} != 0"
)
local_num_experts = global_num_experts // device_mesh.size()
parameter = param[ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts].to(param_casting_dtype)
self.shard = None
return parameter, None
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
ep_rank = rank
global_num_experts = empty_param.shape[0]
@ -851,8 +1056,8 @@ class RouterParallel(TensorParallelLayer):
"""
def __init__(self, *args, **kwargs):
super().__init__(**kwargs)
self.args = args
self.kwargs = kwargs
self.use_dtensor = False
@staticmethod
@ -917,6 +1122,20 @@ class RouterParallel(TensorParallelLayer):
) # masking class for one hot
return router_scores, router_indices
def shard_tensor(
self,
param,
param_type=None,
param_casting_dtype=None,
to_contiguous=None,
rank=None,
device_mesh=None,
tensor_idx=None,
):
parameter = param[...].to(param_casting_dtype)
self.shard = None
return parameter, None
def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
# TODO: i'd like for this to be the default
param = param[...].to(param_casting_dtype)
@ -1059,6 +1278,9 @@ def shard_and_distribute_module(
if current_shard_plan is not None:
try:
tp_layer = ALL_PARALLEL_STYLES[current_shard_plan]
tp_layer.empty_param = empty_param
tp_layer.device_mesh = device_mesh
tp_layer.rank = rank
param = tp_layer.partition_tensor(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)

View File

@ -0,0 +1,127 @@
import importlib.metadata
import re
import types
from collections import defaultdict
from typing import Optional, Any
import torch
from packaging import version
from transformers.utils.import_utils import is_torchao_available
from transformers.utils import logging
from ..core_model_loading import ConversionOps
from ..quantizers.quantizers_utils import get_module_from_name
if is_torchao_available():
TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
from torchao.prototype.safetensors.safetensors_support import (
unflatten_tensor_state_dict,
)
from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao
logger = logging.get_logger(__name__)
class TorchAoQuantize(ConversionOps):
def __init__(self, hf_quantizer):
self.hf_quantizer = hf_quantizer
def convert(
self, input_dict: dict[str, torch.Tensor], model: Optional[torch.nn.Module] = None, missing_keys=None, **kwargs
) -> dict[str, torch.Tensor]:
target_key, value = tuple(input_dict.items())[0]
value = value[0] if isinstance(value, list) else value
full_name = target_key
# update param name to get the weights instead of the quantized stats
target_key = self.hf_quantizer.get_param_name(target_key)
module, _ = get_module_from_name(model, target_key)
from torchao.quantization import quantize_
# Those are the pre quantized weights
if ":" in target_key:
target_key = target_key.rsplit(":", 1)[0]
module, tensor_name = get_module_from_name(model, target_key)
if self.hf_quantizer.pre_quantized:
# If it's a bias, no need to do anything special (except removing the ":_data" part of the key, but was
# already done) - if it's unsafe-serialized (i.e. not safetensors), not need for anything either
is_unsafe_serialization = ":" not in full_name
if tensor_name == "bias" or is_unsafe_serialization:
return {target_key: value}
# Sanity check for the new serialization format
elif not (TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(self.hf_quantizer.metadata)):
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")
# Save the states for later quantization when they are all gathered
if not hasattr(self.hf_quantizer, "ao_params"):
self.hf_quantizer.ao_params = defaultdict(dict)
self.hf_quantizer.ao_params[target_key].update({full_name: value})
# We are ready for quantization in this case (we retrieved all the needed keys)
if len(self.hf_quantizer.ao_params[target_key]) == len(self.hf_quantizer.weight_ao_keys):
new_param = unflatten_tensor_state_dict(
self.hf_quantizer.ao_params[target_key], self.hf_quantizer.metadata
)[target_key]
del self.hf_quantizer.ao_params[target_key]
return {target_key: new_param}
# Add repr to the module
if isinstance(module, torch.nn.Linear):
module.extra_repr = types.MethodType(self.hf_quantizer._linear_extra_repr, module)
return {}
else:
module._parameters[tensor_name] = torch.nn.Parameter(value, requires_grad=value.requires_grad).to(
value.device
)
# if we are quantizing tied parameters, to avoid tying the quantized weights
# the correct order to do it is
# 1. load the weight to model
# 2. run tie_weights to populate the weights
# 3. quantize
mm: Any = model
input_embed = mm.get_input_embeddings() if hasattr(mm, "get_input_embeddings") else None
if self.hf_quantizer.quantization_config.untie_embedding_weights and id(module) == id(input_embed):
if hasattr(mm, "tie_weights"):
mm.tie_weights()
if hasattr(mm, "config") and hasattr(mm.config, "get_text_config"):
setattr(mm.config.get_text_config(decoder=True), "tie_word_embeddings", False)
# handle ModuleFqnToConfig, introduced in torchao 0.12.0+
if self.hf_quantizer.quantization_config._get_ao_version() >= version.Version("0.12.0"):
from torchao.quantization import ModuleFqnToConfig
config = self.hf_quantizer.quantization_config.get_apply_tensor_subclass()
if isinstance(config, ModuleFqnToConfig):
module_fqn, _ = target_key.rsplit(".", 1)
c = None
if module_fqn in config.module_fqn_to_config:
assert not module_fqn.startswith("re:"), (
"module fqn should not start with`re:`, which is used for specifying regex"
)
c = config.module_fqn_to_config[module_fqn]
else:
for maybe_module_fqn_pattern in config.module_fqn_to_config:
if not maybe_module_fqn_pattern.startswith("re:"):
continue
elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
# we'll apply the config for first fully matched pattern
c = config.module_fqn_to_config[maybe_module_fqn_pattern]
break
else:
c = config.module_fqn_to_config.get("_default", None)
if c is not None:
# filter_fn: not filtering out any modules
quantize_(module, c, filter_fn=lambda x, fqn: True)
module._is_hf_initialized = True
missing_keys.discard(target_key)
return {}
quantize_(module, self.hf_quantizer.quantization_config.get_apply_tensor_subclass())
module._is_hf_initialized = True
missing_keys.discard(target_key)
return {}

View File

@ -82,10 +82,8 @@ def causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int)
def bidirectional_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
"""
This creates a full bidirectional mask.
NOTE: It is important to keep an index-based version for non-vmap expansion.
"""
return q_idx >= 0
return q_idx.new_ones((), dtype=torch.bool)
def sliding_window_overlay(sliding_window: int) -> Callable:
@ -112,6 +110,18 @@ def chunked_overlay(chunk_size: int, left_padding: torch.Tensor) -> Callable:
return inner_mask
def _legacy_chunked_overlay(chunk_size: int) -> Callable:
"""
Same as the above function, but do not correctly account for left padding tokens.
Only kept for compatibility with older torch versions (< 2.6).
"""
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
return kv_idx // chunk_size == q_idx // chunk_size
return inner_mask
def sliding_window_causal_mask_function(sliding_window: int) -> Callable:
"""
This return the mask_function function to create a sliding window mask.
@ -123,6 +133,8 @@ def chunked_causal_mask_function(chunk_size: int, left_padding: torch.Tensor) ->
"""
This return the mask_function function to create a chunked attention mask.
"""
if not _is_torch_greater_or_equal_than_2_6:
return and_masks(_legacy_chunked_overlay(chunk_size), causal_mask_function)
return and_masks(chunked_overlay(chunk_size, left_padding), causal_mask_function)
@ -163,17 +175,52 @@ def add_offsets_to_mask_function(mask_function: Callable, q_offset: int, kv_offs
return inner_mask
def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable:
"""
Used to vmap our mask_functions over the q_idx and kv_idx dimensions of the inputs. Optionally, vmap over
the batch and head indices as well if `bh_indices=True`.
Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive
functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different).
Args:
mask_function (`Callable`):
The mask_function to vmap.
bh_indices (`bool`, optional):
Whether to vmap over the batch and head indices as well, or only q and kv indices.
Returns:
Callable: The vmapped function.
"""
# We vmap the function 2 times, broadcasting the [q_idx, kv_idx] dimensions
dimensions = [(None, None, None, 0), (None, None, 0, None)]
if bh_indices:
# We extend broadcasting over the [batch_idx, head_idx] dimensions
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
for dims in dimensions:
mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
return mask_function
def prepare_padding_mask(
attention_mask: Optional[torch.Tensor], kv_length: int, kv_offset: int
attention_mask: Optional[torch.Tensor], kv_length: int, kv_offset: int, _slice: bool = True
) -> Optional[torch.Tensor]:
"""
From the 2D attention mask, prepare the correct padding mask to use by potentially padding it.
From the 2D attention mask, prepare the correct padding mask to use by potentially padding it, and slicing
according to the `kv_offset` if `_slice` is `True`.
"""
local_padding_mask = attention_mask
if attention_mask is not None:
# Pad it if necessary
if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0:
local_padding_mask = torch.nn.functional.pad(attention_mask, (0, padding_length))
# For flex, we should not slice them, only use an offset
if _slice:
# Equivalent to: `local_padding_mask = attention_mask[:, kv_offset : kv_offset + kv_length]`,
# but without data-dependent slicing (i.e. torch.compile friendly)
mask_indices = torch.arange(kv_length, device=local_padding_mask.device)
mask_indices += kv_offset
local_padding_mask = local_padding_mask[:, mask_indices]
return local_padding_mask
@ -235,39 +282,7 @@ def _ignore_bidirectional_mask_sdpa(padding_mask: Optional[torch.Tensor]) -> boo
return False
def _vmap_expansion_sdpa(mask_function: Callable) -> Callable:
"""
Used to vmap our mask_functions over the all 4 dimensions (b_idx, h_idx, q_idx, kv_idx) of the inputs.
Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive
functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different).
"""
# We vmap the function over all 4 dimensions, broadcasting [b_idx, h_idx, q_idx, kv_idx]
dimensions = [(None, None, None, 0), (None, None, 0, None), (None, 0, None, None), (0, None, None, None)]
for dims in dimensions:
mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0)
return mask_function
def _non_vmap_expansion_sdpa(
batch_indices: torch.Tensor, head_indices: torch.Tensor, q_indices: torch.Tensor, kv_indices: torch.Tensor
):
"""
Used to broadcast our mask_functions over the all 4 dimensions (b_idx, h_idx, q_idx, kv_idx) of the inputs.
Allows the usage of any index-based mask function without relying on vmap.
NOTE: This is limited to index based functions only and is not guaranteed to work otherwise.
Reference:
- https://github.com/huggingface/optimum-onnx/blob/c123e8f4fab61b54a8e0e31ce74462bcacca576e/optimum/exporters/onnx/model_patcher.py#L362-L365
"""
batch_indices = batch_indices[:, None, None, None]
head_indices = head_indices[None, :, None, None]
q_indices = q_indices[None, None, :, None]
kv_indices = kv_indices[None, None, None, :]
return batch_indices, head_indices, q_indices, kv_indices
def sdpa_mask(
def sdpa_mask_recent_torch(
batch_size: int,
cache_position: torch.Tensor,
kv_length: int,
@ -277,8 +292,6 @@ def sdpa_mask(
local_size: Optional[int] = None,
allow_is_causal_skip: bool = True,
allow_is_bidirectional_skip: bool = False,
allow_torch_fix: bool = True,
use_vmap: bool = False,
**kwargs,
) -> Optional[torch.Tensor]:
"""
@ -311,12 +324,6 @@ def sdpa_mask(
allow_is_bidirectional_skip (`bool`, optional):
Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
i.e. full attention without any padding. Default to `False`.
allow_torch_fix (`bool`, optional):
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
versions. We need an arg to skip it when using eager. By default `True`.
use_vmap (`bool`, optional):
Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
index-based (for the cost of speed performance). By default `False`.
## Creating a simple causal mask:
@ -384,8 +391,97 @@ def sdpa_mask(
"""
q_length = cache_position.shape[0]
# Potentially pad the 2D mask, and slice it correctly
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
# Potentially pad the 2D mask
# Under specific conditions, we can avoid materializing the mask
# 1. Causal masks can rely on the `is_causal` argument
# 2. Bidirectional do not need any further processing (no bias)
if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size):
return None
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask):
return None
# vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the
# padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function
if mask_function is bidirectional_mask_function:
if padding_mask is not None:
# used for slicing without data-dependent slicing
mask_indices = torch.arange(kv_length, device=cache_position.device) + kv_offset
return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1)
else:
return torch.ones(batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device)
# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
# but without data-dependent slicing (i.e. torch.compile friendly)
kv_arange = torch.arange(kv_length, device=cache_position.device)
kv_arange += kv_offset
# Potentially add the padding 2D mask
if padding_mask is not None:
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
batch_arange = torch.arange(batch_size, device=cache_position.device)
head_arange = torch.arange(1, device=cache_position.device)
# This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
# scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
# We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
with TransformGetItemToIndex():
causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
return causal_mask
def sdpa_mask_older_torch(
batch_size: int,
cache_position: torch.Tensor,
kv_length: int,
kv_offset: int = 0,
mask_function: Callable = causal_mask_function,
attention_mask: Optional[torch.Tensor] = None,
local_size: Optional[int] = None,
allow_is_causal_skip: bool = True,
allow_torch_fix: bool = True,
allow_is_bidirectional_skip: bool = False,
**kwargs,
) -> Optional[torch.Tensor]:
"""
NOTE: This function is only used when torch version is torch<2.5 - see `sdpa_mask_recent_torch` otherwise.
Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
the element should take part in the attention computation, and False that it should not.
If `allow_torch_fix=True` (the default), rows corresponding to query tokens that do not attend
to any other tokens (due to padding) will be fully attended to instead, in order to avoid `nan` propagation (this does
not change the final result).
Args:
batch_size (`int`):
The batch size of the input sequence.
cache_position (`torch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
kv_length (`int`):
The size that the key and value states will have during the attention computation.
kv_offset (`int`, optional):
An optional offset to indicate at which first position the key and values states will refer to.
mask_function (`Callable`):
The mask factory function describing the mask pattern.
attention_mask (`torch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
local_size (`int`, optional):
The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
to try to skip mask creation if possible.
allow_is_causal_skip (`bool`, optional):
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
`torch.sdpa` instead. Default to `True`.
allow_torch_fix (`bool`, optional):
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
versions. We need an arg to skip it when using eager. By default `True`.
allow_is_bidirectional_skip (`bool`, optional):
Whether to allow to return `None` for the mask under conditions where we do not have to add any bias,
i.e. full attention without any padding. Default to `False`.
"""
q_length = cache_position.shape[0]
# Potentially pad the 2D mask, and slice it correctly
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
# Under specific conditions, we can avoid materializing the mask
@ -396,45 +492,38 @@ def sdpa_mask(
if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask):
return None
# Potentially add the padding 2D mask
if padding_mask is not None:
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
# vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the
# padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function
if mask_function is bidirectional_mask_function:
if padding_mask is not None:
return padding_mask[:, None, None, :].expand(-1, -1, q_length, -1)
else:
return torch.ones(batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device)
batch_arange = torch.arange(batch_size, device=cache_position.device)
head_arange = torch.arange(1, device=cache_position.device)
# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
# but without data-dependent slicing (i.e. torch.compile friendly)
kv_arange = torch.arange(kv_length, device=cache_position.device) + kv_offset
kv_arange = torch.arange(kv_length, device=cache_position.device)
kv_arange += kv_offset
# Actual mask creation
# Option 1: Fast non-vmap mask creation (default)
if not use_vmap:
# Apply mask function element-wise through broadcasting
attention_mask = mask_function(*_non_vmap_expansion_sdpa(batch_arange, head_arange, cache_position, kv_arange))
# Expand the mask to match batch size and query length if they weren't used in the mask function
attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length)
# Option 2: Vmap mask creation (torch>=2.6 and custom patterns)
elif _is_torch_greater_or_equal_than_2_6:
# This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from
# scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it
# We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices
with TransformGetItemToIndex():
attention_mask = _vmap_expansion_sdpa(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
# Option 3: Error out since it indicates that the user did something custom, which they shouldn't have (torch<2.6)
else:
raise ValueError(
"The vmap functionality for mask creation is only supported from torch>=2.6. "
"Please update your torch version or use `use_vmap=False` with index-based masks."
)
# This creates the 4D mask easily. Note that we do not include vmap over the batch_idx dimension as well,
# as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow
# However, in more recent version of Pytorch, a trick was introduced to handle it - which is the reason we have
# `sdpa_mask_recent_torch`, as it allows more general `mask_function`
causal_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)
if padding_mask is not None:
causal_mask = causal_mask * padding_mask[:, None, None, :]
# Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any
# tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
attention_mask = attention_mask | torch.all(~attention_mask, dim=-1, keepdim=True)
causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
return causal_mask
return attention_mask
# We use the version with newer torch whenever possible, as it is more general and can handle arbitrary mask functions
# (especially mask_function indexing a tensor, such as the padding mask function)
sdpa_mask = sdpa_mask_recent_torch if _is_torch_greater_or_equal_than_2_6 else sdpa_mask_older_torch
def eager_mask(
@ -445,7 +534,6 @@ def eager_mask(
mask_function: Callable = causal_mask_function,
attention_mask: Optional[torch.Tensor] = None,
dtype: torch.dtype = torch.float32,
use_vmap: bool = False,
**kwargs,
) -> torch.Tensor:
"""
@ -468,14 +556,10 @@ def eager_mask(
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
dtype (`torch.dtype`, optional):
The dtype to use for the mask. By default, `torch.float32`.
use_vmap (`bool`, optional):
Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be
index-based (for the cost of speed performance). By default `False`.
"""
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
_ = kwargs.pop("allow_is_causal_skip", None)
_ = kwargs.pop("allow_is_bidirectional_skip", None)
_ = kwargs.pop("allow_torch_fix", None)
mask = sdpa_mask(
batch_size=batch_size,
cache_position=cache_position,
@ -486,7 +570,6 @@ def eager_mask(
allow_is_causal_skip=False,
allow_is_bidirectional_skip=False,
allow_torch_fix=False,
use_vmap=use_vmap,
**kwargs,
)
min_dtype = torch.finfo(dtype).min
@ -572,7 +655,7 @@ def flex_attention_mask(
if not _is_torch_greater_or_equal_than_2_6 and pad_len > 0:
attention_mask = torch.nn.functional.pad(attention_mask, value=0, pad=(0, pad_len))
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
mask_function = and_masks(mask_function, padding_mask_function(padding_mask))
# Add the offsets on top (because flex interface only allows length, not start and end indices)
@ -768,11 +851,6 @@ def create_causal_mask(
mask_factory_function = causal_mask_function
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
# Defaulting to using non-vmap based mask creations except when detecting
# users passing custom mask functions (as we cannot guarantee that they
# are properly index-based as required by our implementation).
use_vmap = False
# Do not allow skip if we are compiling (this is to match BC)
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
if _is_torch_xpu_available:
@ -789,16 +867,14 @@ def create_causal_mask(
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
allow_is_causal_skip = False
use_vmap = True
if and_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
allow_is_causal_skip = False
use_vmap = True
# If we detected packing format
if packed_sequence_mask is not None:
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
allow_is_causal_skip = False
@ -813,7 +889,6 @@ def create_causal_mask(
allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa
dtype=dtype, # Additional kwarg for eager
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
)
return causal_mask
@ -867,10 +942,6 @@ def create_bidirectional_mask(
# Allow skipping the mask creation except we have additional masking operators (and/or masks)
allow_is_bidirectional_skip = True
# Defaulting to using non-vmap based mask creations except when detecting
# users passing custom mask functions (as we cannot guarantee that they
# are properly index-based as required by our implementation).
use_vmap = False
# Allow slight deviations from the base mask
# Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask,
@ -880,13 +951,11 @@ def create_bidirectional_mask(
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
allow_is_bidirectional_skip = False
use_vmap = True
if and_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
allow_is_bidirectional_skip = False
use_vmap = True
# We now create the mask
attention_mask = mask_interface(
@ -901,7 +970,6 @@ def create_bidirectional_mask(
allow_is_bidirectional_skip=allow_is_bidirectional_skip,
dtype=dtype, # Additional kwarg for eager
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
)
return attention_mask
@ -964,10 +1032,6 @@ def create_sliding_window_causal_mask(
mask_factory_function = sliding_window_causal_mask_function(sliding_window)
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
# Defaulting to using non-vmap based mask creations except when detecting
# users passing custom mask functions (as we cannot guarantee that they
# are properly index-based as required by our implementation).
use_vmap = False
# Do not allow skip if we are compiling (this is to match BC)
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
@ -980,16 +1044,14 @@ def create_sliding_window_causal_mask(
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
allow_is_causal_skip = False
use_vmap = True
if and_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
allow_is_causal_skip = False
use_vmap = True
# If we detected packing format
if packed_sequence_mask is not None:
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
allow_is_causal_skip = False
@ -1005,7 +1067,6 @@ def create_sliding_window_causal_mask(
local_size=sliding_window, # Additional kwarg for sdpa
dtype=dtype, # Additional kwarg for eager
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
)
return causal_mask
@ -1079,13 +1140,20 @@ def create_chunked_causal_mask(
left_padding_tokens = (attention_mask.cumsum(dim=-1) == torch.zeros_like(attention_mask)).sum(dim=-1)
else:
left_padding_tokens = torch.zeros(batch_size, device=cache_position.device, dtype=int)
# Raise a warning for older versions if the problematic left-padding situation arises
if (
not _is_torch_greater_or_equal_than_2_6
and kv_length + kv_offset > chunk_size
and (left_padding_tokens > 0).any()
):
logger.warning_once(
"Due to limitations of your current torch version, we cannot correctly account for the left-padding "
"when computing the chunked attention pattern. This will lead to a wrong attention mask for the padded "
"sequences. Behavior will be undefined. Please upgrade to `torch>=2.6` to solve this issue."
)
mask_factory_function = chunked_causal_mask_function(chunk_size, left_padding_tokens)
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
# Defaulting to using non-vmap based mask creations except when detecting
# users passing custom mask functions (as we cannot guarantee that they
# are properly index-based as required by our implementation).
use_vmap = False
# Do not allow skip if we are compiling (this is to match BC)
# TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it
allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False)
@ -1098,16 +1166,14 @@ def create_chunked_causal_mask(
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = or_masks(mask_factory_function, or_mask_function)
allow_is_causal_skip = False
use_vmap = True
if and_mask_function is not None:
if not _is_torch_greater_or_equal_than_2_6:
raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6")
mask_factory_function = and_masks(mask_factory_function, and_mask_function)
allow_is_causal_skip = False
use_vmap = True
# If we detected packing format
if packed_sequence_mask is not None:
if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6:
mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask))
allow_is_causal_skip = False
@ -1123,7 +1189,6 @@ def create_chunked_causal_mask(
local_size=chunk_size, # Additional kwarg for sdpa
dtype=dtype, # Additional kwarg for eager
config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface
use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask
)
return causal_mask

File diff suppressed because it is too large Load Diff

View File

@ -406,13 +406,14 @@ class Aimv2PreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_flex_attn = True
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if hasattr(module, "logit_scale"):
if isinstance(module.logit_scale, nn.Parameter):
module.logit_scale.data.fill_(math.log(1 / 0.07))
module.logit_scale.fill_(math.log(1 / 0.07))
elif isinstance(module, Aimv2AttentionPoolingHead):
module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range)
module.cls_token.normal_(mean=0.0, std=self.config.initializer_range)
@auto_docstring(

View File

@ -449,13 +449,14 @@ class Aimv2PreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_flex_attn = True
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if hasattr(module, "logit_scale"):
if isinstance(module.logit_scale, nn.Parameter):
module.logit_scale.data.fill_(math.log(1 / 0.07))
module.logit_scale.fill_(math.log(1 / 0.07))
elif isinstance(module, Aimv2AttentionPoolingHead):
module.cls_token.data.normal_(mean=0.0, std=self.config.initializer_range)
module.cls_token.normal_(mean=0.0, std=self.config.initializer_range)
@auto_docstring(

View File

@ -302,21 +302,22 @@ class AlbertPreTrainedModel(PreTrainedModel):
"attentions": AlbertAttention,
}
@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
module.weight[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, AlbertMLMHead):
module.bias.data.zero_()
module.bias.zero_()
@dataclass
@ -425,7 +426,10 @@ class AlbertModel(AlbertPreTrainedModel):
"""
)
class AlbertForPreTraining(AlbertPreTrainedModel):
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
_tied_weights_keys = {
"predictions.decoder.weight": "albert.embeddings.word_embeddings.weight",
"predictions.decoder.bias": "predictions.bias",
}
def __init__(self, config: AlbertConfig):
super().__init__(config)
@ -525,7 +529,6 @@ class AlbertMLMHead(nn.Module):
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
self.activation = ACT2FN[config.hidden_act]
self.decoder.bias = self.bias
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
@ -537,14 +540,6 @@ class AlbertMLMHead(nn.Module):
return prediction_scores
def _tie_weights(self) -> None:
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
class AlbertSOPHead(nn.Module):
def __init__(self, config: AlbertConfig):
@ -561,7 +556,10 @@ class AlbertSOPHead(nn.Module):
@auto_docstring
class AlbertForMaskedLM(AlbertPreTrainedModel):
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
_tied_weights_keys = {
"predictions.decoder.weight": "albert.embeddings.word_embeddings.weight",
"predictions.decoder.bias": "predictions.bias",
}
def __init__(self, config):
super().__init__(config)

View File

@ -823,24 +823,25 @@ class AlignPreTrainedModel(PreTrainedModel):
input_modalities = ["image", "text"]
supports_gradient_checkpointing = True
@torch.no_grad()
def _init_weights(self, module: nn.Module):
"""Initialize the weights"""
std = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, AlignModel):
nn.init.xavier_uniform_(module.text_projection.weight)
module.text_projection.bias.data.zero_()
module.temperature.data.fill_(self.config.temperature_init_value)
module.text_projection.bias.zero_()
module.temperature.fill_(self.config.temperature_init_value)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
module.weight[module.padding_idx].zero_()
if isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
@auto_docstring(

View File

@ -59,6 +59,9 @@ class AlignProcessor(ProcessorMixin):
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "EfficientNetImageProcessor"
tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
valid_processor_kwargs = AlignProcessorKwargs
def __init__(self, image_processor, tokenizer):

View File

@ -770,6 +770,7 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_module = []
@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_factor
@ -797,23 +798,21 @@ class AltCLIPPreTrainedModel(PreTrainedModel):
module.text_projection.weight,
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
)
module.text_projection._is_hf_initialized = True
nn.init.normal_(
module.visual_projection.weight,
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
)
module.visual_projection._is_hf_initialized = True
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor)
module.weight.normal_(mean=0.0, std=self.config.initializer_factor)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor)
module.weight.normal_(mean=0.0, std=self.config.initializer_factor)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
module.weight[module.padding_idx].zero_()
class AltCLIPVisionTransformer(nn.Module):

View File

@ -35,6 +35,10 @@ class AltCLIPProcessor(ProcessorMixin):
The tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = ("CLIPImageProcessor", "CLIPImageProcessorFast")
tokenizer_class = ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast")
@deprecate_kwarg(old_name="feature_extractor", version="5.0.0", new_name="image_processor")
def __init__(self, image_processor=None, tokenizer=None):
super().__init__(image_processor, tokenizer)

View File

@ -429,7 +429,7 @@ class ApertusModel(ApertusPreTrainedModel):
@auto_docstring
class ApertusForCausalLM(ApertusPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

View File

@ -434,7 +434,7 @@ class ArceeModel(ArceePreTrainedModel):
@auto_docstring(checkpoint="arcee-ai/AFM-4.5B")
class ArceeForCausalLM(ArceePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

View File

@ -585,10 +585,11 @@ class AriaTextPreTrainedModel(PreTrainedModel):
"attentions": AriaTextAttention,
}
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, AriaGroupedExpertsGemm):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
@auto_docstring
@ -608,6 +609,7 @@ class AriaPreTrainedModel(PreTrainedModel):
"attentions": AriaTextAttention,
}
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, AriaProjector):
@ -760,7 +762,7 @@ class AriaTextModel(AriaTextPreTrainedModel):
@auto_docstring
class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
@ -1053,7 +1055,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
def __init__(self, config: AriaConfig):
super().__init__(config)

View File

@ -906,6 +906,10 @@ class AriaProcessor(ProcessorMixin):
A dictionary indicating size conversions for images.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AriaImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor=None,
@ -1187,10 +1191,11 @@ class AriaTextPreTrainedModel(PreTrainedModel):
"attentions": AriaTextAttention,
}
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, AriaGroupedExpertsGemm):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
class AriaPreTrainedModel(LlamaPreTrainedModel):
@ -1199,6 +1204,7 @@ class AriaPreTrainedModel(LlamaPreTrainedModel):
_can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
_supports_attention_backend = True
@torch.no_grad()
def _init_weights(self, module):
PreTrainedModel._init_weights(self, module)
if isinstance(module, AriaProjector):
@ -1216,7 +1222,7 @@ class AriaTextModel(LlamaModel):
class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
def __init__(self, config: AriaTextConfig):
super().__init__(config)
@ -1355,6 +1361,8 @@ class AriaModel(LlavaModel):
"""
)
class AriaForConditionalGeneration(LlavaForConditionalGeneration):
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
def get_image_features(
self,
pixel_values: torch.FloatTensor,

View File

@ -67,6 +67,10 @@ class AriaProcessor(ProcessorMixin):
A dictionary indicating size conversions for images.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AriaImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor=None,

View File

@ -300,23 +300,26 @@ class ASTPreTrainedModel(PreTrainedModel):
"attentions": ASTSelfAttention,
}
@torch.no_grad()
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
).to(module.weight.dtype)
module.weight.copy_(
nn.init.trunc_normal_(module.weight.to(torch.float32), mean=0.0, std=self.config.initializer_range).to(
module.weight.dtype
)
)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, ASTEmbeddings):
module.cls_token.data.zero_()
module.position_embeddings.data.zero_()
module.distillation_token.data.zero_()
module.cls_token.zero_()
module.position_embeddings.zero_()
module.distillation_token.zero_()
@auto_docstring

View File

@ -223,7 +223,6 @@ CONFIG_MAPPING_NAMES = OrderedDict[str, str](
("layoutlm", "LayoutLMConfig"),
("layoutlmv2", "LayoutLMv2Config"),
("layoutlmv3", "LayoutLMv3Config"),
("layoutxlm", "LayoutLMv2Config"),
("led", "LEDConfig"),
("levit", "LevitConfig"),
("lfm2", "Lfm2Config"),

View File

@ -41,7 +41,6 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("audio-spectrogram-transformer", "ASTFeatureExtractor"),
("clap", "ClapFeatureExtractor"),
("clvp", "ClvpFeatureExtractor"),
("csm", "EncodecFeatureExtractor"),
("dac", "DacFeatureExtractor"),
("data2vec-audio", "Wav2Vec2FeatureExtractor"),
("dia", "DiaFeatureExtractor"),
@ -50,20 +49,14 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("granite_speech", "GraniteSpeechFeatureExtractor"),
("hubert", "Wav2Vec2FeatureExtractor"),
("kyutai_speech_to_text", "KyutaiSpeechToTextFeatureExtractor"),
("markuplm", "MarkupLMFeatureExtractor"),
("mctct", "MCTCTFeatureExtractor"),
("mimi", "EncodecFeatureExtractor"),
("moonshine", "Wav2Vec2FeatureExtractor"),
("moshi", "EncodecFeatureExtractor"),
("musicgen", "EncodecFeatureExtractor"),
("musicgen_melody", "MusicgenMelodyFeatureExtractor"),
("parakeet_ctc", "ParakeetFeatureExtractor"),
("parakeet_encoder", "ParakeetFeatureExtractor"),
("phi4_multimodal", "Phi4MultimodalFeatureExtractor"),
("pop2piano", "Pop2PianoFeatureExtractor"),
("qwen2_5_omni", "WhisperFeatureExtractor"),
("qwen2_audio", "WhisperFeatureExtractor"),
("qwen3_omni_moe", "WhisperFeatureExtractor"),
("seamless_m4t", "SeamlessM4TFeatureExtractor"),
("seamless_m4t_v2", "SeamlessM4TFeatureExtractor"),
("sew", "Wav2Vec2FeatureExtractor"),
@ -73,7 +66,6 @@ FEATURE_EXTRACTOR_MAPPING_NAMES = OrderedDict(
("unispeech", "Wav2Vec2FeatureExtractor"),
("unispeech-sat", "Wav2Vec2FeatureExtractor"),
("univnet", "UnivNetFeatureExtractor"),
("voxtral", "WhisperFeatureExtractor"),
("wav2vec2", "Wav2Vec2FeatureExtractor"),
("wav2vec2-bert", "Wav2Vec2FeatureExtractor"),
("wav2vec2-conformer", "Wav2Vec2FeatureExtractor"),

View File

@ -62,9 +62,7 @@ else:
("aimv2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("aimv2_vision_model", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("align", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
("altclip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("aria", ("AriaImageProcessor", None)),
("aya_vision", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
("beit", ("BeitImageProcessor", "BeitImageProcessorFast")),
("bit", ("BitImageProcessor", "BitImageProcessorFast")),
("blip", ("BlipImageProcessor", "BlipImageProcessorFast")),
@ -75,8 +73,6 @@ else:
("clip", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("clipseg", ("ViTImageProcessor", "ViTImageProcessorFast")),
("cohere2_vision", (None, "Cohere2VisionImageProcessorFast")),
("colpali", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
("colqwen2", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
("conditional_detr", ("ConditionalDetrImageProcessor", "ConditionalDetrImageProcessorFast")),
("convnext", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("convnextv2", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
@ -99,10 +95,8 @@ else:
("efficientformer", ("EfficientFormerImageProcessor", None)),
("efficientloftr", ("EfficientLoFTRImageProcessor", "EfficientLoFTRImageProcessorFast")),
("efficientnet", ("EfficientNetImageProcessor", "EfficientNetImageProcessorFast")),
("emu3", ("Emu3ImageProcessor", None)),
("eomt", ("EomtImageProcessor", "EomtImageProcessorFast")),
("flava", ("FlavaImageProcessor", "FlavaImageProcessorFast")),
("florence2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("focalnet", ("BitImageProcessor", "BitImageProcessorFast")),
("fuyu", ("FuyuImageProcessor", "FuyuImageProcessorFast")),
("gemma3", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
@ -120,13 +114,11 @@ else:
("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
("imagegpt", ("ImageGPTImageProcessor", "ImageGPTImageProcessorFast")),
("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
("internvl", ("GotOcr2ImageProcessor", "GotOcr2ImageProcessorFast")),
("janus", ("JanusImageProcessor", "JanusImageProcessorFast")),
("kosmos-2", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("kosmos-2.5", ("Kosmos2_5ImageProcessor", "Kosmos2_5ImageProcessorFast")),
("layoutlmv2", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessorFast")),
("layoutlmv3", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
("layoutxlm", ("LayoutLMv2ImageProcessor", "LayoutLMv2ImageProcessor")),
("levit", ("LevitImageProcessor", "LevitImageProcessorFast")),
("lfm2_vl", (None, "Lfm2VlImageProcessorFast")),
("lightglue", ("LightGlueImageProcessor", "LightGlueImageProcessorFast")),
@ -149,7 +141,6 @@ else:
("mobilevitv2", ("MobileViTImageProcessor", "MobileViTImageProcessorFast")),
("nat", ("ViTImageProcessor", "ViTImageProcessorFast")),
("nougat", ("NougatImageProcessor", "NougatImageProcessorFast")),
("omdet-turbo", ("DetrImageProcessor", "DetrImageProcessorFast")),
("oneformer", ("OneFormerImageProcessor", "OneFormerImageProcessorFast")),
("ovis2", ("Ovis2ImageProcessor", "Ovis2ImageProcessorFast")),
("owlv2", ("Owlv2ImageProcessor", "Owlv2ImageProcessorFast")),
@ -164,17 +155,14 @@ else:
("prompt_depth_anything", ("PromptDepthAnythingImageProcessor", "PromptDepthAnythingImageProcessorFast")),
("pvt", ("PvtImageProcessor", "PvtImageProcessorFast")),
("pvt_v2", ("PvtImageProcessor", "PvtImageProcessorFast")),
("qwen2_5_omni", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
("qwen2_5_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
("qwen2_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
("qwen3_omni_moe", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
("qwen3_vl", ("Qwen2VLImageProcessor", "Qwen2VLImageProcessorFast")),
("regnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("resnet", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("rt_detr", ("RTDetrImageProcessor", "RTDetrImageProcessorFast")),
("sam", ("SamImageProcessor", "SamImageProcessorFast")),
("sam2", (None, "Sam2ImageProcessorFast")),
("sam2_video", (None, "Sam2ImageProcessorFast")),
("sam_hq", ("SamImageProcessor", "SamImageProcessorFast")),
("segformer", ("SegformerImageProcessor", "SegformerImageProcessorFast")),
("seggpt", ("SegGptImageProcessor", None)),
@ -192,14 +180,12 @@ else:
("textnet", ("TextNetImageProcessor", "TextNetImageProcessorFast")),
("timesformer", ("VideoMAEImageProcessor", None)),
("timm_wrapper", ("TimmWrapperImageProcessor", None)),
("trocr", ("ViTImageProcessor", "ViTImageProcessorFast")),
("tvlt", ("TvltImageProcessor", None)),
("tvp", ("TvpImageProcessor", "TvpImageProcessorFast")),
("udop", ("LayoutLMv3ImageProcessor", "LayoutLMv3ImageProcessorFast")),
("upernet", ("SegformerImageProcessor", "SegformerImageProcessorFast")),
("van", ("ConvNextImageProcessor", "ConvNextImageProcessorFast")),
("video_llama_3", ("VideoLlama3ImageProcessor", "VideoLlama3ImageProcessorFast")),
("video_llava", ("VideoLlavaImageProcessor", None)),
("videomae", ("VideoMAEImageProcessor", None)),
("vilt", ("ViltImageProcessor", "ViltImageProcessorFast")),
("vipllava", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
@ -538,9 +524,10 @@ class AutoImageProcessor:
)
use_fast = False
if use_fast:
# Check if the fast image processor class exists
image_processor_class_fast = get_image_processor_class_from_name(image_processor_type)
if image_processor_class_fast is None:
for image_processors in IMAGE_PROCESSOR_MAPPING_NAMES.values():
if image_processor_type in image_processors:
break
else:
image_processor_type = image_processor_type[:-4]
use_fast = False
logger.warning_once(

View File

@ -107,7 +107,6 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("mllama", "MllamaProcessor"),
("mm-grounding-dino", "GroundingDinoProcessor"),
("moonshine", "Wav2Vec2Processor"),
("omdet-turbo", "OmDetTurboProcessor"),
("oneformer", "OneFormerProcessor"),
("ovis2", "Ovis2Processor"),
("owlv2", "Owlv2Processor"),

View File

@ -72,7 +72,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
),
),
("align", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("altclip", ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast" if is_tokenizers_available() else None)),
("arcee", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("aria", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("aya_vision", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
@ -157,7 +156,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("codegen", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
("cohere", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
("cohere2", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
("cohere2_vision", (None, "CohereTokenizerFast" if is_tokenizers_available() else None)),
("colpali", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("colqwen2", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
("convbert", ("ConvBertTokenizer", "ConvBertTokenizerFast" if is_tokenizers_available() else None)),
@ -226,7 +224,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
),
),
("distilbert", ("DistilBertTokenizer", "DistilBertTokenizerFast" if is_tokenizers_available() else None)),
("donut", ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast" if is_tokenizers_available() else None)),
(
"dpr",
(
@ -241,7 +238,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("ernie4_5_moe", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("ernie_m", ("ErnieMTokenizer" if is_sentencepiece_available() else None, None)),
("esm", ("EsmTokenizer", None)),
("evolla", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
(
"exaone4",
(
@ -256,13 +252,10 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("FastSpeech2ConformerTokenizer" if is_g2p_en_available() else None, None),
),
("flaubert", ("FlaubertTokenizer", None)),
("flava", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("flex_olmo", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("florence2", ("BartTokenizer", "BartTokenizerFast" if is_tokenizers_available() else None)),
("fnet", ("FNetTokenizer", "FNetTokenizerFast" if is_tokenizers_available() else None)),
("fsmt", ("FSMTTokenizer", None)),
("funnel", ("FunnelTokenizer", "FunnelTokenizerFast" if is_tokenizers_available() else None)),
("fuyu", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
(
"gemma",
(
@ -311,7 +304,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("glm4_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("glm4v", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("glm4v_moe", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("got_ocr2", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("gpt-sw3", ("GPTSw3Tokenizer" if is_sentencepiece_available() else None, None)),
("gpt2", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("gpt_bigcode", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
@ -322,7 +314,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("gptj", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("gptsan-japanese", ("GPTSanJapaneseTokenizer", None)),
("granite", ("GPT2Tokenizer", None)),
("granite_speech", ("GPT2Tokenizer", None)),
("granitemoe", ("GPT2Tokenizer", None)),
("granitemoehybrid", ("GPT2Tokenizer", None)),
("granitemoeshared", ("GPT2Tokenizer", None)),
@ -362,14 +353,11 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
),
),
("kosmos-2.5", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("kyutai_speech_to_text", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("layoutlm", ("LayoutLMTokenizer", "LayoutLMTokenizerFast" if is_tokenizers_available() else None)),
("layoutlmv2", ("LayoutLMv2Tokenizer", "LayoutLMv2TokenizerFast" if is_tokenizers_available() else None)),
("layoutlmv3", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
("layoutxlm", ("LayoutXLMTokenizer", "LayoutXLMTokenizerFast" if is_tokenizers_available() else None)),
("led", ("LEDTokenizer", "LEDTokenizerFast" if is_tokenizers_available() else None)),
("lfm2", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("lfm2_vl", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("lilt", ("LayoutLMv3Tokenizer", "LayoutLMv3TokenizerFast" if is_tokenizers_available() else None)),
(
"llama",
@ -410,7 +398,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
("mamba2", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)),
("markuplm", ("MarkupLMTokenizer", "MarkupLMTokenizerFast" if is_tokenizers_available() else None)),
(
"mbart",
(
@ -497,7 +484,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
"NllbTokenizerFast" if is_tokenizers_available() else None,
),
),
("nougat", (None, "NougatTokenizerFast" if is_tokenizers_available() else None)),
(
"nystromformer",
(
@ -519,7 +505,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("OpenAIGPTTokenizer", "OpenAIGPTTokenizerFast" if is_tokenizers_available() else None),
),
("opt", ("GPT2Tokenizer", "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("ovis2", (None, "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
("owlv2", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("owlvit", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
("paligemma", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
@ -545,7 +530,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
None,
),
),
("perception_lm", (None, "LlamaTokenizerFast" if is_tokenizers_available() else None)),
(
"persimmon",
(
@ -555,7 +539,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
),
("phi", ("CodeGenTokenizer", "CodeGenTokenizerFast" if is_tokenizers_available() else None)),
("phi3", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("phi4_multimodal", (None, "GPT2TokenizerFast" if is_tokenizers_available() else None)),
("phimoe", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("phobert", ("PhobertTokenizer", None)),
("pix2struct", ("T5Tokenizer", "T5TokenizerFast" if is_tokenizers_available() else None)),
@ -569,7 +552,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
),
),
("plbart", ("PLBartTokenizer" if is_sentencepiece_available() else None, None)),
("pop2piano", ("Pop2PianoTokenizer", None)),
("prophetnet", ("ProphetNetTokenizer", None)),
("qdqbert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
(
@ -676,7 +658,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
),
),
("smollm3", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("smolvlm", (None, "PreTrainedTokenizerFast" if is_tokenizers_available() else None)),
("speech_to_text", ("Speech2TextTokenizer" if is_sentencepiece_available() else None, None)),
("speech_to_text_2", ("Speech2Text2Tokenizer", None)),
("speecht5", ("SpeechT5Tokenizer" if is_sentencepiece_available() else None, None)),
@ -711,7 +692,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("tapas", ("TapasTokenizer", None)),
("tapex", ("TapexTokenizer", None)),
("transfo-xl", ("TransfoXLTokenizer", None)),
("trocr", ("XLMRobertaTokenizer", "XLMRobertaTokenizerFast" if is_tokenizers_available() else None)),
("tvp", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
(
"udop",
@ -727,14 +707,9 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
"T5TokenizerFast" if is_tokenizers_available() else None,
),
),
("video_llama_3", ("Qwen2Tokenizer", "Qwen2TokenizerFast" if is_tokenizers_available() else None)),
("video_llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("vipllava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
(
"vision_text_dual_encoder",
("PreTrainedTokenizer", "PreTrainedTokenizerFast" if is_tokenizers_available() else None),
),
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("vits", ("VitsTokenizer", None)),
(
@ -750,7 +725,6 @@ TOKENIZER_MAPPING_NAMES = OrderedDict[str, tuple[Optional[str], Optional[str]]](
("wav2vec2-bert", ("Wav2Vec2CTCTokenizer", None)),
("wav2vec2-conformer", ("Wav2Vec2CTCTokenizer", None)),
("wav2vec2_phoneme", ("Wav2Vec2PhonemeCTCTokenizer", None)),
("wav2vec2_with_lm", ("Wav2Vec2CTCTokenizer", None)),
("whisper", ("WhisperTokenizer", "WhisperTokenizerFast" if is_tokenizers_available() else None)),
("xclip", ("CLIPTokenizer", "CLIPTokenizerFast" if is_tokenizers_available() else None)),
(
@ -1186,7 +1160,7 @@ class AutoTokenizer:
The configuration corresponding to the model to register.
slow_tokenizer_class ([`PretrainedTokenizer`], *optional*):
The slow tokenizer to register.
fast_tokenizer_class ([`PreTrainedTokenizerFast`], *optional*):
fast_tokenizer_class ([`PretrainedTokenizerFast`], *optional*):
The fast tokenizer to register.
"""
if slow_tokenizer_class is None and fast_tokenizer_class is None:

View File

@ -60,7 +60,6 @@ else:
("qwen3_vl_moe", "Qwen3VLVideoProcessor"),
("sam2_video", "Sam2VideoVideoProcessor"),
("smolvlm", "SmolVLMVideoProcessor"),
("video_llama_3", "VideoLlama3VideoProcessor"),
("video_llava", "VideoLlavaVideoProcessor"),
("videomae", "VideoMAEVideoProcessor"),
("vjepa2", "VJEPA2VideoProcessor"),
@ -292,7 +291,7 @@ class AutoVideoProcessor:
# Some models have different image processors, e.g. InternVL uses GotOCRImageProcessor
# We cannot use GotOCRVideoProcessor when falling back for BC and should try to infer from config later on
if video_processor_class_from_name(video_processor_class_inferred) is not None:
if video_processor_class_inferred in VIDEO_PROCESSOR_MAPPING_NAMES.values():
video_processor_class = video_processor_class_inferred
if "AutoImageProcessor" in config_dict.get("auto_map", {}):
image_processor_auto_map = config_dict["auto_map"]["AutoImageProcessor"]

View File

@ -826,21 +826,22 @@ class AutoformerPreTrainedModel(PreTrainedModel):
main_input_name = "past_values"
supports_gradient_checkpointing = True
@torch.no_grad()
def _init_weights(self, module: nn.Module):
std = self.config.init_std
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, AutoformerSinusoidalPositionalEmbedding):
module._init_weight()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
module.weight[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
module.weight.fill_(1.0)
module.bias.zero_()
# copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
def _update_full_mask(

View File

@ -338,7 +338,7 @@ class AyaVisionForConditionalGeneration(AyaVisionPreTrainedModel, GenerationMixi
"^multi_modal_projector": "model.multi_modal_projector",
"^language_model.lm_head": "lm_head",
}
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
def __init__(self, config: AyaVisionConfig):
super().__init__(config)

View File

@ -70,6 +70,10 @@ class AyaVisionProcessor(ProcessorMixin):
in a chat into a tokenizable string.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"
def __init__(
self,
image_processor=None,

View File

@ -1126,12 +1126,13 @@ class BambaPreTrainedModel(PreTrainedModel):
# Note: only supports HybridMambaAttentionDynamicCache
_is_stateful = True
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, BambaMixer):
module.dt_bias.data.fill_(1.0)
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
module.D.data.fill_(1.0)
module.dt_bias.fill_(1.0)
module.A_log.copy_(torch.log(torch.arange(1, module.num_heads + 1)))
module.D.fill_(1.0)
@auto_docstring
@ -1383,7 +1384,7 @@ class BambaModel(BambaPreTrainedModel):
@auto_docstring
class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

View File

@ -800,12 +800,13 @@ class BambaPreTrainedModel(PreTrainedModel):
# Note: only supports HybridMambaAttentionDynamicCache
_is_stateful = True
@torch.no_grad()
def _init_weights(self, module):
super()._init_weights(module)
if isinstance(module, BambaMixer):
module.dt_bias.data.fill_(1.0)
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
module.D.data.fill_(1.0)
module.dt_bias.fill_(1.0)
module.A_log.copy_(torch.log(torch.arange(1, module.num_heads + 1)))
module.D.fill_(1.0)
@auto_docstring

View File

@ -329,19 +329,20 @@ class BarkPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = False
_supports_flash_attn = True
@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear,)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
module.weight[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@ -910,6 +911,9 @@ class BarkFineModel(BarkPreTrainedModel):
# non-causal gpt-like model with one embedding layer and one lm_head for each codebook of Encodec
super().__init__(config)
self.config = config
self._tied_weights_keys = {}
for i in range(self.config.n_codes_total - self.config.n_codes_given):
self._tied_weights_keys[f"lm_heads.{i}.weight"] = f"input_embeds_layers.{i + 1}.weight"
# initialize a modified non causal GPT-like model
# note that for there is one embedding layer and one lm_head for each codebook of Encodec
@ -1025,25 +1029,6 @@ class BarkFineModel(BarkPreTrainedModel):
return model_embeds
def _tie_weights(self):
if getattr(self.config, "tie_word_embeddings", True):
self._tied_weights_keys = []
output_embeddings = self.get_output_embeddings()
input_embeddings = self.get_input_embeddings()
for i in range(self.config.n_codes_total - self.config.n_codes_given):
# self.input_embeds_layers[i + 1].weight = self.lm_heads[i].weight
self._tie_embedding_weights(output_embeddings[i], input_embeddings[i + 1])
self._tied_weights_keys.append(f"lm_heads.{i}.weight")
def tie_weights(self):
"""
Tie the weights between the input embeddings list and the output embeddings list.
"""
for module in self.modules():
if hasattr(module, "_tie_weights"):
module._tie_weights()
@auto_docstring
def forward(
self,
@ -1580,14 +1565,6 @@ class BarkModel(BarkPreTrainedModel, GenerationMixin):
return audio
def tie_weights(self):
"""
Tie the weights between the input embeddings list and the output embeddings list.
"""
for module in self.modules():
if hasattr(module, "_tie_weights"):
module._tie_weights()
__all__ = [
"BarkFineModel",

View File

@ -49,6 +49,9 @@ class BarkProcessor(ProcessorMixin):
"""
tokenizer_class = "AutoTokenizer"
attributes = ["tokenizer"]
preset_shape = {
"semantic_prompt": 1, # 1D array of shape (X,)
"coarse_prompt": 2, # 2D array of shape (2,X)

View File

@ -476,19 +476,20 @@ class BartPreTrainedModel(PreTrainedModel):
_can_compile_fullgraph = True
@torch.no_grad()
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
module.weight[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
module.weight.fill_(1.0)
module.bias.zero_()
@property
def dummy_inputs(self):
@ -527,7 +528,7 @@ class BartEncoder(BartPreTrainedModel):
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
def __init__(self, config: BartConfig):
super().__init__(config)
self.dropout = config.dropout
@ -538,12 +539,9 @@ class BartEncoder(BartPreTrainedModel):
self.max_source_positions = config.max_position_embeddings
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = BartScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
self.embed_tokens = BartScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
@ -674,7 +672,7 @@ class BartDecoder(BartPreTrainedModel):
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
def __init__(self, config: BartConfig):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
@ -682,12 +680,9 @@ class BartDecoder(BartPreTrainedModel):
self.max_target_positions = config.max_position_embeddings
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = BartScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
self.embed_tokens = BartScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
@ -899,7 +894,10 @@ class BartDecoder(BartPreTrainedModel):
@auto_docstring
class BartModel(BartPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = {
"decoder.embed_tokens.weight": "shared.weight",
"encoder.embed_tokens.weight": "shared.weight",
}
def __init__(self, config: BartConfig):
super().__init__(config)
@ -908,24 +906,12 @@ class BartModel(BartPreTrainedModel):
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.shared = BartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
self.encoder = BartEncoder(config, self.shared)
self.decoder = BartDecoder(config, self.shared)
self.encoder = BartEncoder(config)
self.decoder = BartDecoder(config)
# Initialize weights and apply final processing
self.post_init()
def _tie_weights(self):
if self.config.tie_word_embeddings:
# Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens, need check here, see issue #36247
if self.shared.weight.device == torch.device(
"meta"
) and self.decoder.embed_tokens.weight.device != torch.device("meta"):
self._tie_embedding_weights(self.encoder.embed_tokens, self.decoder.embed_tokens)
self._tie_embedding_weights(self.shared, self.decoder.embed_tokens)
else:
self._tie_embedding_weights(self.encoder.embed_tokens, self.shared)
self._tie_embedding_weights(self.decoder.embed_tokens, self.shared)
def get_input_embeddings(self):
return self.shared
@ -1052,7 +1038,9 @@ class BartModel(BartPreTrainedModel):
)
class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_tied_weights_keys = {
"lm_head.weight": "model.shared.weight",
}
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
def __init__(self, config: BartConfig):
@ -1086,11 +1074,6 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
self.register_buffer("final_logits_bias", new_bias)
def _tie_weights(self):
if self.config.tie_word_embeddings:
self.model._tie_weights()
self._tie_embedding_weights(self.lm_head, self.model.shared)
@auto_docstring
def forward(
self,
@ -1240,8 +1223,6 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
"""
)
class BartForSequenceClassification(BartPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BartConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = BartModel(config)
@ -1374,8 +1355,6 @@ class BartForSequenceClassification(BartPreTrainedModel):
@auto_docstring
class BartForQuestionAnswering(BartPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config):
super().__init__(config)
@ -1513,7 +1492,9 @@ class BartDecoderWrapper(BartPreTrainedModel):
"""
)
class BartForCausalLM(BartPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {
"lm_head.weight": "model.decoder.embed_tokens.weight",
}
def __init__(self, config):
config.is_decoder = True

View File

@ -692,31 +692,32 @@ class BeitPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
_supports_sdpa = True
@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
module.weight[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, BeitEmbeddings):
module.cls_token.data.zero_()
module.cls_token.zero_()
if module.mask_token is not None:
module.mask_token.data.zero_()
module.mask_token.zero_()
if module.position_embeddings is not None:
module.position_embeddings.data.zero_()
module.position_embeddings.zero_()
elif isinstance(module, BeitRelativePositionBias):
module.relative_position_bias_table.data.zero_()
module.relative_position_bias_table.zero_()
elif isinstance(module, BeitLayer):
if module.lambda_1 is not None:
module.lambda_1.data.fill_(self.config.layer_scale_init_value)
module.lambda_2.data.fill_(self.config.layer_scale_init_value)
module.lambda_1.fill_(self.config.layer_scale_init_value)
module.lambda_2.fill_(self.config.layer_scale_init_value)
@auto_docstring

View File

@ -506,16 +506,9 @@ class BertLMPredictionHead(nn.Module):
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
@ -569,21 +562,22 @@ class BertPreTrainedModel(PreTrainedModel):
"cross_attentions": BertCrossAttention,
}
@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
module.weight[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, BertLMPredictionHead):
module.bias.data.zero_()
module.bias.zero_()
@dataclass
@ -770,7 +764,10 @@ class BertModel(BertPreTrainedModel):
"""
)
class BertForPreTraining(BertPreTrainedModel):
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = {
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
"cls.predictions.decoder.bias": "cls.predictions.bias",
}
def __init__(self, config):
super().__init__(config)
@ -864,7 +861,10 @@ class BertForPreTraining(BertPreTrainedModel):
"""
)
class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = {
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
"cls.predictions.decoder.bias": "cls.predictions.bias",
}
def __init__(self, config):
super().__init__(config)
@ -948,7 +948,10 @@ class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
@auto_docstring
class BertForMaskedLM(BertPreTrainedModel):
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
_tied_weights_keys = {
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
"cls.predictions.decoder.bias": "cls.predictions.bias",
}
def __init__(self, config):
super().__init__(config)

View File

@ -456,21 +456,22 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
"cross_attentions": BertGenerationCrossAttention,
}
@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
module.weight[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, BertGenerationOnlyLMHead):
module.bias.data.zero_()
module.bias.zero_()
@auto_docstring(
@ -629,20 +630,11 @@ class BertGenerationOnlyLMHead(nn.Module):
super().__init__()
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
self.decoder.bias = self.bias
def forward(self, hidden_states):
logits = self.decoder(hidden_states)
return logits
def _tie_weights(self):
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
@auto_docstring(
custom_intro="""
@ -650,7 +642,10 @@ class BertGenerationOnlyLMHead(nn.Module):
"""
)
class BertGenerationDecoder(BertGenerationPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
_tied_weights_keys = {
"lm_head.decoder.weight": "bert.embeddings.word_embeddings.weight",
"lm_head.decoder.bias": "lm_head.bias",
}
def __init__(self, config):
super().__init__(config)

View File

@ -1464,16 +1464,9 @@ class BigBirdLMPredictionHead(nn.Module):
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
@ -1521,21 +1514,22 @@ class BigBirdPreTrainedModel(PreTrainedModel):
base_model_prefix = "bert"
supports_gradient_checkpointing = True
@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
module.weight[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, BigBirdLMPredictionHead):
module.bias.data.zero_()
module.bias.zero_()
@dataclass
@ -1899,7 +1893,10 @@ class BigBirdModel(BigBirdPreTrainedModel):
class BigBirdForPreTraining(BigBirdPreTrainedModel):
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = {
"cls.predictions.decoder.bias": "cls.predictions.bias",
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
}
def __init__(self, config):
super().__init__(config)
@ -1999,7 +1996,10 @@ class BigBirdForPreTraining(BigBirdPreTrainedModel):
@auto_docstring
class BigBirdForMaskedLM(BigBirdPreTrainedModel):
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = {
"cls.predictions.decoder.bias": "cls.predictions.bias",
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
}
def __init__(self, config):
super().__init__(config)
@ -2141,7 +2141,10 @@ class BigBirdForMaskedLM(BigBirdPreTrainedModel):
"""
)
class BigBirdForCausalLM(BigBirdPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = {
"cls.predictions.decoder.bias": "cls.predictions.bias",
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
}
def __init__(self, config):
super().__init__(config)

View File

@ -1539,19 +1539,20 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
_can_compile_fullgraph = True
@torch.no_grad()
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
module.weight[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
module.weight.fill_(1.0)
module.bias.zero_()
@property
def dummy_inputs(self):
@ -1574,7 +1575,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
def __init__(self, config: BigBirdPegasusConfig):
super().__init__(config)
self.attention_type = config.attention_type
@ -1592,9 +1593,6 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
@ -1849,7 +1847,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embedding] = None):
def __init__(self, config: BigBirdPegasusConfig):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
@ -1861,9 +1859,6 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
@ -2075,7 +2070,10 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
@auto_docstring
class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
_tied_weights_keys = {
"encoder.embed_tokens.weight": "shared.weight",
"decoder.embed_tokens.weight": "shared.weight",
}
def __init__(self, config: BigBirdPegasusConfig):
super().__init__(config)
@ -2086,8 +2084,8 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
vocab_size, config.d_model, padding_idx, embed_scale=embed_scale
)
self.encoder = BigBirdPegasusEncoder(config, self.shared)
self.decoder = BigBirdPegasusDecoder(config, self.shared)
self.encoder = BigBirdPegasusEncoder(config)
self.decoder = BigBirdPegasusDecoder(config)
# Initialize weights and apply final processing
self.post_init()
@ -2100,11 +2098,6 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_embedding_weights(self.encoder.embed_tokens, self.shared)
self._tie_embedding_weights(self.decoder.embed_tokens, self.shared)
def get_encoder(self):
return self.encoder
@ -2213,7 +2206,9 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_tied_weights_keys = {
"lm_head.weight": "model.shared.weight",
}
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
def __init__(self, config: BigBirdPegasusConfig):
@ -2247,11 +2242,6 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, Gene
new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
self.register_buffer("final_logits_bias", new_bias)
def _tie_weights(self):
if self.config.tie_word_embeddings:
self.model._tie_weights()
self._tie_embedding_weights(self.lm_head, self.model.shared)
@auto_docstring
# Ignore copy
def forward(
@ -2374,8 +2364,6 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, Gene
"""
)
class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config: BigBirdPegasusConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = BigBirdPegasusModel(config)
@ -2497,8 +2485,6 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
@auto_docstring
class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
def __init__(self, config):
super().__init__(config)
@ -2621,8 +2607,6 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
config.is_decoder = True
config.is_encoder_decoder = False

View File

@ -510,7 +510,7 @@ class BioGptModel(BioGptPreTrainedModel):
"""
)
class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["output_projection.weight"]
_tied_weights_keys = {"output_projection.weight": "biogpt.embed_tokens.weight"}
def __init__(self, config):
super().__init__(config)

View File

@ -332,7 +332,7 @@ class BioGptModel(BioGptPreTrainedModel):
"""
)
class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["output_projection.weight"]
_tied_weights_keys = {"output_projection.weight": "biogpt.embed_tokens.weight"}
def __init__(self, config):
super().__init__(config)

View File

@ -628,6 +628,7 @@ class BitPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
_no_split_modules = ["BitEmbeddings"]
@torch.no_grad()
def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")

View File

@ -433,7 +433,7 @@ class BitNetModel(BitNetPreTrainedModel):
@auto_docstring
class BitNetForCausalLM(BitNetPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = None
_pp_plan = None

View File

@ -114,7 +114,7 @@ class BitNetModel(LlamaModel):
class BitNetForCausalLM(LlamaForCausalLM):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = None
_pp_plan = None

View File

@ -438,19 +438,20 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
_can_compile_fullgraph = True
@torch.no_grad()
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
module.weight[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
module.weight.fill_(1.0)
module.bias.zero_()
@property
def dummy_inputs(self):
@ -474,7 +475,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None):
def __init__(self, config: BlenderbotConfig):
super().__init__(config)
self.dropout = config.dropout
@ -485,12 +486,9 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
self.max_source_positions = config.max_position_embeddings
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = BlenderbotScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
self.embed_tokens = BlenderbotScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
config.max_position_embeddings,
@ -623,7 +621,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding] = None):
def __init__(self, config: BlenderbotConfig):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
@ -631,12 +629,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
self.max_target_positions = config.max_position_embeddings
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = BlenderbotScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
self.embed_tokens = BlenderbotScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
config.max_position_embeddings,
@ -852,7 +847,10 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
@auto_docstring
class BlenderbotModel(BlenderbotPreTrainedModel):
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = {
"encoder.embed_tokens.weight": "shared.weight",
"decoder.embed_tokens.weight": "shared.weight",
}
def __init__(self, config: BlenderbotConfig):
super().__init__(config)
@ -860,8 +858,8 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.shared = BlenderbotScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
self.encoder = BlenderbotEncoder(config, self.shared)
self.decoder = BlenderbotDecoder(config, self.shared)
self.encoder = BlenderbotEncoder(config)
self.decoder = BlenderbotDecoder(config)
# Initialize weights and apply final processing
self.post_init()
@ -1001,7 +999,9 @@ class BlenderbotModel(BlenderbotPreTrainedModel):
class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
_tied_weights_keys = {
"lm_head.weight": "model.shared.weight",
}
def __init__(self, config: BlenderbotConfig):
super().__init__(config)
@ -1184,7 +1184,9 @@ class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill
class BlenderbotForCausalLM(BlenderbotPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {
"lm_head.weight": "model.decoder.embed_tokens.weight",
}
def __init__(self, config):
config.is_decoder = True

View File

@ -431,19 +431,20 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
_can_compile_fullgraph = True
@torch.no_grad()
def _init_weights(self, module):
std = self.config.init_std
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
module.weight[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
module.weight.fill_(1.0)
module.bias.zero_()
@property
def dummy_inputs(self):
@ -467,7 +468,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embedding] = None):
def __init__(self, config: BlenderbotSmallConfig):
super().__init__(config)
self.dropout = config.dropout
@ -478,10 +479,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_positions = BlenderbotSmallLearnedPositionalEmbedding(
config.max_position_embeddings,
@ -612,7 +610,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embedding] = None):
def __init__(self, config: BlenderbotSmallConfig):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
@ -620,10 +618,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_positions = BlenderbotSmallLearnedPositionalEmbedding(
config.max_position_embeddings,
@ -838,7 +833,10 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
@auto_docstring
class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
_tied_weights_keys = {
"encoder.embed_tokens.weight": "shared.weight",
"decoder.embed_tokens.weight": "shared.weight",
}
def __init__(self, config: BlenderbotSmallConfig):
super().__init__(config)
@ -846,8 +844,8 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
self.encoder = BlenderbotSmallEncoder(config, self.shared)
self.decoder = BlenderbotSmallDecoder(config, self.shared)
self.encoder = BlenderbotSmallEncoder(config)
self.decoder = BlenderbotSmallDecoder(config)
# Initialize weights and apply final processing
self.post_init()
@ -974,7 +972,9 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
_tied_weights_keys = {
"lm_head.weight": "model.shared.weight",
}
def __init__(self, config: BlenderbotSmallConfig):
super().__init__(config)
@ -1144,7 +1144,9 @@ class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {
"lm_head.weight": "model.decoder.embed_tokens.weight",
}
def __init__(self, config):
config.is_decoder = True

View File

@ -419,13 +419,14 @@ class BlipPreTrainedModel(PreTrainedModel):
_no_split_modules = ["BlipEncoderLayer", "BlipTextEmbeddings"]
_skip_keys_device_placement = ["past_key_values"]
@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_range
if isinstance(module, (nn.Conv2d, nn.Embedding, nn.Linear)):
module.weight.data.normal_(mean=0.0, std=factor)
module.weight.normal_(mean=0.0, std=factor)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
if isinstance(module, BlipVisionEmbeddings):
if hasattr(self.config, "vision_config"):
@ -443,10 +444,10 @@ class BlipPreTrainedModel(PreTrainedModel):
)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
class BlipEncoder(nn.Module):
@ -797,8 +798,11 @@ class BlipModel(BlipPreTrainedModel):
)
class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin):
config: BlipConfig
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
main_input_name = "pixel_values"
_tied_weights_keys = {
"text_decoder.cls.predictions.decoder.bias": "text_decoder.cls.predictions.bias",
"text_decoder.cls.predictions.decoder.weight": "text_decoder.bert.embeddings.word_embeddings.weight",
} # TODO @arthurzucker check why we need this when for other models, their subPreTrainedModel handle it themselves.
def __init__(self, config: BlipConfig):
super().__init__(config)
@ -963,7 +967,10 @@ class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin):
)
class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin):
config: BlipConfig
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
_tied_weights_keys = {
"text_decoder.cls.predictions.decoder.bias": "text_decoder.cls.predictions.bias",
"text_decoder.cls.predictions.decoder.weight": "text_decoder.bert.embeddings.word_embeddings.weight",
}
def __init__(self, config: BlipConfig):
super().__init__(config)
@ -971,7 +978,6 @@ class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin):
self.vision_model = BlipVisionModel(config.vision_config)
self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
self.text_decoder = BlipTextLMHeadModel(config.text_config)
self.decoder_pad_token_id = config.text_config.pad_token_id

View File

@ -473,16 +473,9 @@ class BlipTextLMPredictionHead(nn.Module):
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
@ -511,15 +504,16 @@ class BlipTextPreTrainedModel(PreTrainedModel):
base_model_prefix = "bert"
_no_split_modules = []
@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
module.weight.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
# Adapted from https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/med.py#L571
@ -744,7 +738,10 @@ class BlipTextModel(BlipTextPreTrainedModel):
# Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811
class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = {
"cls.predictions.decoder.bias": "cls.predictions.bias",
"cls.predictions.decoder.weight": "bert.embeddings.word_embeddings.weight",
}
def __init__(self, config):
super().__init__(config)

View File

@ -53,6 +53,10 @@ class BlipProcessor(ProcessorMixin):
An instance of ['BertTokenizerFast`]. The tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = ("BlipImageProcessor", "BlipImageProcessorFast")
tokenizer_class = ("BertTokenizer", "BertTokenizerFast")
def __init__(self, image_processor, tokenizer, **kwargs):
tokenizer.return_token_type_ids = False
super().__init__(image_processor, tokenizer)

View File

@ -263,9 +263,7 @@ class Blip2Config(PreTrainedConfig):
```"""
model_type = "blip-2"
attribute_map = {
"image_token_id": "image_token_index",
}
attribute_map = {"image_token_id": "image_token_index", "tie_words_embeddings": "use_decoder_only_language_model"}
sub_configs = {"text_config": AutoConfig, "qformer_config": Blip2QFormerConfig, "vision_config": Blip2VisionConfig}
def __init__(

View File

@ -409,19 +409,20 @@ class Blip2PreTrainedModel(PreTrainedModel):
]
_skip_keys_device_placement = "past_key_values"
@torch.no_grad()
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_range
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=factor)
module.weight.normal_(mean=0.0, std=factor)
if module.bias is not None:
module.bias.data.zero_()
module.bias.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=factor)
module.weight.normal_(mean=0.0, std=factor)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
module.bias.zero_()
module.weight.fill_(1.0)
elif isinstance(module, Blip2VisionEmbeddings):
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
@ -435,7 +436,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
Blip2ForImageTextRetrieval,
),
):
module.query_tokens.data.zero_()
module.query_tokens.zero_()
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2
@ -1049,10 +1050,6 @@ class Blip2Model(Blip2PreTrainedModel):
else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
self.language_model = language_model
# Initialize weights and apply final processing
@ -1076,11 +1073,6 @@ class Blip2Model(Blip2PreTrainedModel):
def get_decoder(self):
return self.language_model.get_decoder()
def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
self.language_model.decoder.embed_tokens = self.language_model.shared
@filter_out_non_signature_kwargs()
@auto_docstring
def get_text_features(
@ -1612,10 +1604,6 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
else:
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)
# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
self.language_model = language_model
# Initialize weights and apply final processing
@ -1639,11 +1627,6 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
def get_decoder(self):
return self.language_model.get_decoder()
def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
self.language_model.decoder.embed_tokens = self.language_model.shared
def _preprocess_accelerate(self):
r"""
Some pre-processing hacks to make the model `accelerate` compatible. Check

View File

@ -60,6 +60,10 @@ class Blip2Processor(ProcessorMixin):
Number of tokens used by the Qformer as queries, should be same as in model's config.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = ("BlipImageProcessor", "BlipImageProcessorFast")
tokenizer_class = "AutoTokenizer"
def __init__(self, image_processor, tokenizer, num_query_tokens=None, **kwargs):
tokenizer.return_token_type_ids = False
if not hasattr(tokenizer, "image_token"):

Some files were not shown because too many files have changed in this diff Show More