Compare commits

..

89 Commits

Author SHA1 Message Date
38a8ce45eb add other models 2025-05-21 12:19:55 +02:00
0a246a5a5d fix test 2025-05-21 12:07:23 +02:00
0d72f20449 rm dev files 2025-05-21 11:42:17 +02:00
1cb4b7fb6a rm func from some models 2025-05-21 11:38:58 +02:00
f3fb6164f2 rm build_input function for fast, don't test prepare_for_model for fast as we don't use it 2025-05-21 11:38:58 +02:00
2b8774a7c3 rm build_input.. from old file 2025-05-21 11:38:58 +02:00
adeb8cddf1 rm build_inputs_with_special_tokens from llama and gemma 2025-05-21 11:38:58 +02:00
148e3159d4 skipping tests 2025-05-21 11:38:58 +02:00
cc76a4f113 ruff 2025-05-21 11:38:58 +02:00
0202f862ae change test 2025-05-21 11:38:58 +02:00
6829936ee0 [MODEL] Add Falcon H1 (#38249)
* Create push-important-models.yml

* feat: add falcon-h1

* fixup

* address comment

* fix

* fix copies

* fix copies

* fix

* fix

* fix

* fix

* fix copies

* fix

* fix copies

* fix test import to at least trigget the cis

* yups

* update

* fix make fix copies

* fix inits?

* fix style

* skip annoying test

* add integration test for Falcon H1

* fix copies

* fix

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: dhia.rhaiem <dhia.rhaiem@tii.ae>
2025-05-21 10:43:11 +02:00
e288ee00d8 tp plan should not be NONE (#38255)
* accept custom device_mesh

* fix device_map

* assert that num_heads % tp_size == 0

* todo.

* ReplicateParallel

* handle tied weights

* handle dtensor in save_pretrained with safe_serialization

* tp test works

* doesnt work

* fix shard_and_distribute_module's rank should be local_rank

* tp=4 is correct

* dp+tp is broken

* todo allreduce with dtensors on another dim is annoying

* workaround to sync dp grads when using dtensors

* loading a checkpoint works

* wandb and compare losses with different tp/dp

* cleaning

* cleaning

* .

* .

* logs

* CP2 DP2 no mask works after commenting attn_mask and is_causal from scaled_dot_product_attention

* DP=2 TP=2 now works even with tied embeddings

* model.parameters() and model.module.parameters() are empty..

* reformat sanity_check_tensor_sync

* set atol=1e-4 for CP to pass

* try populate _parameters from named_modules

* refactors
TP2 DP2 works
CP2 DP2 works

* is_causal=True and pack sequences, no attn mask, and preshuffle dataset

* fix packing

* CP=4 doesn't work

* fix labels and position_ids for CP

* DP CP works with transformers 🥳🥳🥳

* refactor

* add example cp

* fixup

* revert sdpa changes

* example cleared

* add CP, DP to the mesh init

* nit

* clean

* use `ALL_PARALLEL_STYLES`

* style

* FSDP works

* log on 1 rank

* .

* fix?

* FSDP1 also has .parameters() bug

* reported gradnorm when using FSDP1 is wrong, but loss is correct so it's okay

* .

* style and fixup

* move stuff around

* fix tests

* style

* let's make it a check

* add missing licences

* warning should be an info

* tp plan should not be NONE

* test all

* god damn it

* test all

---------

Co-authored-by: nouamanetazi <nouamane98@gmail.com>
2025-05-21 10:22:38 +02:00
711d78d104 Revert parallelism temporarily (#38240)
* Revert "Protect ParallelInterface"

This reverts commit cb513e35f9c096d60558bd43110837cbb66611ce.

* Revert "parallelism goes brrr (#37877)"

This reverts commit 1c2f36b480e02c9027d2523746d34e27b39e01a4.

* Empty commit
2025-05-20 22:43:04 +02:00
feec294dea CI reporting improvements (#38230)
update

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-05-20 19:34:58 +02:00
cb513e35f9 Protect ParallelInterface 2025-05-20 18:27:50 +02:00
f4ef41c45e v4.53.0.dev0 2025-05-20 18:12:56 +02:00
f834d368f6 [gemma3] fix bidirectional attention mask (#38080)
* fix attn mask

* attn viz doesn't show yello cubes between images

* bucketize made it hard with different number of crops

* fixup
2025-05-20 17:35:04 +02:00
2edb0e4b4d [mllama] fix loading and inference (#38223)
fix loading
2025-05-20 17:34:55 +02:00
390f153469 Add padding-free to bamba (#35861)
* add seq_idx and fa kwargs

* update tests

* docs and grad ckpt support

* fmt

* better names

* test_raise_missing_padding_free_kwarg_errs

* + seq_idx in doc strings

* padding free training docs

* add link to pr plots

* raise err on attn_mask with padding free

* rm raising missing padding free err test

* BambaFlashAttentionKwargs

* run modular util for modular_granitemoehybrid.py
2025-05-20 17:13:59 +02:00
2a79471318 Fixing Bitnet after use_rms_norm introduction (#38229)
* fix

* make style
2025-05-20 17:13:21 +02:00
9661896083 Enable Quantize KV Cache for Mistral Model (#35042)
fix #35041
2025-05-20 16:50:26 +02:00
1c2f36b480 parallelism goes brrr (#37877)
* accept custom device_mesh

* fix device_map

* assert that num_heads % tp_size == 0

* todo.

* ReplicateParallel

* handle tied weights

* handle dtensor in save_pretrained with safe_serialization

* tp test works

* doesnt work

* fix shard_and_distribute_module's rank should be local_rank

* tp=4 is correct

* dp+tp is broken

* todo allreduce with dtensors on another dim is annoying

* workaround to sync dp grads when using dtensors

* loading a checkpoint works

* wandb and compare losses with different tp/dp

* cleaning

* cleaning

* .

* .

* logs

* CP2 DP2 no mask works after commenting attn_mask and is_causal from scaled_dot_product_attention

* DP=2 TP=2 now works even with tied embeddings

* model.parameters() and model.module.parameters() are empty..

* reformat sanity_check_tensor_sync

* set atol=1e-4 for CP to pass

* try populate _parameters from named_modules

* refactors
TP2 DP2 works
CP2 DP2 works

* is_causal=True and pack sequences, no attn mask, and preshuffle dataset

* fix packing

* CP=4 doesn't work

* fix labels and position_ids for CP

* DP CP works with transformers 🥳🥳🥳

* refactor

* add example cp

* fixup

* revert sdpa changes

* example cleared

* add CP, DP to the mesh init

* nit

* clean

* use `ALL_PARALLEL_STYLES`

* style

* FSDP works

* log on 1 rank

* .

* fix?

* FSDP1 also has .parameters() bug

* reported gradnorm when using FSDP1 is wrong, but loss is correct so it's okay

* .

* style and fixup

* move stuff around

* fix tests

* style

* let's make it a check

* warning should be an info

---------

Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com>
2025-05-20 16:22:52 +02:00
b591d925be Fix Llama4 (#38222)
Update modeling_llama4.py
2025-05-20 16:00:46 +02:00
3f0b7d0fac Mamba2 remove unecessary test parameterization (#38227) 2025-05-20 13:54:04 +00:00
9cde2f5d42 Minor llama4 fixes (#38123)
* fix wrong scaling value/default Cache init

* style

* fix various issues on integration tests

* change expected outputs

* fixup

* fix config access

* protect default scaling
2025-05-20 13:15:54 +00:00
856f034f45 fix dead flax links modeling_flax_pytorch_utils.py (#38212) 2025-05-20 13:03:41 +00:00
bb3c6426d8 Make train_dataset attribute in _get_train_sampler optional (#38226)
make it optional
2025-05-20 12:59:53 +00:00
2ad152f84c In Llama4 fix wrongly inverted causal attention mask when using SDPA implementation (#38094)
When preparing the causal attention mask at this point the mask comes
in as a float tensor with min value as a masked value.
It is not correct to convert it to bool and treat it as a bool mask as
this inverts the mask.
`torch.nn.functional.scaled_dot_product_attention` expects that a masked value is `False`.

I suspect that the `sdpa` implementation variant may not have been
thoroughly tested and that is why this error was not caught earlier.

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
2025-05-20 14:47:59 +02:00
de70c8426e Disable torchscript tests for AriaForConditionalGenerationModelTest (#38225)
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
2025-05-20 14:37:55 +02:00
8ea61c4530 Add support to Marimo Notebooks and Enverge.ai (#38210)
* Add support to Marimo notebooks

* Consice logic

* Simplify logic

* Ruff fixes
2025-05-20 12:26:34 +00:00
d34e21e7dd New cache tests and refactored Hybrid Cache (#37972) 2025-05-20 12:46:13 +02:00
183fb3637c Add Llama4TextModel to AutoModel mapping (#38162)
Add Llama4TextModel to AutoModel mapping

using Llama4TextConfig on AutoModel.from_config raises a ValueError when it is expected to instantiate a Llama4TextModel
2025-05-20 10:01:00 +00:00
f022bf9322 Remove trust_remote_code=True tests from bnb quantization tests (MPT now integrated) (#38206)
bnb quant tests: remove obsolete trust_remote_code test

The MPT model is now natively integrated in Transformers and no longer requires trust_remote_code=True. This removes the failing test_get_keys_to_not_convert_trust_remote_code and related usage, which depended on remote code and caused CI issues due to missing dependencies (e.g., triton_pre_mlir).
2025-05-20 11:43:11 +02:00
0a52bd2403 [fix] sliding window attention mask (#38045)
* fix sliding attn

* make style

* Update tests/test_modeling_common.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* no a second throught, should default to `True` fo BC

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
2025-05-20 09:32:19 +00:00
555715f418 Fix broken example generation script for Llama3 (#38062)
Fix broken example generation script for llama3
2025-05-20 10:53:43 +02:00
7a611f0afd Fix: make docs work better with doc builder (#38213) 2025-05-20 08:23:03 +00:00
3bd1c20149 enable misc cases on XPU & use device agnostic APIs for cases in tests (#38192)
* use device agnostic APIs in tests

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* more

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* add reset_peak_memory_stats API

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* update

---------

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-05-20 10:09:01 +02:00
dbc4b91db4 Qwen2.5-Omni: Update modeling_qwen2_5_omni.py to fix error when loading quantized weights with AutoAWQ. (#38013)
* Update modular_qwen2_5_omni.py

fix the error when loading quantized model by AuotAWQ.

* Update modeling_qwen2_5_omni.py

sync code to modular_qwen2_5_omni.py
2025-05-20 09:53:51 +02:00
46a4b7c909 Feat: save_pretrained for tensor parallel (and other parallelisms) models (#37919)
* tmp: initial save pretrained with dtensors

* Feat: add correctness tests

* Refactor: version checks

* Temp: 1:1 checkpoint llama4

* refactor

* Tests

* Feat: works

* Style

* Feat: version checks + minor fixes

* Style

* Fix: version checks in tests

* Feat: move more stuff into tensor_parallel.py
2025-05-19 18:16:21 +00:00
9ecee14378 [doc] fix bugs in how_to_hack_models.md (#38198)
fix several bugs
2025-05-19 10:37:54 -07:00
f524439cc5 Translating model_doc/bert.md to Chinese (#37806)
* Translated model_doc/bert.md

* Revise grammatical errors

* Changed _toctree.yml

* Revise some errors
2025-05-19 10:14:57 -07:00
6e738411e1 Tensor parallel docs (#38178)
* Feat: initial docs

* Feat: update doc

* Final typos/changes

* Refactor: reorder top to bottom.
2025-05-19 17:05:01 +00:00
9c500015c5 🚨🚨🚨 [pipelines] update defaults in pipelines that can generate (#38129)
* pipeline generation defaults

* add max_new_tokens=20 in test pipelines

* pop all kwargs that are used to parameterize generation config

* add class attr that tell us whether a pipeline calls generate

* tmp commit

* pt text gen pipeline tests passing

* remove failing tf tests

* fix text gen pipeline mixin test corner case

* update text_to_audio pipeline tests

* trigger tests

* a few more tests

* skips

* some more audio tests

* not slow

* broken

* lower severity of generation mode errors

* fix all asr pipeline tests

* nit

* skip

* image to text pipeline tests

* text2test pipeline

* last pipelines

* fix flaky

* PR comments

* handle generate attrs more carefully in models that cant generate

* same as above
2025-05-19 18:02:06 +01:00
6f9da7649f [image-text-to-text pipeline] Accept a chat as a positional arg (#38204)
accept chat as a positional arg
2025-05-19 17:26:09 +01:00
7c9b0ca08c [SAM-HQ] Update names in the docs (#38058)
Update names
2025-05-19 09:21:14 -07:00
04282a9ef5 Remove Deprecated verbose arg in LayerWiseDummyScheduler (#38197)
Remove Deprecated args in LayerWiseDummyScheduler
2025-05-19 13:49:11 +00:00
aef12349b6 Make HF implementation match original OLMo 2 models for lower precisions (#38131)
* Make HF implementation match OLMo models for lower precisions

* Add test of 1B logits in bfloat16

* Run make fixup
2025-05-19 15:35:23 +02:00
9644acb7cb [docs] add Audio import (#38195)
add Audio import
2025-05-19 13:16:35 +00:00
7d93f93f83 [docs] minor fixes in models.md (#38193)
minor gix
2025-05-19 13:14:21 +00:00
47f8578d96 Pass eps to Mistral3RMSNorm (#38026)
Pass eps to Mistral3RMSNorm

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
2025-05-19 15:09:25 +02:00
6c6302817d Resolve Python logger warnings (#38183)
* Resolve Python logger warnings

Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>

* Apply style fixes

---------

Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-05-19 12:53:07 +00:00
003deb16f1 Support for transformers explicit filename (#38152)
* Support for transformers explicit filename

* Tests

* Rerun tests
2025-05-19 14:33:47 +02:00
dbb9813dff [generation] Less verbose warnings by default (#38179)
* tmp commit (imports broken)

* working version; update tests

* remove line break

* shorter msg

* dola checks need num_beams=1; other minor PR comments

* update early trainer failing on bad gen config

* make fixup

* test msg
2025-05-19 10:03:37 +00:00
656e2eab3f Add adam_kwargs for Apollo Optimizer (#38168)
Add adam_kwargs for Apollo
2025-05-19 08:59:49 +00:00
6bb6821d93 Refactor get_XXX_dataloader from Trainer (#38090)
* Remove test_dataloader

* refactor
2025-05-19 10:43:27 +02:00
40a493c7ed [tests] remove test_sdpa_equivalence (redundant) (#37911)
* rm test_sdpa_equivalence

* make fixup

---------

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
2025-05-16 18:37:27 +01:00
ea29f61ed9 fix bug in distributed loss test (#38166)
* fix bug in distributed loss test and change some config to pass at both 2&8 gpus

* fix doc
2025-05-16 16:21:35 +00:00
a4389494c7 Fix import torchao.prototype.low_bit_optim since torchao v0.11 (#38174)
* Fix ModuleNotFoundError torchao.prototype.low_bit_optim since torchao v 0.11.0

* Fix space on blank line

* update torchao's AdamW4bit and AdamW8bit import for v0.11.0

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-05-16 18:02:33 +02:00
0ba95564b7 Add args support for fast image processors (#37018)
* add args support to fast image processors

* add comment for clarity

* fix-copies

* Handle child class args passed as both args or kwargs in call and preprocess functions

* revert support args passed as kwargs in overwritten preprocess

* fix image processor errors
2025-05-16 12:01:46 -04:00
d69945e5fc [ESM] Add flash-attention-2 backend for ESM-2 (#38023)
* Add flash-attention-2 backend for ESM-2

Signed-off-by: Peter St. John <pstjohn@nvidia.com>

* update extended_attention_mask for fa2

Signed-off-by: Peter St. John <pstjohn@nvidia.com>

* add test_flash_attn_2_equivalence test

Signed-off-by: Peter St. John <pstjohn@nvidia.com>

---------

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
2025-05-16 14:11:56 +01:00
7b5e327c6e Feat: add warnings for unused keys and rules in tensor parallel (#37893)
Feat: tensor parallel plan verification
2025-05-16 14:52:47 +02:00
120935234f remove some commands from fetch_tests CircleCI job (#38176)
delete

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-05-16 14:42:50 +02:00
91f6fa00f4 Disable convert to draft workflow (#38177)
delete

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-05-16 14:42:14 +02:00
5036ec8872 Disable Trigger CircleCI by ready for review (#38171)
delete

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-05-16 14:02:48 +02:00
7f28da2850 clean autoawq cases on xpu (#38163)
* clean autoawq cases on xpu

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

---------

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
2025-05-16 13:56:43 +02:00
01ad9f4b49 Bart: new cache format (#35314)
* bart compile

* add mbart

* some more models touched by fix-copies

* more

* more models

* even more models

* fix copies

* fix tests

* fix copies

* fix

* biogpt accepts position ids now (breaking?)

* fix failing non-slow tests

* fix some tests

* should not be removed

* small update

* Update src/transformers/models/bart/modeling_bart.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* update for last `main`

* fix copies

* clone `update_causal_mask` from llama

* tmp

* fixup

* why? how?

* fix bart tests

* dont skip test

* address comments

* fix tests

* fix

* fixup and delete the file

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
2025-05-16 13:26:54 +02:00
3ab47b6ce3 [VLMs] add helpers to get multimodal encodings (#37743)
* add helpers in VLMs

* fix tests and copies

* fix blip tests

* make fix-copies

* fix copies

* fixup
2025-05-16 13:20:10 +02:00
1e921a3a9c Add optional RMSNorm support to BitNet quantization (config + layers) (#38087)
* enable optional RMS in BitLinear

* Fix naming

* Import RMS from Llama using config.*

* make fix-copies

* ran CI loop

* remove default BitNetQuantConfig values

* Fix BitNetQuantConfig to be Optional

* Fix config docstrings to match Optoinal

* Edit docstrings to match standards

---------

Co-authored-by: steinmetzc <codysteinmetz7@gmail.com>
Co-authored-by: codys12 <steinmetzc@dh-mgmt4.hpc.msoe.edu>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
2025-05-16 12:38:06 +02:00
57a79f51b2 Fix Qwen2.5 Omni SinusoidsPositionEmbedding precision (#38151)
* Fix Qwen2.5 Omni `SinusoidsPositionEmbedding` precision

fixes https://github.com/QwenLM/Qwen2.5-Omni/issues/271

* Update modular_qwen2_5_omni.py
2025-05-16 12:24:50 +02:00
44fa04ae8d Include output embedding as well with include_embedding flag (#37935)
* Include output embedding as well with `include_embedding` flag

Summary:
att

Test Plan:
python tests/quantization/torchao_integration/test_torchao.py -k test_include_embedding

Reviewers:

Subscribers:

Tasks:

Tags:

* format

* rename include_embedding to include_input_output_embeddings

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
2025-05-16 12:06:11 +02:00
34c1e29cdd enable autoround cases on XPU (#38167)
* enable autoround cases on XPU

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

---------

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
2025-05-16 09:08:35 +00:00
0f77ca72ca [FIX] Save speed metrics to logs (#38136)
Previously, we calculated speed metrics and did not do anything with the result.
2025-05-15 16:58:50 +02:00
27ef46e846 Omit creation of positional IDs within ESM if applicable (#38089)
* omit pos emb creation

* rft

---------

Co-authored-by: sgottreich <sgottreich@absci.com>
2025-05-15 14:09:21 +00:00
fe9426f12d disable deepspeed when setting up fake trainer (#38101)
* disable deepspeed when setting up fake trainer

* Apply style fixes

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-05-15 15:34:04 +02:00
7caa57e85e enable trainer test cases on xpu (#38138)
* enable trainer test cases on xpu

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

---------

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
2025-05-15 12:17:44 +00:00
b11b28cc4e Hotfix: Flash Attention 2 support in Pixtral (#38146)
setting attention_mask to None when flash_attention_2 is selected

Co-authored-by: aurelien.lac <aurelien.lac@lighton.ai>
2025-05-15 11:45:35 +02:00
0e0e5c1044 [generate] Run custom generation code from the Hub (#36405)
* mvp

* remove trust_remote_code

* generate_from_hub

* handle requirements; docs

* english

* doc PR suggestions

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* changed remote code path to generate/generate.py

* model repo has custom generate -> override base generate

* check for proper inheritance

* some doc updates (missing: tag-related docs)

* update docs to model repo

* nit

* nit

* nits

* Update src/transformers/dynamic_module_utils.py

* Apply suggestions from code review

* Update docs/source/en/generation_strategies.md

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* trust remote code is required

* use new import utils for requirements version parsing

* use  org examples

* add tests

* Apply suggestions from code review

Co-authored-by: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com>

* ascii file structure; tag instructions on readme.md

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com>
2025-05-15 10:35:54 +01:00
955e61b0da Remove head mask in generative models (#35786)
* just squash into one commit

* delete print
2025-05-15 10:44:19 +02:00
0173a99e73 enable csm integration cases on xpu, all passed (#38140)
* enable csm test cases on XPU, all passed

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

---------

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
2025-05-15 09:46:29 +02:00
e5a48785d9 [Qwen3] Qwen3 MoE add tp plan for expert mlps (#38135)
fix tp plan
2025-05-15 09:12:39 +02:00
4005e30c80 Fix incorrect attention mask truncate in WhisperFlashAttention2 (#36477)
* Fix incorrect attention mask truncate in whisper flash attention

* also fix incorrect attention mask truncate in qwen2 audio

* Nit attention mask truncate modeling_qwen2_audio.py

* Nit attention mask truncate modeling_whisper.py

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>

---------

Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
2025-05-14 20:08:31 +00:00
aa27fa75cd enable d_fine finetuning properly (#37962)
add pre_output in the front

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
2025-05-14 16:53:04 +01:00
e021bf6bf8 Add manueldeprada to run_slow whitelist (#38126)
Add manueldeprada to run_slow allowed users
2025-05-14 15:16:58 +02:00
ef27b2bc22 [docs] add uv installation instructions for source builds (#37968) 2025-05-14 13:09:41 +00:00
4a2decd192 Update trainer.md (#38113)
Fix typo in torch.compile method parameters
2025-05-14 12:40:00 +00:00
935bbbc711 Add config validation and style tweaks (#37589)
* Add config validation and style tweaks

* Fix style issues

* Fix style issues

* style

* Small fixes for copy/paste errors

---------

Co-authored-by: Cyrile <cyrile.delestre@arkea.com>
2025-05-14 12:22:10 +00:00
1b00966395 Fix auto batch size finder test (#38125)
Ensure --auto_find_batch_size is the last test arg so indexing is correct
2025-05-14 12:12:04 +00:00
fe918d13b9 Fix temporal padding in Qwen2VLImageProcessor when the number of frames is not divisible by temporal_patch_size (#38076)
Qwen2VL: Fix temporal padding in Qwen2VLImageProcessor when frames are not divisible by temporal_patch_size
2025-05-14 12:28:21 +02:00
aaf224d570 [video processor] fix tests (#38104)
* fix tests

* delete

* fix one more test

* fix qwen + some tests are failing irrespective of `VideoProcessor`

* delete file
2025-05-14 10:24:07 +00:00
391 changed files with 15072 additions and 7813 deletions

View File

@ -43,8 +43,6 @@ jobs:
parallelism: 1
steps:
- checkout
- run: git branch
- run: git log -n 1
- run: python3 utils/extract_pr_number_from_circleci.py > pr_number.txt
- run: echo $(cat pr_number.txt)
- run: if [[ "$(cat pr_number.txt)" == "" && "$CIRCLE_BRANCH" != "main" && "$CIRCLE_BRANCH" != *-release ]]; then echo "Not a PR, not the main branch and not a release branch, skip test!"; circleci-agent step halt; fi

View File

@ -1,25 +0,0 @@
name: Change PR to draft
on:
pull_request_target:
types: [opened, reopened]
jobs:
convert_pr_to_draft:
runs-on: ubuntu-22.04
name: Convert PR to draft
permissions:
pull-requests: write
contents: write
if: github.event.pull_request.draft == false
steps:
- name: Convert PR to draft
shell: bash
env:
PR_NUMBER: ${{ github.event.number }}
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
REPO: ${{ github.repository }}
run: |
echo $PR_NUMBER
gh pr ready $PR_NUMBER --repo $REPO --undo
gh pr comment $PR_NUMBER --repo $REPO --body "Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the \`Ready for review\` button (at the bottom of the PR page). This will assign reviewers and trigger CI."

View File

@ -39,55 +39,100 @@ jobs:
name: ci_results_run_models_gpu
path: /transformers/ci_results_run_models_gpu
- name: Check file
working-directory: /transformers
run: |
if [ -f ci_results_run_models_gpu/new_model_failures.json ]; then
echo "`ci_results_run_models_gpu/new_model_failures.json` exists, continue ..."
echo "process=true" >> $GITHUB_ENV
else
echo "`ci_results_run_models_gpu/new_model_failures.json` doesn't exist, abort."
echo "process=false" >> $GITHUB_ENV
fi
- uses: actions/download-artifact@v4
if: ${{ env.process == 'true' }}
with:
pattern: setup_values*
path: setup_values
merge-multiple: true
- name: Prepare some setup values
if: ${{ env.process == 'true' }}
run: |
if [ -f setup_values/prev_workflow_run_id.txt ]; then
echo "PREV_WORKFLOW_RUN_ID=$(cat setup_values/prev_workflow_run_id.txt)" >> $GITHUB_ENV
else
echo "PREV_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
fi
if [ -f setup_values/other_workflow_run_id.txt ]; then
echo "OTHER_WORKFLOW_RUN_ID=$(cat setup_values/other_workflow_run_id.txt)" >> $GITHUB_ENV
else
echo "OTHER_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
fi
- name: Update clone
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: git fetch && git checkout ${{ github.sha }}
- name: Get target commit
working-directory: /transformers/utils
if: ${{ env.process == 'true' }}
run: |
echo "END_SHA=$(TOKEN=${{ secrets.ACCESS_REPO_INFO_TOKEN }} python3 -c 'import os; from get_previous_daily_ci import get_last_daily_ci_run_commit; commit=get_last_daily_ci_run_commit(token=os.environ["TOKEN"]); print(commit)')" >> $GITHUB_ENV
echo "END_SHA=$(TOKEN=${{ secrets.ACCESS_REPO_INFO_TOKEN }} python3 -c 'import os; from get_previous_daily_ci import get_last_daily_ci_run_commit; commit=get_last_daily_ci_run_commit(token=os.environ["TOKEN"], workflow_run_id=os.environ["PREV_WORKFLOW_RUN_ID"]); print(commit)')" >> $GITHUB_ENV
- name: Checkout to `start_sha`
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: git fetch && git checkout ${{ inputs.start_sha }}
- name: Reinstall transformers in edit mode (remove the one installed during docker image build)
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: python3 -m pip uninstall -y transformers && python3 -m pip install -e .
- name: NVIDIA-SMI
if: ${{ env.process == 'true' }}
run: |
nvidia-smi
- name: Environment
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: |
python3 utils/print_env.py
- name: Show installed libraries and their versions
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: pip freeze
- name: Check failed tests
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: python3 utils/check_bad_commit.py --start_commit ${{ inputs.start_sha }} --end_commit ${{ env.END_SHA }} --file ci_results_run_models_gpu/new_model_failures.json --output_file new_model_failures_with_bad_commit.json
- name: Show results
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: |
ls -l new_model_failures_with_bad_commit.json
cat new_model_failures_with_bad_commit.json
- name: Checkout back
working-directory: /transformers
if: ${{ env.process == 'true' }}
run: |
git checkout ${{ inputs.start_sha }}
- name: Process report
shell: bash
working-directory: /transformers
if: ${{ env.process == 'true' }}
env:
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN: ${{ secrets.TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN }}
run: |
python3 utils/process_bad_commit_report.py
@ -95,7 +140,9 @@ jobs:
- name: Process report
shell: bash
working-directory: /transformers
if: ${{ env.process == 'true' }}
env:
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN: ${{ secrets.TRANSFORMERS_CI_RESULTS_UPLOAD_TOKEN }}
run: |
{
@ -105,7 +152,7 @@ jobs:
} >> "$GITHUB_ENV"
- name: Send processed report
if: ${{ !endsWith(env.REPORT_TEXT, '{}') }}
if: ${{ env.process == 'true' && !endsWith(env.REPORT_TEXT, '{}') }}
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
with:
# Slack channel id, channel name, or user id to post message.

View File

@ -29,7 +29,7 @@ jobs:
runs-on: ubuntu-22.04
name: Get PR number
# For security: only allow team members to run
if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "qubvel", "molbap", "gante", "LysandreJik", "Cyrilvallez", "Rocketknight1", "SunMarc", "muellerzr", "eustlb", "MekkCyber"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }}
if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "qubvel", "molbap", "gante", "LysandreJik", "Cyrilvallez", "Rocketknight1", "SunMarc", "muellerzr", "eustlb", "MekkCyber", "manueldeprada"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }}
outputs:
PR_NUMBER: ${{ steps.set_pr_number.outputs.PR_NUMBER }}
steps:

View File

@ -8,8 +8,43 @@ on:
push:
branches:
- run_scheduled_ci*
workflow_dispatch:
inputs:
prev_workflow_run_id:
description: 'previous workflow run id to compare'
type: string
required: false
default: ""
other_workflow_run_id:
description: 'other workflow run id to compare'
type: string
required: false
default: ""
# Used for `push` to easily modiffy the target workflow runs to compare against
env:
prev_workflow_run_id: ""
other_workflow_run_id: ""
jobs:
setup:
name: Setup
runs-on: ubuntu-22.04
steps:
- name: Setup
run: |
mkdir "setup_values"
echo "${{ inputs.prev_workflow_run_id || env.prev_workflow_run_id }}" > "setup_values/prev_workflow_run_id.txt"
echo "${{ inputs.other_workflow_run_id || env.other_workflow_run_id }}" > "setup_values/other_workflow_run_id.txt"
- name: Upload artifacts
uses: actions/upload-artifact@v4
with:
name: setup_values
path: setup_values
model-ci:
name: Model CI
uses: ./.github/workflows/self-scheduled.yml

View File

@ -39,6 +39,21 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/download-artifact@v4
- name: Prepare some setup values
run: |
if [ -f setup_values/prev_workflow_run_id.txt ]; then
echo "PREV_WORKFLOW_RUN_ID=$(cat setup_values/prev_workflow_run_id.txt)" >> $GITHUB_ENV
else
echo "PREV_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
fi
if [ -f setup_values/other_workflow_run_id.txt ]; then
echo "OTHER_WORKFLOW_RUN_ID=$(cat setup_values/other_workflow_run_id.txt)" >> $GITHUB_ENV
else
echo "OTHER_WORKFLOW_RUN_ID=" >> $GITHUB_ENV
fi
- name: Send message to Slack
if: ${{ inputs.job != 'run_quantization_torch_gpu' }}
env:
@ -50,7 +65,6 @@ jobs:
ACCESS_REPO_INFO_TOKEN: ${{ secrets.ACCESS_REPO_INFO_TOKEN }}
CI_EVENT: ${{ inputs.ci_event }}
CI_SHA: ${{ github.sha }}
CI_WORKFLOW_REF: ${{ github.workflow_ref }}
CI_TEST_JOB: ${{ inputs.job }}
SETUP_STATUS: ${{ inputs.setup_status }}
# We pass `needs.setup.outputs.matrix` as the argument. A processing in `notification_service.py` to change
@ -58,7 +72,6 @@ jobs:
# For a job that doesn't depend on (i.e. `needs`) `setup`, the value for `inputs.folder_slices` would be an
# empty string, and the called script still get one argument (which is the emtpy string).
run: |
sudo apt-get install -y curl
pip install huggingface_hub
pip install slack_sdk
pip show slack_sdk
@ -86,7 +99,6 @@ jobs:
# We pass `needs.setup.outputs.quantization_matrix` as the argument. A processing in `notification_service_quantization.py` to change
# `quantization/bnb` to `quantization_bnb` is required, as the artifact names use `_` instead of `/`.
run: |
sudo apt-get install -y curl
pip install huggingface_hub
pip install slack_sdk
pip show slack_sdk

View File

@ -1,20 +0,0 @@
name: Trigger CircleCI
on:
pull_request_target:
types: [ready_for_review]
jobs:
trigger-circleci:
runs-on: ubuntu-22.04
steps:
- name: trigger CircleCI pipeline via GitHub Actions
uses: CircleCI-Public/trigger-circleci-pipeline-action@v1.2.0
with:
GHA_Meta: "Trigger via GitHub Actions"
target-slug: "github/huggingface/transformers"
target-branch: "pull/${{ github.event.number }}/head"
env:
CCI_TOKEN: ${{ secrets.CIRCLECI_PAT }}

View File

@ -98,7 +98,12 @@ Install Transformers from source if you want the latest changes in the library o
```shell
git clone https://github.com/huggingface/transformers.git
cd transformers
# pip
pip install .[torch]
# uv
uv pip install .[torch]
```
## Quickstart

View File

@ -1,54 +0,0 @@
version: '3'
services:
memcached:
image: memcached:1.6.29
container_name: memcached
ports:
- "11211:11211"
environment:
- MEMCACHED_MAX_MEMORY=64m # Set the maximum memory usage
- MEMCACHED_THREADS=4 # Number of threads to use
prometheus:
image: prom/prometheus:latest
command:
- "--config.file=/etc/prometheus/prometheus.yml"
- --web.enable-otlp-receiver # Enable OTLP receiver
- --web.enable-remote-write-receiver
- --enable-feature=exemplar-storage
- --enable-feature=native-histograms
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
ports:
- "9090:9090"
tempo:
image: grafana/tempo:latest
command: [ "-config.file=/etc/tempo.yaml" ]
volumes:
- ./tempo.yaml:/etc/tempo.yaml
ports:
- "14268:14268" # jaeger ingest
- "3200:3200" # tempo
- "9095:9095" # tempo grpc
- "4317:4317" # otlp grpc
- "4318:4318" # otlp http
- "9411:9411" # zipkin
depends_on:
- memcached
grafana:
image: grafana/grafana:latest
volumes:
- ./grafana-datasources.yaml:/etc/grafana/provisioning/datasources/datasources.yaml
environment:
- GF_AUTH_ANONYMOUS_ENABLED=true
- GF_AUTH_ANONYMOUS_ORG_ROLE=Admin
- GF_AUTH_DISABLE_LOGIN_FORM=true
- GF_FEATURE_TOGGLES_ENABLE=traceqlEditor metricsSummary
- GF_INSTALL_PLUGINS=https://storage.googleapis.com/integration-artifacts/grafana-exploretraces-app/grafana-exploretraces-app-latest.zip;grafana-traces-app
ports:
- "3000:3000"
depends_on:
- prometheus
- tempo

View File

@ -455,6 +455,8 @@
title: Falcon
- local: model_doc/falcon3
title: Falcon3
- local: model_doc/falcon_h1
title: FalconH1
- local: model_doc/falcon_mamba
title: FalconMamba
- local: model_doc/flan-t5

View File

@ -20,11 +20,15 @@ A decoding strategy informs how a model should select the next generated token.
This guide will help you understand the different decoding strategies available in Transformers and how and when to use them.
## Greedy search
## Basic decoding methods
Greedy search is the default decoding strategy. It selects the next most likely token at each step. Unless specified in [`GenerationConfig`], this strategy generates a maximum of 20 tokens.
These are well established decoding methods, and should be your starting point for text generation tasks.
Greedy search works well for tasks with relatively short outputs. However, it breaks down when generating longer sequences because it begins to repeat itself.
### Greedy search
Greedy search is the default decoding strategy. It selects the next most likely token at each step. Unless specified in [`GenerationConfig`], this strategy generates a maximum of 20 new tokens.
Greedy search works well for tasks with relatively short outputs where creativity is not a priority. However, it breaks down when generating longer sequences because it begins to repeat itself.
```py
import torch
@ -40,11 +44,11 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
'Hugging Face is an open-source company that provides a suite of tools and services for building, deploying, and maintaining natural language processing'
```
## Contrastive search
### Sampling
[Contrastive search](https://huggingface.co/papers/2202.06417) is a decoding strategy that aims to reduce repetition even while generating longer sequences. This strategy compares how similar a generated token is against previous tokens, and if they're more similar, a penalty is applied.
Sampling, or multinomial sampling, randomly selects a token based on the probability distribution over the entire model's vocabulary (as opposed to the most likely token, as in greedy search). This means every token with a non-zero probability has a chance to be selected. Sampling strategies reduce repetition and can generate more creative and diverse outputs.
Enable contrastive search with the `penalty_alpha` and `top_k` parameters. The `penalty_alpha` manages the penalty applied and `top_k` is the number of most likely tokens to return.
Enable multinomial sampling with `do_sample=True` and `num_beams=1`.
```py
import torch
@ -55,14 +59,14 @@ inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt"
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
# explicitly set to 100 because Llama2 generation length is 4096
outputs = model.generate(**inputs, max_new_tokens=100, penalty_alpha=0.6, top_k=4)
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, num_beams=1)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
'Hugging Face is an open-source company that provides a platform for building and deploying AI models.\nHugging Face is an open-source company that provides a platform for building and deploying AI models. The platform allows developers to build and deploy AI models, as well as collaborate with other developers.\nHugging Face was founded in 2019 by Thibault Wittemberg and Clément Delangue. The company is based in Paris, France.\nHugging Face has'
'Hugging Face is an open-source company 🤗\nWe are open-source and believe that open-source is the best way to build technology. Our mission is to make AI accessible to everyone, and we believe that open-source is the best way to achieve that.'
```
## Beam search
### Beam search
Beam search keeps track of several generated sequences (beams) at each time step. After a certain number of steps, it selects the sequence with the highest *overall* probability. Unlike greedy search, this strategy can "look ahead" and pick a sequence with a higher probability overall even if the initial tokens have a lower probability.
Beam search keeps track of several generated sequences (beams) at each time step. After a certain number of steps, it selects the sequence with the highest *overall* probability. Unlike greedy search, this strategy can "look ahead" and pick a sequence with a higher probability overall even if the initial tokens have a lower probability. It is best suited for input-grounded tasks, like describing an image or speech recognition. You can also use `do_sample=True` with beam search to sample at each step, but beam search will still greedily prune out low probability sequences between steps.
> [!TIP]
> Check out the [beam search visualizer](https://huggingface.co/spaces/m-ric/beam_search_visualizer) to see how beam search works.
@ -83,66 +87,11 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
"['Hugging Face is an open-source company that develops and maintains the Hugging Face platform, which is a collection of tools and libraries for building and deploying natural language processing (NLP) models. Hugging Face was founded in 2018 by Thomas Wolf']"
```
## Diverse beam search
## Advanced decoding methods
[Diverse beam search](https://hf.co/papers/1610.02424) is a variant of beam search that produces more diverse output candidates to choose from. This strategy measures the dissimilarity of sequences and a penalty is applied if sequences are too similar. To avoid high computation costs, the number of beams is divided into groups.
Advanced decoding methods aim at either tackling specific generation quality issues (e.g. repetition) or at improving the generation throughput in certain situations. These techniques are more complex, and may not work correctly with all models.
Enable diverse beam search with the `num_beams`, `num_beam_groups` and `diversity_penalty` parameters (the `num_beams` parameter should be divisible by `num_beam_groups`).
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to("cuda")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
# explicitly set to 100 because Llama2 generation length is 4096
outputs = model.generate(**inputs, max_new_tokens=50, num_beams=6, num_beam_groups=3, diversity_penalty=1.0, do_sample=False)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
'Hugging Face is an open-source company 🤗\nWe are an open-source company. Our mission is to democratize AI and make it accessible to everyone. We believe that AI should be used for the benefit of humanity, not for the benefit of a'
```
## Multinomial sampling
Search methods selects the most likely tokens. Sampling, or multinomial sampling, randomly selects a token based on the probability distribution over the entire models vocabulary. This means every token with a non-zero probability has a chance to be selected. Sampling strategies reduce repetition and can generate more creative and diverse outputs.
Enable multinomial sampling with `do_sample=True` and `num_beams=1`.
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to("cuda")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
# explicitly set to 100 because Llama2 generation length is 4096
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, num_beams=1)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
'Hugging Face is an open-source company 🤗\nWe are open-source and believe that open-source is the best way to build technology. Our mission is to make AI accessible to everyone, and we believe that open-source is the best way to achieve that.'
```
## Beam search multinomial sampling
This decoding strategy is a combination of beam search and multinomial sampling. It generates multiple beams and uses a sampling strategy for each beam.
Enable beam search multinomial sampling by setting `num_beams` to a value greater than 1 and `do_sample=True`.
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to("cuda")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
# explicitly set to 100 because Llama2 generation length is 4096
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, num_beams=4)
'Hugging Face is an open-source company 100% dedicated to making AI more accessible. We believe that AI should be available to everyone, and were working hard to make that a reality.\nWere a team of passionate engineers, designers,'
```
## Speculative decoding
### Speculative decoding
[Speculative](https://hf.co/papers/2211.17192) or assistive decoding isn't a search or sampling strategy. Instead, speculative decoding adds a second smaller model to generate candidate tokens. The main model verifies the candidate tokens in a single `forward` pass, which speeds up the decoding process overall. This method is especially useful for LLMs where it can be more costly and slower to generate tokens. Refer to the [speculative decoding](./llm_optims#speculative-decoding) guide to learn more.
@ -203,7 +152,7 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
</hfoption>
</hfoptions>
### Prompt lookup decoding
#### Prompt lookup decoding
[Prompt lookup decoding](./llm_optims#prompt-lookup-decoding) is a variant of speculative decoding that uses overlapping n-grams as the candidate tokens. It works well for input-grounded tasks such as summarization. Refer to the [prompt lookup decoding](./llm_optims#prompt-lookup-decoding) guide to learn more.
@ -245,7 +194,7 @@ outputs = model.generate(**inputs, assistant_early_exit=4, do_sample=False, max_
tokenizer.batch_decode(outputs, skip_special_tokens=True)
```
### Universal assisted decoding
#### Universal assisted decoding
Universal assisted decoding (UAD) enables the main and assistant models to use different tokenizers. The main models input tokens are re-encoded into assistant model tokens. Candidate tokens are generated in the assistant encoding which are re-encoded into the main model candidate tokens. The candidate tokens are verified as explained in [speculative decoding](#speculative-decoding).
@ -269,7 +218,27 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```
## DoLa
### Contrastive search
[Contrastive search](https://huggingface.co/papers/2202.06417) is a decoding strategy that aims to reduce repetition even while generating longer sequences. This strategy compares how similar a generated token is against previous tokens, and if they're more similar, a penalty is applied.
Enable contrastive search with the `penalty_alpha` and `top_k` parameters. The `penalty_alpha` manages the penalty applied and `top_k` is the number of most likely tokens to return.
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to("cuda")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
# explicitly set to 100 because Llama2 generation length is 4096
outputs = model.generate(**inputs, max_new_tokens=100, penalty_alpha=0.6, top_k=4)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
'Hugging Face is an open-source company that provides a platform for building and deploying AI models.\nHugging Face is an open-source company that provides a platform for building and deploying AI models. The platform allows developers to build and deploy AI models, as well as collaborate with other developers.\nHugging Face was founded in 2019 by Thibault Wittemberg and Clément Delangue. The company is based in Paris, France.\nHugging Face has'
```
### DoLa
[Decoding by Contrasting Layers (DoLa)](https://hf.co/papers/2309.03883) is a contrastive decoding strategy for improving factuality and reducing hallucination. This strategy works by contrasting the logit differences between the final and early layers. As a result, factual knowledge localized to particular layers are amplified. DoLa is not recommended for smaller models like GPT-2.
@ -325,6 +294,210 @@ tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[-1]:], skip_special_tok
</hfoption>
</hfoptions>
### Diverse beam search
[Diverse beam search](https://hf.co/papers/1610.02424) is a variant of beam search that produces more diverse output candidates to choose from. This strategy measures the dissimilarity of sequences and a penalty is applied if sequences are too similar. To avoid high computation costs, the number of beams is divided into groups.
Enable diverse beam search with the `num_beams`, `num_beam_groups` and `diversity_penalty` parameters (the `num_beams` parameter should be divisible by `num_beam_groups`).
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to("cuda")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to("cuda")
# explicitly set to 100 because Llama2 generation length is 4096
outputs = model.generate(**inputs, max_new_tokens=50, num_beams=6, num_beam_groups=3, diversity_penalty=1.0, do_sample=False)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
'Hugging Face is an open-source company 🤗\nWe are an open-source company. Our mission is to democratize AI and make it accessible to everyone. We believe that AI should be used for the benefit of humanity, not for the benefit of a'
```
## Custom decoding methods
Custom decoding methods enable specialized generation behavior such as the following:
- have the model continue thinking if it is uncertain;
- roll back generation if the model gets stuck;
- handle special tokens with custom logic;
- enhanced input preparation for advanced models;
We enable custom decoding methods through model repositories, assuming a specific model tag and file structure (see subsection below). This feature is an extension of [custom modeling code](./models.md#custom-models) and, like such, requires setting `trust_remote_code=True`.
If a model repository holds a custom decoding method, the easiest way to try it out is to load the model and generate with it:
<!-- TODO before merging: 1) better repo name (use a `generate-community` org?) 2) prettify the repo -->
```py
from transformers import AutoModelForCausalLM, AutoTokenizer
# `transformers-community/custom_generate_example` holds a copy of `Qwen/Qwen2.5-0.5B-Instruct`, but
# with custom generation code -> calling `generate` uses the custom decoding method!
tokenizer = AutoTokenizer.from_pretrained("transformers-community/custom_generate_example")
model = AutoModelForCausalLM.from_pretrained(
"transformers-community/custom_generate_example", device_map="auto", trust_remote_code=True
)
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
# The custom decoding method is a minimal greedy decoding implementation. It also prints a custom message at run time.
gen_out = model.generate(**inputs)
# you should now see its custom message, "✨ using a custom generation method ✨"
print(tokenizer.batch_decode(gen_out, skip_special_tokens=True))
'The quick brown fox jumps over a lazy dog, and the dog is a type of animal. Is'
```
Model repositories with custom decoding methods have a special property: their decoding method can be loaded from **any** model through [`~GenerationMixin.generate`]'s `custom_generate` argument. This means anyone can create and share their custom generation method to potentially work with any Transformers model, without requiring users to install additional Python packages.
```py
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", device_map="auto")
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
# `custom_generate` replaces the original `generate` by the custom decoding method defined in
# `transformers-community/custom_generate_example`
gen_out = model.generate(**inputs, custom_generate="transformers-community/custom_generate_example", trust_remote_code=True)
print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
'The quick brown fox jumps over a lazy dog, and the dog is a type of animal. Is'
```
You should read the `README.md` file of the repository containing the custom generation strategy to see what the new arguments and output type differences are, if they exist. Otherwise, you can assume it works like the base [`~GenerationMixin.generate`] method.
> [!TIP]
> You can find all custom decoding methods by [searching for their custom tag.](https://huggingface.co/models?other=custom_generate), `custom_generate`
Consider the Hub repository [transformers-community/custom_generate_example](https://huggingface.co/transformers-community/custom_generate_example) as an example. The `README.md` states that it has an additional input argument, `left_padding`, which adds a number of padding tokens before the prompt.
```py
gen_out = model.generate(
**inputs, custom_generate="transformers-community/custom_generate_example", trust_remote_code=True, left_padding=5
)
print(tokenizer.batch_decode(gen_out)[0])
'<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>The quick brown fox jumps over the lazy dog.\n\nThe sentence "The quick'
```
If the custom method has pinned Python requirements that your environment doesn't meet, you'll get an exception about missing requirements. For instance, [transformers-community/custom_generate_bad_requirements](https://huggingface.co/transformers-community/custom_generate_bad_requirements) has an impossible set of requirements defined in its `custom_generate/requirements.txt` file, and you'll see the error message below if you try to run it.
```
ImportError: Missing requirements in your local environment for `transformers-community/custom_generate_bad_requirements`:
foo (installed: None)
bar==0.0.0 (installed: None)
torch>=99.0 (installed: 2.6.0)
```
Updating your Python requirements accordingly will remove this error message.
### Creating a custom decoding method
To create a new decoding method, you need to create a new [**Model**](https://huggingface.co/new) repository and push a few files into it.
1. The model you've designed your decoding method with.
2. `custom_generate/generate.py`, which contains all the logic for your custom decoding method.
3. `custom_generate/requirements.txt`, used to optionally add new Python requirements and/or lock specific versions to correctly use your method.
4. `README.md`, where you should add the `custom_generate` tag and document any new arguments or output type differences of your custom method here.
After you've added all required files, your repository should look like this
```
your_repo/
├── README.md # include the 'custom_generate' tag
├── config.json
├── ...
└── custom_generate/
├── generate.py
└── requirements.txt
```
#### Adding the base model
The starting point for your custom decoding method is a model repository just like any other. The model to add to this repository should be the model you've designed your method with, and it is meant to be part of a working self-contained model-generate pair. When the model in this repository is loaded, your custom decoding method will override `generate`. Don't worry -- your decoding method can still be loaded with any other Transformers model, as explained in the section above.
If you simply want to copy an existing model, you can do
```py
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("source/model_repo")
model = AutoModelForCausalLM.from_pretrained("source/model_repo")
tokenizer.save_pretrained("your/decoding_method", push_to_hub=True)
model.save_pretrained("your/decoding_method", push_to_hub=True)
```
#### generate.py
This is the core of your decoding method. It *must* contain a method named `generate`, and this method *must* contain a `model` argument as its first argument. `model` is the model instance, which means you have access to all attributes and methods in the model, including the ones defined in [`GenerationMixin`] (like the base `generate` method).
> [!WARNING]
> `generate.py` must be placed in a folder named `custom_generate`, and not at the root level of the repository. The file paths for this feature are hardcoded.
Under the hood, when the base [`~GenerationMixin.generate`] method is called with a `custom_generate` argument, it first checks its Python requirements (if any), then locates the custom `generate` method in `generate.py`, and finally calls the custom `generate`. All received arguments and `model` are forwarded to your custom `generate` method.
This means your `generate` can have a mix of original and custom arguments (as well as a different output type) as shown below.
```py
import torch
def generate(model, input_ids, generation_config=None, left_padding=None, **kwargs):
generation_config = generation_config or model.generation_config # default to the model generation config
cur_length = input_ids.shape[1]
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
# Example of custom argument: add `left_padding` (integer) pad tokens before the prompt
if left_padding is not None:
if not isinstance(left_padding, int) or left_padding < 0:
raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
pad_token = kwargs.pop("pad_token", None) or generation_config.pad_token_id or model.config.pad_token_id
if pad_token is None:
raise ValueError("pad_token is not defined")
batch_size = input_ids.shape[0]
pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
input_ids = torch.cat((pad_tensor, input_ids), dim=1)
cur_length = input_ids.shape[1]
# Simple greedy decoding loop
while cur_length < max_length:
logits = model(input_ids).logits
next_token_logits = logits[:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1)
cur_length += 1
return input_ids
```
Follow the recommended practices below to ensure your custom decoding method works as expected.
- Feel free to reuse the logic for validation and input preparation in the original [`~GenerationMixin.generate`].
- Pin the `transformers` version in the requirements if you use any private method/attribute in `model`.
- You can add other files in the `custom_generate` folder, and use relative imports.
- Consider adding model validation, input validation, or even a separate test file to help users sanity-check your code in their environment.
#### requirements.txt
You can optionally specify additional Python requirements in a `requirements.txt` file inside the `custom_generate` folder. These are checked at runtime and an exception will be thrown if they're missing, nudging users to update their environment accordingly.
#### README.md
The root level `README.md` in the model repository usually describes the model therein. However, since the focus of the repository is the custom decoding method, we highly recommend to shift its focus towards describing the custom decoding method. In addition to a description of the method, we recommend documenting any input and/or output differences to the original [`~GenerationMixin.generate`]. This way, users can focus on what's new, and rely on Transformers docs for generic implementation details.
For discoverability, we highly recommend you to add the `custom_generate` tag to your repository. To do so, the top of your `README.md` file should look like the example below. After you push the file, you should see the tag in your repository!
```
---
library_name: transformers
tags:
- custom_generate
---
(your markdown content here)
```
Recommended practices:
- Document input and output differences in [`~GenerationMixin.generate`].
- Add self-contained examples to enable quick experimentation.
- Describe soft-requirements such as if the method only works well with a certain family of models.
## Resources
Read the [How to generate text: using different decoding methods for language generation with Transformers](https://huggingface.co/blog/how-to-generate) blog post for an explanation of how common decoding strategies work.

View File

@ -90,11 +90,6 @@ class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
@ -114,13 +109,14 @@ Load the model with [`~PreTrainedModel.from_pretrained`].
```py
from transformers import SamModel
from transformers.models.sam import modeling_sam
# replace the attention class in the modeling_sam module
modeling_sam.SamVisionAttention = SamVisionAttentionSplit
# load the pretrained SAM model
model = SamModel.from_pretrained("facebook/sam-vit-base")
# replace the attention class in the vision_encoder module
for layer in model.vision_encoder.layers:
if hasattr(layer, "attn"):
layer.attn = SamVisionAttentionSplit(model.config.vision_config, model.config.vision_config.window_size)
```
## LoRA
@ -138,7 +134,7 @@ config = LoraConfig(
# apply LoRA to q and v
target_modules=["q", "v"],
lora_dropout=0.1,
task_type="mask-generation"
task_type="FEATURE_EXTRACTION"
)
```
@ -152,5 +148,5 @@ Call [print_trainable_parameters](https://huggingface.co/docs/peft/package_refer
```py
model.print_trainable_parameters()
"trainable params: 608,256 || all params: 94,343,728 || trainable%: 0.6447"
"trainable params: 589,824 || all params: 94,274,096 || trainable%: 0.6256"
```

View File

@ -57,6 +57,7 @@ This model was contributed by [lysandre](https://huggingface.co/lysandre). This
- Embedding size E is different from hidden size H justified because the embeddings are context independent (one embedding vector represents one token), whereas hidden states are context dependent (one hidden state represents a sequence of tokens) so it's more logical to have H >> E. Also, the embedding matrix is large since it's V x E (V being the vocab size). If E < H, it has less parameters.
- Layers are split in groups that share parameters (to save memory).
Next sentence prediction is replaced by a sentence ordering prediction: in the inputs, we have two sentences A and B (that are consecutive) and we either feed A followed by B or B followed by A. The model must predict if they have been swapped or not.
- The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
### Using Scaled Dot Product Attention (SDPA)

View File

@ -39,7 +39,7 @@ Checkout all Bamba-9B model checkpoints [here](https://github.com/foundation-mod
<!---
## Usage Tips
Tips:
Tips:
- The architecture is based on Mamba-2 models.
@ -63,7 +63,35 @@ response = model.generate(**inputs, max_new_tokens=64)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
```
## Padding-Free Training
Bamba supports padding-free training in which distinct training examples can be concatenated
together while nevertheless processing the inputs as though they belonged to separate batches. When
the examples are of varying lengths, padding-free training can provide significant speed ups and
memory savings compared to batching the examples together and using padding, as the unnecessary
compute and memory due to padding is avoided entirely. The performance gains depend on factors such
as the model and the data distribution, but throughput gains up to [~2x are commonly
seen](https://github.com/huggingface/transformers/pull/35861#issue-2807873129).
Using padding-free training with Bamba requires the `flash-attn`, `mamba-ssm`, and `causal-conv1d`
packages, and the following arguments must be passed to the model in addition to `input_ids` and
`labels`:
* `position_ids: torch.LongTensor`: the position index of each token in each sequence.
* `seq_idx: torch.IntTensor`: the index of each sequence in the batch.
* Each of the [`FlashAttentionKwargs`]
* `cu_seq_lens_q: torch.LongTensor`: The cumulative sequence lengths of all queries.
* `cu_seq_lens_k: torch.LongTensor`: The cumulative sequence lengths of all keys.
* `max_length_q: int`: the longest query length in the batch.
* `max_length_k: int`: the longest key length in the batch.
The `attention_mask` inputs should not be provided. The [`DataCollatorWithFlattening`] can be used
to programmatically generate the above set of additional arguments using `return_seq_idx=True` and
`return_flash_attn_kwargs=True`. See [this blog post](https://huggingface.co/blog/packing-with-FA2)
for additional information.
[[autodoc]] BambaForCausalLM
- forward
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).
This HF implementation is contributed by [ani300](https://github.com/ani300) and [fabianlim](https://github.com/fabianlim).

View File

@ -55,6 +55,7 @@ This model was contributed by [sshleifer](https://huggingface.co/sshleifer). The
* mask a span of k tokens with a single mask token (a span of 0 tokens is an insertion of a mask token)
* permute sentences
* rotate the document to make it start at a specific token
- The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
## Implementation Notes

View File

@ -36,6 +36,7 @@ This model was contributed by [kamalkraj](https://huggingface.co/kamalkraj). The
- BioGPT is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than the left.
- BioGPT was trained with a causal language modeling (CLM) objective and is therefore powerful at predicting the next token in a sequence. Leveraging this feature allows BioGPT to generate syntactically coherent text as it can be observed in the run_generation.py example script.
- The model can take the `past_key_values` (for PyTorch) as input, which is the previously computed key/value attention pairs. Using this (past_key_values or past) value prevents the model from re-computing pre-computed values in the context of text generation. For PyTorch, see past_key_values argument of the BioGptForCausalLM.forward() method for more information on its usage.
- The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
### Using Scaled Dot Product Attention (SDPA)

View File

@ -53,6 +53,7 @@ The original code for vision can be found [here](https://github.com/facebookrese
- For Data2VecAudio, preprocessing is identical to [`Wav2Vec2Model`], including feature extraction
- For Data2VecText, preprocessing is identical to [`RobertaModel`], including tokenization.
- For Data2VecVision, preprocessing is identical to [`BeitModel`], including feature extraction.
- The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
### Using Scaled Dot Product Attention (SDPA)

View File

@ -0,0 +1,65 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# FalconH1
## Overview
The FalconH1 model was developed by the TII Pretraining team. A comprehensive research paper covering the architecture, pretraining dynamics, experimental results, and conclusions is forthcoming. You can read more about this series in [this website](https://github.com/tiiuae/Falcon-H1).
## Contributors
This model was contributed by [DhiyaEddine](https://huggingface.co/DhiyaEddine), [ybelkada](https://huggingface.co/ybelkada), [JingweiZuo](https://huggingface.co/JingweiZuo), [IlyasChahed](https://huggingface.co/IChahed), and [MaksimVelikanov](https://huggingface.co/yellowvm).
The original code can be found [here](https://github.com/tiiuae/Falcon-H1).
## FalconH1Config
| Model | Depth | Dim | Attn Heads | KV | Mamba Heads | d_head | d_state | Ctx Len |
|-----------|--------|------|------------|----|--------------|--------------|------|-----------------|
| H1 0.5B | 36 | 1024 | 8 | 2 | 24 | 64 / 64 | 128 | 4K, 16K-SFT |
| H1 1.5B | 24 | 2048 | 8 | 2 | 48 | 128 / 64 | 256 | 128K |
| H1 1.5B-d | 66 | 1280 | 6 | 2 | 24 | 128 / 64 | 256 | 128K |
| H1 3B | 32 | 2560 | 10 | 2 | 32 | 128 / 128 | 256 | 128K |
| H1 7B | 44 | 3072 | 12 | 2 | 24 | 128 / 128 | 256 | 256K |
| H1 34B | 72 | 5120 | 20 | 4 | 32 | 128 / 128 | 256 | 256K |
[[autodoc]] FalconH1Config
<!---
## Usage Tips
Tips:
- The architecture is based on Mamba-2 models.
## FalconH1Model
[[autodoc]] FalconH1Model
- forward
-->
## FalconH1ForCausalLM
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("tiiuae/Falcon-H1-7B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("tiiuae/Falcon-H1-7B-Instruct")
message = ["Mamba is a snake with following properties "]
inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
response = model.generate(**inputs, max_new_tokens=64)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])
```
[[autodoc]] FalconH1ForCausalLM
- forward
This HF implementation is contributed by [younesbelkada](https://github.com/younesbelkada) and [DhiaEddineRhaiem](https://github.com/dhiaEddineRhaiem).

View File

@ -46,8 +46,12 @@ The main differences compared to GPT2.
- Merge the key and value caches into one (this changes the format of layer_past/ present, does it risk creating problems?)
- Use the memory layout (self.num_heads, 3, self.head_dim) instead of `(3, self.num_heads, self.head_dim)` for the QKV tensor with MHA. (prevents an overhead with the merged key and values, but makes the checkpoints incompatible with the original openai-community/gpt2 model).
You can read more about the optimizations in the [original pull request](https://github.com/huggingface/transformers/pull/22575)
> [!NOTE]
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
## Combining Starcoder and Flash Attention 2
First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature.

View File

@ -50,7 +50,7 @@ This model was contributed by [patrickvonplaten](https://huggingface.co/patrickv
- Hubert is a speech model that accepts a float array corresponding to the raw waveform of the speech signal.
- Hubert model was fine-tuned using connectionist temporal classification (CTC) so the model output has to be decoded
using [`Wav2Vec2CTCTokenizer`].
- The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
## Using Flash Attention 2

View File

@ -51,6 +51,9 @@ multilingual it expects the sequences in a certain format: A special language id
source and target text. The source text format is `[lang_code] X [eos]`, where `lang_code` is source language
id for source text and target language id for target text, with `X` being the source or target text.
> [!NOTE]
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
The [`M2M100Tokenizer`] depends on `sentencepiece` so be sure to install it before running the
examples. To install `sentencepiece` run `pip install sentencepiece`.

View File

@ -35,6 +35,9 @@ You can find all the original mBART checkpoints under the [AI at Meta](https://h
> [!TIP]
> Click on the mBART models in the right sidebar for more examples of applying mBART to different language tasks.
> [!NOTE]
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
The example below demonstrates how to translate text with [`Pipeline`] or the [`AutoModel`] class.
<hfoptions id="usage">

View File

@ -62,6 +62,9 @@ python src/transformers/models/musicgen/convert_musicgen_transformers.py \
--checkpoint small --pytorch_dump_folder /output/path --safe_serialization
```
> [!NOTE]
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
## Generation
MusicGen is compatible with two generation modes: greedy and sampling. In practice, sampling leads to significantly

View File

@ -44,6 +44,9 @@ There are two key differences with MusicGen:
1. The audio prompt is used here as a conditional signal for the generated audio sample, whereas it's used for audio continuation in [MusicGen](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen).
2. Conditional text and audio signals are concatenated to the decoder's hidden states instead of being used as a cross-attention signal, as in MusicGen.
> [!NOTE]
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
## Generation
MusicGen Melody is compatible with two generation modes: greedy and sampling. In practice, sampling leads to significantly better results than greedy, thus we encourage sampling mode to be used where possible. Sampling is enabled by default, and can be explicitly specified by setting `do_sample=True` in the call to [`MusicgenMelodyForConditionalGeneration.generate`], or by overriding the model's generation config (see below).

View File

@ -41,6 +41,9 @@ Tips:
- OPT has the same architecture as [`BartDecoder`].
- Contrary to GPT2, OPT adds the EOS token `</s>` to the beginning of every prompt.
> [!NOTE]
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
## Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with OPT. If you're

View File

@ -40,6 +40,9 @@ The abstract from the paper is the following:
`Qwen2-Audio-7B` and `Qwen2-Audio-7B-Instruct` can be found on the [Huggingface Hub](https://huggingface.co/Qwen)
> [!NOTE]
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
### Inference
```python

View File

@ -43,8 +43,8 @@ import requests
from transformers import SamHQModel, SamHQProcessor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base").to(device)
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
@ -69,8 +69,8 @@ import requests
from transformers import SamHQModel, SamHQProcessor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamHQModel.from_pretrained("sushmanth/sam_hq_vit_b").to(device)
processor = SamHQProcessor.from_pretrained("sushmanth/sam_hq_vit_b")
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-base").to(device)
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-base")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")

View File

@ -46,6 +46,9 @@ This model was contributed by [anton-l](https://huggingface.co/anton-l).
- SEWForCTC is fine-tuned using connectionist temporal classification (CTC) so the model output has to be decoded using
[`Wav2Vec2CTCTokenizer`].
> [!NOTE]
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
## Resources
- [Audio classification task guide](../tasks/audio_classification)

View File

@ -54,6 +54,9 @@ found [here](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech-SAT).
decoded using [`Wav2Vec2CTCTokenizer`].
- UniSpeechSat performs especially well on speaker verification, speaker identification, and speaker diarization tasks.
> [!NOTE]
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
## Resources
- [Audio classification task guide](../tasks/audio_classification)

View File

@ -49,6 +49,9 @@ found [here](https://github.com/microsoft/UniSpeech/tree/main/UniSpeech).
- UniSpeech model can be fine-tuned using connectionist temporal classification (CTC) so the model output has to be
decoded using [`Wav2Vec2CTCTokenizer`].
> [!NOTE]
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
## Resources
- [Audio classification task guide](../tasks/audio_classification)

View File

@ -50,6 +50,9 @@ Note: Meta (FAIR) released a new version of [Wav2Vec2-BERT 2.0](https://huggingf
- Wav2Vec2 model was trained using connectionist temporal classification (CTC) so the model output has to be decoded
using [`Wav2Vec2CTCTokenizer`].
> [!NOTE]
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
## Using Flash Attention 2
Flash Attention 2 is an faster, optimized version of the model.

View File

@ -32,6 +32,9 @@ rendered properly in your Markdown viewer.
You can find all the original Whisper checkpoints under the [Whisper](https://huggingface.co/collections/openai/whisper-release-6501bba2cf999715fd953013) collection.
> [!NOTE]
> The `head_mask` argument is ignored when using all attention implementation other than "eager". If you have a `head_mask` and want it to have effect, load the model with `XXXModel.from_pretrained(model_id, attn_implementation="eager")`
> [!TIP]
> Click on the Whisper models in the right sidebar for more examples of how to apply Whisper to different audio tasks.

View File

@ -54,8 +54,8 @@ For each model type, there is a separate class for each machine learning framewo
from transformers import AutoModelForCausalLM, MistralForCausalLM
# load with AutoClass or model-specific class
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", , torch_dtype="auto", device_map="auto")
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", , torch_dtype="auto", device_map="auto")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype="auto", device_map="auto")
model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype="auto", device_map="auto")
```
</hfoption>
@ -272,6 +272,7 @@ Explicitly set the [torch_dtype](https://pytorch.org/docs/stable/tensor_attribut
<hfoption id="specific dtype">
```py
import torch
from transformers import AutoModelForCausalLM
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype=torch.float16)

View File

@ -13,9 +13,15 @@ rendered properly in your Markdown viewer.
-->
# Distributed GPU inference
# Tensor parallelism in transformers
[Tensor parallelism](./perf_train_gpu_many#tensor-parallelism) shards a model onto multiple GPUs and parallelizes computations such as matrix multiplication. It enables fitting larger model sizes into memory and is faster because each GPU can process a tensor slice.
This document assumes that you are already familiar with the basics of tensor parallelism. If you are not, please refer to the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism) section on tensor parallelism.
> [!TIP]
> Tensor parallelism is very communication intensive, therefore it is reccomended to use it on a single machine with multiple GPUs, utilizing fast intra-node communication. For multi-node training, methods as pipeline or data parallelism are more efficient (depending on your use case).
Tensor parallelism requires slight changes to the model parameters, therefore in transformers, we support some of the popular models out of the box.
> [!TIP]
> Expand the list below to see which models support tensor parallelism. Open a GitHub issue or pull request to add support for a model not currently below.
@ -37,9 +43,218 @@ rendered properly in your Markdown viewer.
</details>
Set `tp_plan="auto"` in [`~AutoModel.from_pretrained`] to enable tensor parallelism for inference.
## Using 🤗 transformers
```py
Transformers provides a simple interface to use for tensor parallelism. We provide multiple classes implementing different partitioning
strategies and a simple entrypoint to parallelize `nn.Module` instance. You won't have to interact with this interface directly, everything is done in `PretrainedModel.from_pretrained` method for you. This section will first talk about the partitioning strategies
we support, then the user interface you will be interacting with, and finally it will teach you how to extend it with your own partitioning
strategies.
### Partitioning strategies
In transformers, partitioning strategies reside in a class `ParallelInterface` which works like a mapping from string to the strategy implementation.
```python
class ParallelInterface(MutableMapping):
"""
Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
it needs to declare a new instance of this class inside the `modeling_<model>.py`, and declare it on that instance.
"""
_global_mapping = {
"colwise": ColwiseParallel(),
"rowwise": RowwiseParallel(),
"colwise_rep": ColwiseParallel(output_layouts=Replicate()),
"rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
"local_colwise": ColwiseParallel(use_dtensor=False),
"local_rowwise": RowwiseParallel(use_dtensor=False),
"local": IsolatedParallel(),
"gather": GatherParallel(),
"local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
"sequence_parallel": SequenceParallel(),
"replicate": ReplicateParallel(),
}
```
We support the following strategies:
- `ColwiseParallel` - A simple column-wise partitioning, being able to handle both weights and biases, does exactly what we've discussed before.
- `RowwiseParallel` - Again, row-wise partitioning as dicussed before, supports weights and biases, on top of that it also supports `nn.Embedding` modules.
- `SequenceParallel` - Sequence parallel implementation, for support of `LayerNorm` and `Dropout` layers. Also supports Python implementation of `RMSNorm` (see [this](https://github.com/facebookresearch/llama/blob/main/llama/model.py#L34))
- `PackedColwiseParallel` - A variant of column-wise partitioning, however it works on packed weights (i.e. `up_proj` and `gate_proj` being packed together). For more details, see [this comment](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/tensor_parallel.py#L79-#L108)
- `PackedRowwiseParallel` - A variant of row-wise partitioning, works on packed weights, for more details check the comment linked above.
- `GatherParallel` - A very simple class, that only makes the outputs of the module to be gathered across devices.
- `IsolatedParallel` - This is a special case, where we want to *isolate* the module from the rest of the devices (world). This is used for Experts in MoE layers, basically creating Expert parallelism of sorts.
- `ReplicateParallel` - Many `torch.distributed` APIs break if model is partially sharded, so this class is used to replicate the module across all devices.
### Sharding a model
We provide two ways to shard a model, first one is to use `auto` tensor parallelism plan, which will automatically shard the model based on our predefined configuration. This requires the model to have predefined tensor parallel plan in transformers.
```python
from transformers import AutoModelForCausalLM
# model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # better for smaller number of GPUs
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" # better to visualize all the possible strategies
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan="auto")
print(model._tp_plan)
```
> [!TIP]
> For a list of models that support tensor parallelism, see the [Supported models](#supported-models) section above.
The second way is to manually specify your own partitioning plan.
```python
from transformers import AutoModelForCausalLM
tp_plan = {
"model.layers.*.self_attn.q_proj": "colwise",
"model.layers.*.self_attn.k_proj": "colwise",
"model.layers.*.self_attn.v_proj": "colwise",
"model.layers.*.self_attn.o_proj": "rowwise",
...
}
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan=tp_plan)
print(model._tp_plan)
```
You might have noticed that there are some special cases in the `ParallelInterface` mapping, let's now talk about them. This will help you understand their purpose and help with extending to other strategies.
### PackedRowwiseParallel
This class is a special case of `RowwiseParallel`, it's used to shard packed weights. Weight packing is a common technique used in models. It's a technique where we pack multiple linear layers into a single, bigger one.
For example in `Llama4` model, we pack `up_proj` and `gate_proj` into a single `gate_up_proj` module.
```python
class Llama4TextExperts(nn.Module):
...
self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))
```
Then in forward, we can use batch matrix multiplication to compute the output of the `gate_up_proj` module.
```python
def forward(self, hidden_states):
...
gate_up = torch.bmm(hidden_states, self.gate_up_proj) # Compute the output of the gate_up_proj module
gate, up = gate_up.chunk(2, dim=-1) # Split the output into gate and up
```
In this case, we need to use the `PackedRowwiseParallel` strategy to shard the `gate_up_proj` module, as using a simple `RowwiseParallel` will shard the layers wrongly.
> [!TIP]
> If this is a bit difficult to wrap your head around, check out [this comment](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/tensor_parallel.py#L79-#L108) for an amazing visual representation of why `Packed*` needs to be used.
### `local*` strategies
You could have noticed that there are `local*` strategies, which use the same layers as `*` strategy, but don't use `DTensor` at all.
This is because `DTensor` is not supported for some of the operations: such as `torch.chunk`. Therefore, sometimes we need to use the `local*` strategies, which use vanilla `torch.Tensor` and do some of the distributed logic manually.
<!---
Readd this when I get the exact error message
> [!TIP]
> If you are using a custom partitioning strategy, and it's not working with `... is not supported` error, try using the `local*` strategies to see if they work better.
-->
> [!WARNING]
> Manually specifying your own partitiong plan requires a good understanding of the model architecture and how the partitioning strategies interact together. If you are not sure about this, the resulting model can be very slow, even failing or incorrect. Again, refer to the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=tensor_parallelism) which can teach you everything required.
### Extending the interface with your own partitioning strategies
This is a very advanced topic, which requires a good understanding of distributed collectives and the model architecture.
Your custom partitioning strategy should inherit from `TensorParallelLayer` defined in [integrations/tensor_parallel.py](https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/tensor_parallel.py) and implement: `partition_tensor`, `_prepare_input_fn` and `_prepare_output_fn`. Then it should be registered in the `ParallelInterface` mapping, so our dispatching logic can find it when specified in the `tp_plan`.
Let's go through this workflow step by step, on an already existing example: `ColwiseParallel`.
1. Inherit from `TensorParallelLayer` and initialization
```python
class ColwiseParallel(TensorParallelLayer):
def __init__(
self,
*,
input_layouts: Optional[Placement] = None, # The input layout coming from the previous layer
output_layouts: Optional[Placement] = None, # The output layout we want to achieve
use_local_output: bool = True, # Whether to use local output or not
use_dtensor=True, # Whether to use DTensor or not
):
self.input_layouts = (input_layouts or Replicate(),) # The input sharding coming from the previous layer
self.output_layouts = (output_layouts or Shard(-1),) # Desired output sharding
self.desired_input_layouts = (Replicate(),) # Desired input sharding, inputs should be replicated across GPUs
self.use_local_output = use_local_output
self.use_dtensor = use_dtensor
```
In the `__init__` method, we define these attributes, where `input_layouts` and `output_layouts` describing, how the input and output tensors should be placed on the devices. `desired_input_layouts` is used to specify, how the input *SHOULD* be placed on the devices.
2a. Implement `partition_tensor` method
```python
def partition_tensor(
self,
param, # Full tensor of the parameter
empty_param, # Empty tensor of the parameter, will be filled with the partitioned tensor
param_type, # Type of the parameter, `bias` or `weight`
param_casting_dtype, # The type to cast the parameter to
to_contiguous, # Whether to convert the tensor to a contiguous memory layout
rank, # The rank of the current device
device_mesh, # The device mesh
) -> nn.Parameter: # Return the partitioned parameter
...
```
This method is used to partition the tensor, and fill the `empty_param` with the partitioned tensor.
We provide some utility functions to help you with this, such as `get_tensor_shard` which will get you the correct shard of the original parameter for this rank or `get_packed_weights` to help with packed weights.
2b. Implement `_prepare_input_fn` and `_prepare_output_fn` methods
These methods are used as [`pre-forward`](https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_pre_hook.html) and [`forward`](https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_hook.html) hooks respectively. Their purpose is to re-distribute the inputs and outputs to the desired layout, passed in the `__init__` method.
```python
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
...
# Do some custom logic, cast to DTensor etc.
...
return inputs.redistribute(placements=desired_input_layouts, device_mesh=device_mesh)
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
...
# Do some custom logic, cast to DTensor etc.
...
return outputs.redistribute(placements=output_layouts, device_mesh=device_mesh)
```
3. Register the strategy
Congratulations! You've implemented your own partitioning strategy. Now, to use it with your own `tp_plan`, you need to register it in the `ParallelInterface` mapping.
```python
from transformers.integrations.tensor_parallel import ParallelInterface
ParallelInterface.register_strategy("colwise_custom", ColwiseParallel)
```
And now you can use it in your `tp_plan` as such:
```python
tp_plan = {
"model.layers.*.self_attn.q_proj": "colwise_custom",
...
}
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan=tp_plan)
```
## Full example
Let's go through a full example of inference with tensor parallelism.
```python
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
@ -66,17 +281,49 @@ Launch the inference script above on [torchrun](https://pytorch.org/docs/stable/
torchrun --nproc-per-node 4 demo.py
```
For CPU, please binding different socket on each rank. For example, if you are using Intel 4th Gen Xeon:
```bash
export OMP_NUM_THREADS=56
numactl -C 0-55 -m 0 torchrun --nnodes=2 --node_rank=0 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & numactl -C 56-111 -m 1 torchrun --nnodes=2 --node_rank=1 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & wait
```
The CPU benchmark data will be released soon.
You can benefit from considerable speed ups for inference, especially for inputs with large batch size or long sequences.
For a single forward pass on [Llama](./model_doc/llama) with a sequence length of 512 and various batch sizes, you can expect the following speed ups.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Meta-Llama-3-8B-Instruct%2C%20seqlen%20%3D%20512%2C%20python%2C%20w_%20compile.png">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/Meta-Llama-3-8B-Instruct%2C%20seqlen%20%3D%20512%2C%20python%2C%20w_%20compile.png">
</div>
## Tensor parallelism in-depth
Our implementation of tensor parallelism is framework-agnostic in design, but the specific implementations we've developed rely on the torch.distributed package. We heavily utilize abstractions such as `DeviceMesh` or `DTensor` to provide a simple and extensible interface to the user.
### DeviceMesh
Imagine `DeviceMesh` as a multi-dimensional grid of devices that communicate together. Different parallelization strategies require different types of communication patterns, therefore we can create a `DeviceMesh` with multiple submeshes:
```python
from torch.distributed.device_mesh import init_device_mesh
# Create a 1D mesh of 4 GPUs
device_mesh = init_device_mesh("cuda", (4,), mesh_dim_names=["tp"])
```
Then, most of the `torch.distributed` defined parallelization strategies can be applied to a mesh itself, or its submesh, automatically handling the communication patterns.
### DTensor
Abbreviation for Distributed Tensor, `DTensor` is a tensor subclass that handles the distributed logic on-top of the usual tensor operations. Most of the model weights in case of tensor parallelism are stored as `DTensor`s (with some exceptions, more on that later).
The most important part of DTensor, that is crucial to understand, is the `placement` attribute. It's an attribute that tells PyTorch how is the tensor placed on the devices of the `DeviceMesh`.
It can have the following values:
- `Shard(dimension)` - Annotates that this `DTensor` is sharded across a given dimension, over the `DeviceMesh` it was constructed under. For example, if we would like to shard weights for column-wise partitioning, we would do:
```python
weight = ...
weight = DTensor.from_local(weight, device_mesh["tp"], placements=[Shard(0)]) # Shard across the 1st (column-wise) dimension
bias = ...
bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Shard(-1)]) # Shard across the ONLY dimension
```
To give another example, for row-wise partitioning, we would do:
```python
weight = ...
weight = DTensor.from_local(weight, device_mesh["tp"], placements=[Shard(1)]) # Shard across the 2nd (row-wise) dimension
bias = ...
bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Replicate()]) # Replicate bias across all GPUs
```
- `Replicate()` - Annotates that this `DTensor` is replicated across the `DeviceMesh`. Very straight-forward, only creates a full copy of the tensor on each device.
- `Partial()` - This placement is mostly of no interest to us, it's used to annotate that this tensor is pending a reduction operation.

View File

@ -106,6 +106,8 @@ dataset[0]["text"]
Remember to resample the sampling rate to match the pretrained models required sampling rate.
```py
from datasets import Audio
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
```

View File

@ -372,14 +372,14 @@ accelerate launch \
### torch.compile
[torch.compile](./perf_torch_compile) can significantly speed up training and reduce computational overhead. Configure your torch.compile settings in [`TrainingArguments`]. Set `torch.compile` to `True`, and select a backend and compile mode.
[torch.compile](./perf_torch_compile) can significantly speed up training and reduce computational overhead. Configure your torch.compile settings in [`TrainingArguments`]. Set `torch_compile` to `True`, and select a backend and compile mode.
```py
from transformers import TrainingArguments
training_args = TrainingArguments(
torch.compile=True,
torch.compile_backend="inductor",
torch_compile=True,
torch_compile_backend="inductor",
torch_compile_mode="default",
...,
)

View File

@ -157,5 +157,8 @@
title: 通用工具
- local: internal/time_series_utils
title: 时序数据工具
- sections:
- local: model_doc/bert
title: BERT
title: 内部辅助工具
title: 应用程序接口 (API)
title: 应用程序接口 (API)

View File

@ -0,0 +1,258 @@
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="TensorFlow" src="https://img.shields.io/badge/TensorFlow-FF6F00?style=flat&logo=tensorflow&logoColor=white">
<img alt="Flax" src="https://img.shields.io/badge/Flax-29a79b.svg?style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAC0AAAAtCAMAAAANxBKoAAAC7lBMVEUAAADg5vYHPVgAoJH+/v76+v39/f9JbLP///9+AIgAnY3///+mcqzt8fXy9fgkXa3Ax9709fr+///9/f8qXq49qp5AaLGMwrv8/P0eW60VWawxYq8yqJzG2dytt9Wyu9elzci519Lf3O3S2efY3OrY0+Xp7PT///////+dqNCexMc6Z7AGpJeGvbenstPZ5ejQ1OfJzOLa7ejh4+/r8fT29vpccbklWK8PVa0AS6ghW63O498vYa+lsdKz1NDRt9Kw1c672tbD3tnAxt7R6OHp5vDe7OrDyuDn6vLl6/EAQKak0MgATakkppo3ZK/Bz9y8w9yzu9jey97axdvHzeG21NHH4trTwthKZrVGZLSUSpuPQJiGAI+GAI8SWKydycLL4d7f2OTi1+S9xNzL0ePT6OLGzeEAo5U0qJw/aLEAo5JFa7JBabEAp5Y4qZ2QxLyKmsm3kL2xoMOehrRNb7RIbbOZgrGre68AUqwAqZqNN5aKJ5N/lMq+qsd8kMa4pcWzh7muhLMEV69juq2kbKqgUaOTR5uMMZWLLZSGAI5VAIdEAH+ovNDHuNCnxcy3qcaYx8K8msGplrx+wLahjbYdXrV6vbMvYK9DrZ8QrZ8tqJuFms+Sos6sw8ecy8RffsNVeMCvmb43aLltv7Q4Y7EZWK4QWa1gt6meZKUdr6GOAZVeA4xPAISyveLUwtivxtKTpNJ2jcqfvcltiMiwwcfAoMVxhL+Kx7xjdrqTe60tsaNQs6KaRKACrJ6UTZwkqpqTL5pkHY4AloSgsd2ptNXPvNOOncuxxsqFl8lmg8apt8FJcr9EbryGxLqlkrkrY7dRa7ZGZLQ5t6iXUZ6PPpgVpZeJCJFKAIGareTa0+KJod3H0deY2M+esM25usmYu8d2zsJOdcBVvrCLbqcAOaaHaKQAMaScWqKBXqCXMJ2RHpiLF5NmJZAdAHN2kta11dKu1M+DkcZLdb+Mcql3TppyRJdzQ5ZtNZNlIY+DF4+voCOQAAAAZ3RSTlMABAT+MEEJ/RH+/TP+Zlv+pUo6Ifz8+fco/fz6+evr39S9nJmOilQaF/7+/f38+smmoYp6b1T+/v7++vj189zU0tDJxsGzsrKSfv34+Pf27dDOysG9t6+n/vv6+vr59uzr1tG+tZ6Qg9Ym3QAABR5JREFUSMeNlVVUG1EQhpcuxEspXqS0SKEtxQp1d3d332STTRpIQhIISQgJhODu7lAoDoUCpe7u7u7+1puGpqnCPOyZvffbOXPm/PsP9JfQgyCC+tmTABTOcbxDz/heENS7/1F+9nhvkHePG0wNDLbGWwdXL+rbLWvpmZHXD8+gMfBjTh+aSe6Gnn7lwQIOTR0c8wfX3PWgv7avbdKwf/ZoBp1Gp/PvuvXW3vw5ib7emnTW4OR+3D4jB9vjNJ/7gNvfWWeH/TO/JyYrsiKCRjVEZA3UB+96kON+DxOQ/NLE8PE5iUYgIXjFnCOlxEQMaSGVxjg4gxOnEycGz8bptuNjVx08LscIgrzH3umcn+KKtiBIyvzOO2O99aAdR8cF19oZalnCtvREUw79tCd5sow1g1UKM6kXqUx4T8wsi3sTjJ3yzDmmhenLXLpo8u45eG5y4Vvbk6kkC4LLtJMowkSQxmk4ggVJEG+7c6QpHT8vvW9X7/o7+3ELmiJi2mEzZJiz8cT6TBlanBk70cB5GGIGC1gRDdZ00yADLW1FL6gqhtvNXNG5S9gdSrk4M1qu7JAsmYshzDS4peoMrU/gT7qQdqYGZaYhxZmVbGJAm/CS/HloWyhRUlknQ9KYcExTwS80d3VNOxUZJpITYyspl0LbhArhpZCD9cRWEQuhYkNGMHToQ/2Cs6swJlb39CsllxdXX6IUKh/H5jbnSsPKjgmoaFQ1f8wRLR0UnGE/RcDEjj2jXG1WVTwUs8+zxfcrVO+vSsuOpVKxCfYZiQ0/aPKuxQbQ8lIz+DClxC8u+snlcJ7Yr1z1JPqUH0V+GDXbOwAib931Y4Imaq0NTIXPXY+N5L18GJ37SVWu+hwXff8l72Ds9XuwYIBaXPq6Shm4l+Vl/5QiOlV+uTk6YR9PxKsI9xNJny31ygK1e+nIRC1N97EGkFPI+jCpiHe5PCEy7oWqWSwRrpOvhFzcbTWMbm3ZJAOn1rUKpYIt/lDhW/5RHHteeWFN60qo98YJuoq1nK3uW5AabyspC1BcIEpOhft+SZAShYoLSvnmSfnYADUERP5jJn2h5XtsgCRuhYQqAvwTwn33+YWEKUI72HX5AtfSAZDe8F2DtPPm77afhl0EkthzuCQU0BWApgQIH9+KB0JhopMM7bJrdTRoleM2JAVNMyPF+wdoaz+XJpGoVAQ7WXUkcV7gT3oUZyi/ISIJAVKhgNp+4b4veCFhYVJw4locdSjZCp9cPUhLF9EZ3KKzURepMEtCDPP3VcWFx4UIiZIklIpFNfHpdEafIF2aRmOcrUmjohbT2WUllbmRvgfbythbQO3222fpDJoufaQPncYYuqoGtUEsCJZL6/3PR5b4syeSjZMQG/T2maGANlXT2v8S4AULWaUkCxfLyW8iW4kdka+nEMjxpL2NCwsYNBp+Q61PF43zyDg9Bm9+3NNySn78jMZUUkumqE4Gp7JmFOdP1vc8PpRrzj9+wPinCy8K1PiJ4aYbnTYpCCbDkBSbzhu2QJ1Gd82t8jI8TH51+OzvXoWbnXUOBkNW+0mWFwGcGOUVpU81/n3TOHb5oMt2FgYGjzau0Nif0Ss7Q3XB33hjjQHjHA5E5aOyIQc8CBrLdQSs3j92VG+3nNEjbkbdbBr9zm04ruvw37vh0QKOdeGIkckc80fX3KH/h7PT4BOjgCty8VZ5ux1MoO5Cf5naca2LAsEgehI+drX8o/0Nu+W0m6K/I9gGPd/dfx/EN/wN62AhsBWuAAAAAElFTkSuQmCC
">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
</div>
# BERT
[BERT](https://huggingface.co/papers/1810.04805) 是一个在无标签的文本数据上预训练的双向 transformer用于预测句子中被掩码的masked token以及预测一个句子是否跟随在另一个句子之后。其主要思想是在预训练过程中通过随机掩码一些 token让模型利用左右上下文的信息预测它们从而获得更全面深入的理解。此外BERT 具有很强的通用性,其学习到的语言表示可以通过额外的层或头进行微调,从而适配其他下游 NLP 任务。
你可以在 [BERT](https://huggingface.co/collections/google/bert-release-64ff5e7a4be99045d1896dbc) 集合下找到 BERT 的所有原始 checkpoint。
> [!TIP]
> 点击右侧边栏中的 BERT 模型,以查看将 BERT 应用于不同语言任务的更多示例。
下面的示例演示了如何使用 [`Pipeline`], [`AutoModel`] 和命令行预测 `[MASK]` token。
<hfoptions id="usage">
<hfoption id="Pipeline">
```py
import torch
from transformers import pipeline
pipeline = pipeline(
task="fill-mask",
model="google-bert/bert-base-uncased",
torch_dtype=torch.float16,
device=0
)
pipeline("Plants create [MASK] through a process known as photosynthesis.")
```
</hfoption>
<hfoption id="AutoModel">
```py
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"google-bert/bert-base-uncased",
)
model = AutoModelForMaskedLM.from_pretrained(
"google-bert/bert-base-uncased",
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="sdpa"
)
inputs = tokenizer("Plants create [MASK] through a process known as photosynthesis.", return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(**inputs)
predictions = outputs.logits
masked_index = torch.where(inputs['input_ids'] == tokenizer.mask_token_id)[1]
predicted_token_id = predictions[0, masked_index].argmax(dim=-1)
predicted_token = tokenizer.decode(predicted_token_id)
print(f"The predicted token is: {predicted_token}")
```
</hfoption>
<hfoption id="transformers-cli">
```bash
echo -e "Plants create [MASK] through a process known as photosynthesis." | transformers-cli run --task fill-mask --model google-bert/bert-base-uncased --device 0
```
</hfoption>
</hfoptions>
## 注意
- 输入内容应在右侧进行填充,因为 BERT 使用绝对位置嵌入。
## BertConfig
[[autodoc]] BertConfig
- all
## BertTokenizer
[[autodoc]] BertTokenizer
- build_inputs_with_special_tokens
- get_special_tokens_mask
- create_token_type_ids_from_sequences
- save_vocabulary
## BertTokenizerFast
[[autodoc]] BertTokenizerFast
## BertModel
[[autodoc]] BertModel
- forward
## BertForPreTraining
[[autodoc]] BertForPreTraining
- forward
## BertLMHeadModel
[[autodoc]] BertLMHeadModel
- forward
## BertForMaskedLM
[[autodoc]] BertForMaskedLM
- forward
## BertForNextSentencePrediction
[[autodoc]] BertForNextSentencePrediction
- forward
## BertForSequenceClassification
[[autodoc]] BertForSequenceClassification
- forward
## BertForMultipleChoice
[[autodoc]] BertForMultipleChoice
- forward
## BertForTokenClassification
[[autodoc]] BertForTokenClassification
- forward
## BertForQuestionAnswering
[[autodoc]] BertForQuestionAnswering
- forward
## TFBertTokenizer
[[autodoc]] TFBertTokenizer
## TFBertModel
[[autodoc]] TFBertModel
- call
## TFBertForPreTraining
[[autodoc]] TFBertForPreTraining
- call
## TFBertModelLMHeadModel
[[autodoc]] TFBertLMHeadModel
- call
## TFBertForMaskedLM
[[autodoc]] TFBertForMaskedLM
- call
## TFBertForNextSentencePrediction
[[autodoc]] TFBertForNextSentencePrediction
- call
## TFBertForSequenceClassification
[[autodoc]] TFBertForSequenceClassification
- call
## TFBertForMultipleChoice
[[autodoc]] TFBertForMultipleChoice
- call
## TFBertForTokenClassification
[[autodoc]] TFBertForTokenClassification
- call
## TFBertForQuestionAnswering
[[autodoc]] TFBertForQuestionAnswering
- call
## FlaxBertModel
[[autodoc]] FlaxBertModel
- __call__
## FlaxBertForPreTraining
[[autodoc]] FlaxBertForPreTraining
- __call__
## FlaxBertForCausalLM
[[autodoc]] FlaxBertForCausalLM
- __call__
## FlaxBertForMaskedLM
[[autodoc]] FlaxBertForMaskedLM
- __call__
## FlaxBertForNextSentencePrediction
[[autodoc]] FlaxBertForNextSentencePrediction
- __call__
## FlaxBertForSequenceClassification
[[autodoc]] FlaxBertForSequenceClassification
- __call__
## FlaxBertForMultipleChoice
[[autodoc]] FlaxBertForMultipleChoice
- __call__
## FlaxBertForTokenClassification
[[autodoc]] FlaxBertForTokenClassification
- __call__
## FlaxBertForQuestionAnswering
[[autodoc]] FlaxBertForQuestionAnswering
- __call__
## Bert specific outputs
[[autodoc]] models.bert.modeling_bert.BertForPreTrainingOutput
[[autodoc]] models.bert.modeling_tf_bert.TFBertForPreTrainingOutput
[[autodoc]] models.bert.modeling_flax_bert.FlaxBertForPreTrainingOutput

435
examples/3D_parallel.py Normal file
View File

@ -0,0 +1,435 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""":
This script is used to test training a model using Tensor Parallelism and Data Parallelism.
Usage:
export CUDA_VISIBLE_DEVICES=0,1,2,3
export CUDA_VISIBLE_DEVICES=4,5,6,7
export CUDA_VISIBLE_DEVICES=5,6,7
TP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 examples/3D_parallel.py
CP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 examples/3D_parallel.py
CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=4 examples/3D_parallel.py
DP_SIZE=2 CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=8 examples/3D_parallel.py
TP_SIZE=1 CP_SIZE=4 torchrun --nproc_per_node=4 examples/3D_parallel.py
TP_SIZE=1 DP_SIZE=4 torchrun --nproc_per_node=4 examples/3D_parallel.py
TP_SIZE=4 DP_SIZE=1 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 examples/3D_parallel.py
IGNORE_SANITY=1 CP_SIZE=1 TP_SIZE=1 DP_SIZE=1 torchrun --nproc_per_node=1 --rdzv_endpoint=localhost:29504 examples/3D_parallel.py
ocalhost:29504 test_train.py
"""
import logging
import os
from contextlib import nullcontext
from typing import Iterable
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.optim as optim
import wandb
from datasets import load_dataset
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.experimental import context_parallel
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer
# torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
# Set up logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
# from torch.distributed.tensor.experimental._attention import set_rotate_method
# set_rotate_method("alltoall") # CP rotate shards using all-to-all
def main():
tp_size = int(os.environ.get("TP_SIZE", 1))
dp_size = int(os.environ.get("DP_SIZE", 1))
cp_size = int(os.environ.get("CP_SIZE", 1)) # Add CP size configuration
sdpa_backend = SDPBackend.FLASH_ATTENTION # For CP
# sdpa_backend = SDPBackend.MATH # For CP
global_batch_size = 8 # Desired global batch size
seq_len = 1024 # Sequence length
num_train_steps = 10000 # Number of training steps
LR = 1e-5
model_name = "HuggingFaceTB/SmolLM2-1.7B"
# model_name = "unsloth/Llama-3.2-1B"
CHECKPOINT_DIR = f"checkpoint_tp{tp_size}_dp{dp_size}_cp{cp_size}"
# Initialize distributed environment
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
assert world_size == tp_size * dp_size * cp_size, (
f"World size ({world_size}) must equal TP size ({tp_size}) * DP size ({dp_size}) * CP size ({cp_size})"
)
mesh = torch.arange(world_size).reshape(dp_size, tp_size, cp_size)
world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp"))
tp_mesh = world_mesh["tp"]
dp_mesh = world_mesh["dp"]
cp_mesh = world_mesh["cp"]
world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
logger.info(f"Created DeviceMesh: {world_mesh}")
logger.info(
f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}, DP: {dp_mesh.get_local_rank()}, TP: {tp_mesh.get_local_rank()}, CP: {cp_mesh.get_local_rank()}"
)
if dist.get_rank() == 0:
wandb.init(
project="tp_dp_test",
config={
"tp_size": tp_size,
"dp_size": dp_size,
"cp_size": cp_size,
"global_batch_size": global_batch_size,
"model_name": model_name,
"dataset": "roneneldan/TinyStories-1M",
"seq_len": seq_len,
"lr": LR,
"weight_decay": 0.1,
},
name=f"llama_tp{tp_size}_dp{dp_size}_cp{cp_size}"
if model_name == "unsloth/Llama-3.2-1B"
else f"tp{tp_size}_dp{dp_size}_cp{cp_size}",
)
logger.info("Wandb initialized.")
# Log the current file to wandb
wandb.save("test_train.py")
# Load model and tokenizer
logger.info(f"Loading model and tokenizer from {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Set pad_token to eos_token: {tokenizer.pad_token}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_mesh=tp_mesh if dist.is_initialized() else None,
tp_plan="auto",
torch_dtype=torch.bfloat16,
)
logger.info(f"Model loaded onto device mesh: {tp_mesh}")
device = torch.device(f"cuda:{local_rank}")
logger.info(f"Using device: {device} for non-model tensors")
use_ddp = False
if dist.is_initialized() and dp_mesh.size() > 1:
model = FSDP(model, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD)
use_ddp = True
pass
model.train()
logger.info("Loading TinyStories dataset...")
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") # Use 1% for faster testing
def tokenize_function(examples):
# Tokenize the text without padding
tokenized_batch = tokenizer(
examples["text"], padding=False, truncation=True, max_length=seq_len, return_tensors=None
)
# Set labels to be the same as input_ids for Causal LM
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
return tokenized_batch
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
logger.info(f"Dataset loaded and tokenized. Size: {len(tokenized_dataset)}")
# Create packed sequences
def create_packed_sequences(examples):
# Flatten all sequences
all_tokens = []
for input_ids in examples["input_ids"]:
all_tokens.extend(input_ids)
# Split into sequences of seq_len + 1 (for input + label)
num_sequences = len(all_tokens) // (seq_len + 1)
packed_input_ids = []
packed_labels = []
for i in range(num_sequences):
start_idx = i * (seq_len + 1)
end_idx = start_idx + (seq_len + 1)
# Get the full sequence
full_sequence = all_tokens[start_idx:end_idx]
# For input_ids, remove the last token
packed_input_ids.append(full_sequence[:-1])
# For labels, remove the first token
packed_labels.append(full_sequence[1:])
return {"input_ids": packed_input_ids, "labels": packed_labels}
# Apply packing to the dataset
packed_dataset = tokenized_dataset.map(
create_packed_sequences,
batched=True,
remove_columns=tokenized_dataset.column_names,
batch_size=1000, # Process in batches for efficiency
num_proc=60,
)
logger.info(f"Dataset packed. New size: {len(packed_dataset)}")
# Shuffle the packed dataset
packed_dataset = packed_dataset.shuffle(seed=42)
logger.info("Packed dataset shuffled")
# Calculate local batch size
if dist.is_initialized():
assert global_batch_size % dp_mesh.size() == 0, (
f"Global batch size ({global_batch_size}) must be divisible by DP size ({dp_mesh.size()})"
)
local_batch_size = global_batch_size // dp_mesh.size()
else:
local_batch_size = global_batch_size
logger.info(
f"Global batch size: {global_batch_size}, DP size: {dp_size if dist.is_initialized() else 1}, Local batch size: {local_batch_size}"
)
# Simple collate function since sequences are already packed
def collate_fn(batch):
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
return {"input_ids": input_ids, "labels": labels}
if dist.is_initialized():
sampler = DistributedSampler(
packed_dataset, num_replicas=dp_mesh.size(), rank=dp_mesh.get_local_rank(), shuffle=False
)
else:
sampler = None
dataloader = DataLoader(
packed_dataset,
batch_size=local_batch_size,
sampler=sampler,
shuffle=False,
collate_fn=collate_fn,
pin_memory=True,
)
logger.info(f"DataLoader created. Distributed: {dist.is_initialized()}")
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1)
# Training loop
logger.info(f"Starting training for {num_train_steps} steps...")
model.train()
step = 0
while step < num_train_steps:
for batch in dataloader:
if step >= num_train_steps:
break # Exit loop if max steps reached
# Move batch to appropriate device
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
# Add position_ids to batch before CP sharding
batch_size = batch["input_ids"].shape[0]
position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
batch["position_ids"] = position_ids
from torch.distributed.tensor.experimental._attention import _cp_options
_cp_options.enable_load_balance = False
with sdpa_kernel(sdpa_backend): # TODO: ideally move this to attention implementation
cp_context = (
nullcontext()
if cp_mesh.size() == 1
else context_parallel(
cp_mesh,
buffers=[
batch["input_ids"],
batch["labels"],
batch["position_ids"],
],
buffer_seq_dims=[1, 1, 1],
)
)
with cp_context:
# Pop labels from batch before model forward pass
labels = batch.pop("labels")
outputs = model(**batch) # [mbs, seq_len/cp]
loss = outputs.loss
logits = outputs.logits
# Compute loss with shifted labels
loss = model.loss_function(
logits=logits, labels=None, shift_labels=labels, vocab_size=model.config.vocab_size
)
loss.backward()
# all reduce grads across dp_cp if applicable
all_reduce_grads(model, world_mesh, use_ddp=use_ddp)
if hasattr(model, "clip_grad_norm_"):
gradnorm = model.clip_grad_norm_(max_norm=1.0, norm_type=2.0) # TODO: fix reported gradnorm
else:
# only works with FSDP's NO_SHARD otherwise we should use FSDP's clip_grad_norm_
assert len(list(model.parameters())) > 5, "No parameters found in model. Probably DDP bug.."
gradnorm = clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2.0, foreach=True)
optimizer.step()
# allreduce loss across cp_dp before logging
if dist.is_initialized() and (cp_mesh.size() > 1 or dp_mesh.size() > 1):
dist.all_reduce(loss, group=world_mesh["dp_cp"].get_group(), op=dist.ReduceOp.AVG)
current_loss = loss.item()
# Log loss and gradnorm to wandb (only on rank 0 of dp group)
if not dist.is_initialized() or dist.get_rank() == 0:
logger.info(
f"Step: {step} | GBS: {global_batch_size} | DP: {dp_mesh.size()} | TP: {tp_mesh.size()} | CP: {cp_mesh.size()} | Loss: {current_loss} | Gradnorm: {gradnorm} | lr: {LR}"
)
wandb.log(
{
"train/loss": current_loss,
"train/gradnorm": gradnorm,
"step": step,
"lr": LR,
"GBS": global_batch_size,
}
)
step += 1 # Increment step count
logger.info("Training loop finished.")
# Save model using DCP (only if distributed)
if dist.is_initialized():
state_dict = {"app": AppState(model, optimizer)}
dcp.save(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
)
logger.info(f"Saved checkpoint to {CHECKPOINT_DIR}")
else:
# Fallback to regular save for non-distributed case
save_dir = "test_model_nondist"
model.save_pretrained(save_dir, safe_serialization=False)
tokenizer.save_pretrained(save_dir) # Save tokenizer too
logger.info(f"Saved model to {save_dir}")
dist.destroy_process_group()
logger.info("Cleaned up distributed process group")
# Finish wandb run on rank 0
if dist.get_rank() == 0:
wandb.finish()
logger.info("Wandb run finished.")
def all_reduce_grads(model, world_mesh, use_ddp):
"""All reduce gradients across dp_cp if applicable."""
cp_mesh = world_mesh["cp"]
if use_ddp:
# DDP/FSDP takes care of syncing grads
mesh = cp_mesh
else:
mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
if dist.is_initialized() and mesh.size() > 1:
for name, param in model.named_parameters():
if param.grad is not None:
# Workaround for cross-mesh communication limitation with DTensor gradients
if isinstance(param.grad, DTensor):
local_grad = param.grad.to_local()
# Ensure grad requires grad for inplace modification checks (might not be needed)
# local_grad = local_grad.detach().requires_grad_(True)
torch.distributed.all_reduce(local_grad, op=torch.distributed.ReduceOp.SUM, group=mesh.get_group())
local_grad = local_grad / mesh.size()
# Assign averaged grad back - need careful handling if DTensor structure is complex
# This simple assignment might work if the grad structure matches param structure
param.grad = DTensor.from_local(
local_grad, device_mesh=param.grad.device_mesh, placements=param.grad.placements
)
else:
# Handle regular tensors if any exist (e.g. buffers not converted to DTensor)
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG, group=mesh.get_group())
class AppState(Stateful):
"""Wrapper for checkpointing the Application State including model and optimizer."""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {"model": model_state_dict, "optim": optimizer_state_dict}
def load_state_dict(self, state_dict):
set_state_dict(
self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"]
)
def clip_grad_norm_(
parameters: Iterable[torch.Tensor],
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: bool | None = None,
) -> torch.Tensor:
"""
Clip the gradient norm of an iterable of parameters.
"""
# Filter out parameters with no gradients
parameters = [p for p in parameters if p.grad is not None]
assert len(parameters) > 0, "No parameters with gradients found"
# Calculate total norm
if norm_type == float("inf"):
total_norm = max(p.grad.detach().abs().max() for p in parameters)
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
# Convert DTensor to local tensor if needed
if isinstance(total_norm, DTensor):
total_norm = total_norm.full_tensor()
# Clip gradients
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef)
return total_norm
if __name__ == "__main__":
main()

View File

@ -60,7 +60,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset

View File

@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -0,0 +1,793 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""":
This script is used to test training a model using Tensor Parallelism and Data Parallelism.
Usage:
export CUDA_VISIBLE_DEVICES=0,1,2,3
export CUDA_VISIBLE_DEVICES=4,5,6,7
export CUDA_VISIBLE_DEVICES=5,6,7
TP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 test_train.py
CP_SIZE=2 DP_SIZE=2 torchrun --nproc_per_node=4 test_train.py
CP_SIZE=2 TP_SIZE=2 torchrun --nproc_per_node=4 test_train.py
TP_SIZE=1 CP_SIZE=4 torchrun --nproc_per_node=4 test_train.py
TP_SIZE=1 DP_SIZE=4 torchrun --nproc_per_node=4 test_train.py
TP_SIZE=4 DP_SIZE=1 torchrun --nproc_per_node=4 --rdzv_endpoint=localhost:29503 test_train.py
IGNORE_SANITY=1 CP_SIZE=1 TP_SIZE=1 DP_SIZE=1 torchrun --nproc_per_node=1 --rdzv_endpoint=l
ocalhost:29504 test_train.py
"""
import logging
import os
from contextlib import nullcontext
from typing import Dict, Iterable, Optional
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
import torch.optim as optim
import wandb
from datasets import load_dataset
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.experimental import context_parallel
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.utils.data import DataLoader, default_collate
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoModelForCausalLM, AutoTokenizer
ignore_sanity_checks = int(os.environ.get("IGNORE_SANITY", 0)) == 1
# torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
# Set up logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
# from torch.distributed.tensor.experimental._attention import set_rotate_method
# set_rotate_method("alltoall") # rotate shards using all-to-all
def main():
tp_size = int(os.environ.get("TP_SIZE", 1))
dp_size = int(os.environ.get("DP_SIZE", 4))
cp_size = int(os.environ.get("CP_SIZE", 1)) # Add CP size configuration
sdpa_backend = SDPBackend.FLASH_ATTENTION # For CP
# sdpa_backend = SDPBackend.MATH # For CP
global_batch_size = 8 # Desired global batch size
seq_len = 1024 # Sequence length
num_train_steps = 10000 # Number of training steps
LR = 1e-5
model_name = "HuggingFaceTB/SmolLM2-1.7B"
# model_name = "unsloth/Llama-3.2-1B"
CHECKPOINT_DIR = f"checkpoint_tp{tp_size}_dp{dp_size}_cp{cp_size}"
# Initialize distributed environment
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
assert world_size == tp_size * dp_size * cp_size, (
f"World size ({world_size}) must equal TP size ({tp_size}) * DP size ({dp_size}) * CP size ({cp_size})"
)
mesh = torch.arange(world_size).reshape(dp_size, tp_size, cp_size)
world_mesh = DeviceMesh(device_type="cuda", mesh=mesh, mesh_dim_names=("dp", "tp", "cp"))
tp_mesh = world_mesh["tp"]
dp_mesh = world_mesh["dp"]
cp_mesh = world_mesh["cp"]
world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
logger.info(f"Created DeviceMesh: {world_mesh}")
logger.info(
f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}, DP: {dp_mesh.get_local_rank()}, TP: {tp_mesh.get_local_rank()}, CP: {cp_mesh.get_local_rank()}"
)
if dist.get_rank() == 0:
wandb.init(
project="tp_dp_test",
config={
"tp_size": tp_size,
"dp_size": dp_size,
"cp_size": cp_size,
"global_batch_size": global_batch_size,
"model_name": model_name,
"dataset": "roneneldan/TinyStories-1M",
"seq_len": seq_len,
"lr": LR,
"weight_decay": 0.1,
},
name=f"llama_tp{tp_size}_dp{dp_size}_cp{cp_size}"
if model_name == "unsloth/Llama-3.2-1B"
else f"tp{tp_size}_dp{dp_size}_cp{cp_size}",
)
logger.info(f"ignore_sanity_checks is set to: {ignore_sanity_checks}")
logger.info("Wandb initialized.")
# Log the current file to wandb
wandb.save("test_train.py")
else:
logger.info("Running in non-distributed mode. DeviceMesh not applicable.")
rank = 0
world_size = 1
local_rank = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wandb.init(
project="tp_dp_test",
config={
"tp_size": 1,
"dp_size": 1,
"global_batch_size": global_batch_size,
"model_name": model_name,
"dataset": "roneneldan/TinyStories-1M",
"seq_len": seq_len,
},
name="llama_tp1_dp1_nondist" if model_name == "unsloth/Llama-3.2-1B" else "tp1_dp1_nondist",
)
logger.info("Wandb initialized for non-distributed run.")
# Load model and tokenizer
logger.info(f"Loading model and tokenizer from {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Set pad_token to eos_token: {tokenizer.pad_token}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_mesh=tp_mesh if dist.is_initialized() else None,
tp_plan="auto",
torch_dtype=torch.bfloat16,
)
logger.info(f"Model loaded onto device mesh: {tp_mesh}")
if dist.is_initialized():
assert model.config.num_key_value_heads % tp_mesh.size() == 0, (
f"num_key_value_heads={model.config.num_key_value_heads} must be divisible by tp_size={tp_mesh.size()}"
)
device = torch.device(f"cuda:{local_rank}")
else:
model = model.to(device)
logger.info(f"Using device: {device} for non-model tensors")
use_ddp = False
if dist.is_initialized() and dp_mesh.size() > 1:
# FSDP1
model = FSDP(model, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD)
# FSDP2
# for transformer_block in model.model.layers:
# fully_shard(transformer_block, mesh=dp_mesh, reshard_after_forward=False)
# fully_shard(model.model, mesh=dp_mesh, reshard_after_forward=False)
# DDP
# replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
# assert len(list(model.parameters()))>5, "No parameters found in model. Probably DDP/FSDP bug.." # TODO: we should be cautious abt using model.parameters()
use_ddp = True
model.train()
assert len(list(model.parameters())) > 0, "No parameters found in model. Probably DDP bug.."
assert len([p for p in model.parameters() if p.requires_grad]) > 0, (
"No gradients found in model. Probably DDP bug.."
)
if dist.is_initialized() and not ignore_sanity_checks:
# assert model is replicated across all dp
for name, param in model.named_parameters():
sanity_check_tensor_sync(param, dp_mesh)
# assert model is different across tp (only for sharded params)
for name, param in model.named_parameters():
if isinstance(param, DTensor) and param.placements[0].is_shard():
# Only check sharded parameters for non-sync across TP
sanity_check_tensor_sync(param, tp_mesh, not_sync=True)
elif isinstance(param, DTensor) and param.placements[0].is_replicate():
# Replicated parameters should be the same across TP
sanity_check_tensor_sync(param, tp_mesh)
# assert model is replicated across cp
for name, param in model.named_parameters():
sanity_check_tensor_sync(param, cp_mesh)
# Load and preprocess TinyStories dataset
logger.info("Loading TinyStories dataset...")
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") # Use 1% for faster testing
def tokenize_function(examples):
# Tokenize the text without padding
tokenized_batch = tokenizer(
examples["text"], padding=False, truncation=True, max_length=seq_len, return_tensors=None
)
# Set labels to be the same as input_ids for Causal LM
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
return tokenized_batch
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
logger.info(f"Dataset loaded and tokenized. Size: {len(tokenized_dataset)}")
# Create packed sequences
def create_packed_sequences(examples):
# Flatten all sequences
all_tokens = []
for input_ids in examples["input_ids"]:
all_tokens.extend(input_ids)
# Split into sequences of seq_len + 1 (for input + label)
num_sequences = len(all_tokens) // (seq_len + 1)
packed_input_ids = []
packed_labels = []
for i in range(num_sequences):
start_idx = i * (seq_len + 1)
end_idx = start_idx + (seq_len + 1)
# Get the full sequence
full_sequence = all_tokens[start_idx:end_idx]
# For input_ids, remove the last token
packed_input_ids.append(full_sequence[:-1])
# For labels, remove the first token
packed_labels.append(full_sequence[1:])
return {"input_ids": packed_input_ids, "labels": packed_labels}
# Apply packing to the dataset
packed_dataset = tokenized_dataset.map(
create_packed_sequences,
batched=True,
remove_columns=tokenized_dataset.column_names,
batch_size=1000, # Process in batches for efficiency
num_proc=60,
)
logger.info(f"Dataset packed. New size: {len(packed_dataset)}")
# Shuffle the packed dataset
packed_dataset = packed_dataset.shuffle(seed=42)
logger.info("Packed dataset shuffled")
# Calculate local batch size
if dist.is_initialized():
assert global_batch_size % dp_mesh.size() == 0, (
f"Global batch size ({global_batch_size}) must be divisible by DP size ({dp_mesh.size()})"
)
local_batch_size = global_batch_size // dp_mesh.size()
else:
local_batch_size = global_batch_size
logger.info(
f"Global batch size: {global_batch_size}, DP size: {dp_size if dist.is_initialized() else 1}, Local batch size: {local_batch_size}"
)
# Simple collate function since sequences are already packed
def collate_fn(batch):
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
return {"input_ids": input_ids, "labels": labels}
if dist.is_initialized():
sampler = DistributedSampler(
packed_dataset, num_replicas=dp_mesh.size(), rank=dp_mesh.get_local_rank(), shuffle=False
)
else:
sampler = None
dataloader = DataLoader(
packed_dataset,
batch_size=local_batch_size,
sampler=sampler,
shuffle=False,
collate_fn=collate_fn,
)
logger.info(f"DataLoader created. Distributed: {dist.is_initialized()}")
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1)
# Training loop
logger.info(f"Starting training for {num_train_steps} steps...")
model.train()
step = 0
while step < num_train_steps:
for batch in dataloader:
if step >= num_train_steps:
break # Exit loop if max steps reached
# Move batch to appropriate device
batch = {k: v.to(device) for k, v in batch.items()}
# Sanity checks for batch distribution (only if distributed)
if dist.is_initialized() and not ignore_sanity_checks:
# check batch is same across all tp
sanity_check_tensor_sync(batch["input_ids"], tp_mesh)
# check batch is different across dp
sanity_check_tensor_sync(batch["input_ids"], dp_mesh, not_sync=True)
optimizer.zero_grad()
# Add position_ids to batch before CP sharding
batch_size = batch["input_ids"].shape[0]
position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
batch["position_ids"] = position_ids
from torch.distributed.tensor.experimental._attention import _cp_options
_cp_options.enable_load_balance = False
with sdpa_kernel(sdpa_backend): # TODO: ideally move this to attention implementation
cp_context = (
nullcontext()
if cp_mesh.size() == 1
else context_parallel(
cp_mesh,
buffers=[
batch["input_ids"],
batch["labels"],
batch["position_ids"],
], # TODO: need to add attention mask
buffer_seq_dims=[1, 1, 1],
)
)
with cp_context:
# Pop labels from batch before model forward pass
labels = batch.pop("labels")
outputs = model(**batch) # [mbs, seq_len/cp]
loss = outputs.loss
logits = outputs.logits
# Compute loss with shifted labels
loss = model.loss_function(
logits=logits, labels=None, shift_labels=labels, vocab_size=model.config.vocab_size
)
# Sanity checks for logits
if dist.is_initialized() and not ignore_sanity_checks:
# sanity_check_tensor_sync(logits, tp_mesh) # TODO: only true without sequence parallel
sanity_check_tensor_sync(logits, dp_mesh, not_sync=True)
sanity_check_tensor_sync(logits, cp_mesh, not_sync=True)
loss.backward()
# all reduce grads across dp_cp if applicable
all_reduce_grads(model, world_mesh, use_ddp=use_ddp)
# Sanity checks for gradients (only if distributed)
if dist.is_initialized() and not ignore_sanity_checks:
# check grads are not same across all tp (for sharded grads)
for name, param in model.named_parameters():
if param.grad is not None and isinstance(param.grad, DTensor):
if param.grad.placements[0].is_shard():
sanity_check_tensor_sync(param.grad, tp_mesh, not_sync=True)
elif param.grad.placements[0].is_replicate():
sanity_check_tensor_sync(param.grad, tp_mesh)
# check grads are same across dp
for name, param in model.named_parameters():
if param.grad is not None and dp_mesh.size() > 1:
sanity_check_tensor_sync(param.grad, dp_mesh)
# check grads are same across cp
for name, param in model.named_parameters():
if param.grad is not None and cp_mesh.size() > 1:
sanity_check_tensor_sync(param.grad, cp_mesh)
# Calculate gradient norm and clip gradients
if hasattr(model, "clip_grad_norm_"):
# when using FSDP or DDP, model.parameters() doesn't work
gradnorm = model.clip_grad_norm_(max_norm=1.0, norm_type=2.0)
else:
assert len(list(model.parameters())) > 2, "No parameters found in model. Probably DDP bug.."
assert len([p for p in model.parameters() if p.requires_grad]) > 2, (
"No gradients found in model. Probably DDP bug.."
)
assert len([p for p in model.parameters() if p.grad is not None]) > 2, (
"No gradients found in model. Probably DDP bug.."
)
# only works with FSDP's NO_SHARD otherwise we should use FSDP's clip_grad_norm_
gradnorm = clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2.0, foreach=True)
optimizer.step()
# Sanity checks for updated model parameters (only if distributed)
if dist.is_initialized() and not ignore_sanity_checks:
# check updated model is different across all tp (for sharded params)
for name, param in model.named_parameters():
if isinstance(param, DTensor):
if param.placements[0].is_shard():
sanity_check_tensor_sync(param, tp_mesh, not_sync=True)
elif param.placements[0].is_replicate():
sanity_check_tensor_sync(param, tp_mesh)
# check updated model is same across dp
for name, param in model.named_parameters():
sanity_check_tensor_sync(param, dp_mesh)
# check updated model is same across cp
for name, param in model.named_parameters():
sanity_check_tensor_sync(param, cp_mesh)
# allreduce loss across cp_dp before logging
if dist.is_initialized() and (cp_mesh.size() > 1 or dp_mesh.size() > 1):
dist.all_reduce(loss, group=world_mesh["dp_cp"].get_group(), op=dist.ReduceOp.AVG)
current_loss = loss.item()
# Log loss and gradnorm to wandb (only on rank 0 of dp group)
if not dist.is_initialized() or dist.get_rank() == 0:
logger.info(
f"Step: {step} | GBS: {global_batch_size} | DP: {dp_mesh.size()} | TP: {tp_mesh.size()} | CP: {cp_mesh.size()} | Loss: {current_loss} | Gradnorm: {gradnorm} | lr: {LR}"
)
wandb.log(
{
"train/loss": current_loss,
"train/gradnorm": gradnorm,
"step": step,
"lr": LR,
"GBS": global_batch_size,
}
)
step += 1 # Increment step count
logger.info("Training loop finished.")
# Save model using DCP (only if distributed)
if dist.is_initialized():
state_dict = {"app": AppState(model, optimizer)}
dcp.save(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
)
logger.info(f"Saved checkpoint to {CHECKPOINT_DIR}")
else:
# Fallback to regular save for non-distributed case
save_dir = "test_model_nondist"
model.save_pretrained(save_dir, safe_serialization=False)
tokenizer.save_pretrained(save_dir) # Save tokenizer too
logger.info(f"Saved model to {save_dir}")
# Example of loading the checkpoint (only if distributed)
if dist.is_initialized():
# Create a new model instance
logger.info("Creating new model instance for verification")
new_model = AutoModelForCausalLM.from_pretrained(
model_name,
device_mesh=tp_mesh,
torch_dtype=torch.bfloat16, # Use same dtype
)
new_optimizer = optim.AdamW(new_model.parameters(), lr=LR)
# Load checkpoint into new model
state_dict = {"app": AppState(new_model, new_optimizer)}
dcp.load(
state_dict=state_dict,
checkpoint_id=CHECKPOINT_DIR,
)
logger.info("Loaded checkpoint into new model")
# Verify model weights match
logger.info("Verifying model weights match...")
for (name1, param1), (name2, param2) in zip(model.named_parameters(), new_model.named_parameters()):
torch.testing.assert_close(
param1.to_local(),
param2.to_local(),
rtol=1e-3,
atol=1e-3,
msg=f"Weights mismatch in {name1} vs {name2}",
)
# Verify optimizer states match
logger.info("Verifying optimizer states match...")
for name1, state1 in optimizer.state_dict().items():
state2 = new_optimizer.state_dict()[name1]
if name1 == "state":
# Compare state dictionaries for each parameter
for param_id, param_state1 in state1.items():
param_state2 = state2[param_id]
# Compare each state component (step, exp_avg, exp_avg_sq)
for key, value1 in param_state1.items():
value2 = param_state2[key]
if isinstance(value1, DTensor):
# Convert DTensors to local tensors for comparison
torch.testing.assert_close(
value1.to_local(),
value2.to_local(),
rtol=1e-5,
atol=1e-5,
msg=f"Optimizer state mismatch in state[{param_id}][{key}]",
)
else:
torch.testing.assert_close(
value1,
value2,
rtol=1e-5,
atol=1e-5,
msg=f"Optimizer state mismatch in state[{param_id}][{key}]",
)
elif name1 == "param_groups":
# Compare param_groups (excluding the actual params list)
for i, (group1, group2) in enumerate(zip(state1, state2)):
for key in group1:
if key != "params": # Skip comparing the params list
assert group1[key] == group2[key], f"Param group mismatch in param_groups[{i}][{key}]"
# Run a forward pass with both models to verify outputs match
logger.info("Running forward pass verification...")
with torch.no_grad():
# Use the last batch for verification
batch = {k: v.to(device) for k, v in batch.items()} # Ensure batch is on correct device
original_outputs = model(**batch)
new_outputs = new_model(**batch)
torch.testing.assert_close(
original_outputs.logits.to_local(),
new_outputs.logits.to_local(),
rtol=1e-3,
atol=1e-3,
msg="Model outputs do not match!",
) # Increased tolerance slightly for bf16
# Clean up distributed environment and finish wandb run
if dist.is_initialized():
dist.destroy_process_group()
logger.info("Cleaned up distributed process group")
# Finish wandb run on rank 0
if dist.get_rank() == 0:
wandb.finish()
logger.info("Wandb run finished.")
else:
wandb.finish()
logger.info("Wandb run finished.")
def all_reduce_grads(model, world_mesh, use_ddp):
"""All reduce gradients across dp_cp if applicable."""
cp_mesh = world_mesh["cp"]
if use_ddp:
# DDP takes care of syncing grads
mesh = cp_mesh
else:
mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
if dist.is_initialized() and mesh.size() > 1:
for name, param in model.named_parameters():
if param.grad is not None:
# Workaround for cross-mesh communication limitation with DTensor gradients
if isinstance(param.grad, DTensor):
local_grad = param.grad.to_local()
# Ensure grad requires grad for inplace modification checks (might not be needed)
# local_grad = local_grad.detach().requires_grad_(True)
torch.distributed.all_reduce(local_grad, op=torch.distributed.ReduceOp.SUM, group=mesh.get_group())
local_grad = local_grad / mesh.size()
# Assign averaged grad back - need careful handling if DTensor structure is complex
# This simple assignment might work if the grad structure matches param structure
param.grad = DTensor.from_local(
local_grad, device_mesh=param.grad.device_mesh, placements=param.grad.placements
)
else:
# Handle regular tensors if any exist (e.g. buffers not converted to DTensor)
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG, group=mesh.get_group())
class ContextParallelCollator:
"""Collator for context parallel training that splits sequences into chunks."""
def __init__(self, cp_mesh: Optional[DeviceMesh] = None):
self.cp_mesh = cp_mesh
def __call__(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
batch = default_collate(batch)
if self.cp_mesh is not None and self.cp_mesh.size() > 1:
# Get sequence length from the input batch
seq_len = batch["input_ids"].shape[1]
assert seq_len % self.cp_mesh.size() == 0, (
f"Sequence length {seq_len} must be divisible by CP size {self.cp_mesh.size()}"
)
chunk_size = seq_len // self.cp_mesh.size()
cp_rank = self.cp_mesh.get_local_rank()
start_idx = cp_rank * chunk_size
end_idx = start_idx + chunk_size
# Keep only the local chunk of the sequence
batch["input_ids"] = batch["input_ids"][:, start_idx:end_idx]
batch["attention_mask"] = batch["attention_mask"][:, start_idx:end_idx]
batch["labels"] = batch["labels"][:, start_idx:end_idx]
return batch
class AppState(Stateful):
"""Wrapper for checkpointing the Application State including model and optimizer."""
def __init__(self, model, optimizer=None):
self.model = model
self.optimizer = optimizer
def state_dict(self):
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
return {"model": model_state_dict, "optim": optimizer_state_dict}
def load_state_dict(self, state_dict):
set_state_dict(
self.model, self.optimizer, model_state_dict=state_dict["model"], optim_state_dict=state_dict["optim"]
)
def sanity_check_tensor_sync(
tensor: torch.Tensor, mesh: DeviceMesh, rtol: float = 1e-4, atol: float = 1e-4, not_sync: bool = False
) -> None:
"""
Verify that a tensor is synchronized (or not synchronized) across all processes in the mesh's process group.
Handles both regular tensors and DTensors.
Args:
tensor (torch.Tensor): The tensor to check for synchronization (can be DTensor)
mesh (DeviceMesh): The device mesh containing the process group
rtol (float): Relative tolerance for comparison
atol (float): Absolute tolerance for comparison
not_sync (bool): If True, asserts that tensors are NOT synchronized. If False, asserts they are synchronized.
"""
if not dist.is_initialized() or mesh.size() == 1:
return # No need to check in non-distributed mode
# Get the process group from the mesh
pg = mesh.get_group()
# Convert DTensor to local tensor if needed
if hasattr(tensor, "to_local"):
local_tensor = tensor.to_local()
else:
local_tensor = tensor
# Gather tensors from all processes
world_size = dist.get_world_size(pg)
gathered_tensors = [torch.empty_like(local_tensor) for _ in range(world_size)]
dist.all_gather(gathered_tensors, local_tensor, group=pg)
# Compare each tensor with the first one
for i in range(1, world_size):
try:
torch.testing.assert_close(gathered_tensors[0], gathered_tensors[i], rtol=rtol, atol=atol)
except AssertionError as e:
if not_sync:
continue
# # Add detailed debugging for logit synchronization issues
# print(f"\nLogit synchronization error between rank 0 and rank {i}:")
# print(f"Tensor shape: {gathered_tensors[0].shape}")
# print(f"Number of mismatched elements: {(gathered_tensors[0] != gathered_tensors[i]).sum()}")
# print(f"Percentage of mismatched elements: {((gathered_tensors[0] != gathered_tensors[i]).sum() / gathered_tensors[0].numel() * 100):.2f}%")
# # Find the first few mismatches
# mismatches = torch.nonzero(gathered_tensors[0] != gathered_tensors[i])
# print("\nFirst few mismatches:")
# for idx in mismatches[:5]:
# idx = tuple(idx.tolist())
# print(f"Index {idx}:")
# print(f"Rank 0 value: {gathered_tensors[0][idx]}")
# print(f"Rank {i} value: {gathered_tensors[i][idx]}")
# print(f"Absolute difference: {abs(gathered_tensors[0][idx] - gathered_tensors[i][idx])}")
# print(f"Relative difference: {abs(gathered_tensors[0][idx] - gathered_tensors[i][idx]) / max(abs(gathered_tensors[0][idx]), abs(gathered_tensors[i][idx]))}")
# # Check if differences are systematic (e.g., all positive or negative)
# diff = gathered_tensors[0] - gathered_tensors[i]
# print(f"\nDifference statistics:")
# print(f"Mean difference: {diff.mean()}")
# print(f"Std difference: {diff.std()}")
# print(f"Max positive difference: {diff.max()}")
# print(f"Max negative difference: {diff.min()}")
raise e
def clip_grad_norm_(
parameters: Iterable[torch.Tensor],
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: bool | None = None,
) -> torch.Tensor:
"""
Clip the gradient norm of an iterable of parameters.
"""
# Filter out parameters with no gradients
parameters = [p for p in parameters if p.grad is not None]
assert len(parameters) > 0, "No parameters with gradients found"
# Calculate total norm
if norm_type == float("inf"):
total_norm = max(p.grad.detach().abs().max() for p in parameters)
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
# Convert DTensor to local tensor if needed
if isinstance(total_norm, DTensor):
total_norm = total_norm.full_tensor()
# Clip gradients
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef)
return total_norm
def check_params_sync(model_params, original_params):
"""
Check if original_params are being updated in sync with model parameters.
Args:
model_params: Iterator of model parameters after update
original_params: List of original parameters before DDP wrapping
"""
for mp, op in zip(model_params, original_params):
if isinstance(mp, DTensor):
mp = mp.to_local()
if isinstance(op, DTensor):
op = op.to_local()
if not torch.allclose(mp.data, op.data, rtol=0, atol=0):
raise RuntimeError(f"Parameters out of sync: model param {mp.data} != original param {op.data}")
return True
def get_parameters(model: nn.Module) -> Iterable[torch.Tensor]:
"""
Get all parameters from a model by iterating over its modules.
This is an alternative to model.parameters() that works with DTensor models.
Args:
model (nn.Module): The model to get parameters from
Returns:
Iterable[torch.Tensor]: An iterator over all parameters in the model
"""
for name, module in model._modules.items():
# Look for parameters in module attributes
for attr_name, attr in module.__dict__.items():
if isinstance(attr, torch.Tensor) and attr.requires_grad:
yield attr
# Recursively get parameters from submodules
for param in get_parameters(module):
yield param
def update_model_parameters(model: nn.Module) -> None:
"""
Update model._parameters using named_modules() to ensure all parameters are properly tracked.
Args:
model (nn.Module): The model to update parameters for
"""
# Clear existing parameters
model._parameters = {}
# Add parameters from named_modules
for name, module in model.named_modules():
# Skip the root module itself
if name == "":
continue
# Get the parameter name by removing 'module.' prefix if it exists
param_name = name.replace("module.", "")
# Add weight and bias parameters if they exist
if hasattr(module, "weight") and module.weight is not None:
model._parameters[f"{param_name}.weight"] = module.weight
if hasattr(module, "bias") and module.bias is not None:
model._parameters[f"{param_name}.bias"] = module.bias
if __name__ == "__main__":
main()

View File

@ -44,7 +44,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")

View File

@ -0,0 +1,94 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.experimental import context_parallel
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import AutoModelForCausalLM
from transformers.loss.loss_utils import ForCausalLMLoss
world_size = int(os.environ.get("WORLD_SIZE", "1"))
cp_mesh = init_device_mesh("cuda", (world_size,))
rank = torch.distributed.get_node_local_rank()
device = "cuda"
dtype = torch.bfloat16
sdpa_backend = SDPBackend.FLASH_ATTENTION
# prepare inputs
batch_size = 1
seq_len = 128
input_ids = torch.randint(low=8, high=64, size=(batch_size, seq_len), device=device)
ignore_index = -100
# When using CP, we need to use `shift_labels`
shift_labels = torch.nn.functional.pad(input_ids, (0, 1), value=ignore_index)
shift_labels = shift_labels[..., 1:].contiguous()
position_ids = (
torch.cumsum(torch.ones(size=input_ids.size(), dtype=input_ids.dtype, device=input_ids.device), dim=1) - 1
)
# sync input as they are created randomly
dist.broadcast(input_ids, src=0)
dist.broadcast(shift_labels, src=0)
dist.broadcast(position_ids, src=0)
# model and optimizer
repo_id = "Qwen/Qwen2.5-Coder-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=dtype, device_map=device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
model.train()
model.zero_grad()
optimizer.zero_grad()
# For loss
vocab_size = model.config.vocab_size
# so training could be synced
model = DDP(model, device_ids=[rank])
# prepare for CP
buffers = (input_ids, shift_labels, position_ids)
buffer_seq_dims = (1, 1, 1)
# `no_restore_buffers=set(buffers)` is required if `loss.backward` is outside `context_parallel`.
# no_restore_buffers = set(buffers)
no_restore_buffers = None
# run with CP
with sdpa_kernel(sdpa_backend):
with context_parallel(
cp_mesh,
buffers=buffers,
buffer_seq_dims=buffer_seq_dims,
no_restore_buffers=no_restore_buffers,
):
outputs = model(input_ids, shift_labels=shift_labels, position_ids=position_ids)
print(outputs.logits.shape)
# So far we need to compute `loss` outside `model.forward` when using `shift_labels`
# loss = outputs.loss
loss = ForCausalLMLoss(logits=outputs.logits, labels=None, shift_labels=shift_labels, vocab_size=vocab_size)
# This could be outside `context_parallel` context if `no_restore_buffers` is specified
loss.backward()
optimizer.step()

View File

@ -1,132 +0,0 @@
import time
import datasets
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
_TEST_PROMPTS = [
"A man is a walking his dog down the street, and a the turn he sees",
"Describe a fruit that is of orange color and round. It is a sweet fruit and a great source of Vitamine C. The fruit I'm thinking of is an",
"A plane is flying high in the sky, out of the window are clouds and mountains. Where could the plane be located?",
"Please fill in the form to",
"For safety reasons, the train is stopped in the middle of the",
]
# --- Common Setup ---
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3b-Instruct", attn_implementation="paged_attention", torch_dtype=torch.bfloat16, device_map="auto"
).eval()
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3b-Instruct", padding_side="left")
device = "cuda"
model.use_cache = False
# Set pad token if missing (common for Llama models)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
# Configure generation parameters
generation_config = GenerationConfig(
max_new_tokens=512,
top_k=0,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
num_blocks=512,
block_size=128,
max_batch_tokens=512, # Maximum number of tokens to process in a single batch
)
# Prepare data (using a smaller subset for demonstration)
train_dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
train_dataset = train_dataset.select(range(5)) # Use only 5 examples for the simple version
# tokenized_test_prompts = tokenizer(_TEST_PROMPTS, padding=True, padding_side="left", truncation=True, max_length=512)
# simple_batch_inputs = list(tokenized_test_prompts["input_ids"])
# def tokenize_function(examples):
# # Truncate to avoid overly long prompts exceeding max context length
# return tokenizer(examples["question"], padding=True, truncation=True, max_length=512)
# tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
# simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
# model.config.attn_implementation = "sdpa"
# start_time_simple = time.time()
# batch_size = 64
# full_outputs = []
# from tqdm import tqdm
# for i in tqdm(range(0, len(simple_batch_inputs)-batch_size, batch_size)):
# outputs = model.generate(
# torch.tensor(simple_batch_inputs[i:i+batch_size], device=model.device),
# generation_config=GenerationConfig(
# max_new_tokens=16, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id
# ),
# )
# full_outputs.extend(outputs.tolist())
# end_time_simple = time.time()
# print(f"\nSimple batch generation took: {end_time_simple - start_time_simple:.2f} seconds")
# print("\nResults from simple generate_batch:")
# for i, request in enumerate(full_outputs):
# output_text = tokenizer.decode(request, skip_special_tokens=False)
# print("-" * 20)
# print(f" Output: {output_text}")
# print("-" * 20)
# print("--- Finished Simple Batch Generation Example ---\n\n")
# --- Example 1: Simple Version using generate_batch ---
print("--- Running CB Generation Example ---")
model.config.attn_implementation = "paged_attention"
def tokenize_function(examples):
# Truncate to avoid overly long prompts exceeding max context length
return tokenizer(examples["question"])
tokenized_datasets = train_dataset.map(tokenize_function, batched=True)
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
# tokenized_test_prompts = tokenizer(_TEST_PROMPTS, truncation=True, max_length=512)
# simple_batch_inputs = list(tokenized_test_prompts["input_ids"])
start_time_simple = time.time()
# Call the simple batch generation function
# This handles manager initialization, request adding, result retrieval, and shutdown internally.
batch_outputs = model.generate_batch(
inputs=simple_batch_inputs,
generation_config=generation_config,
# You can pass request-specific overrides here, e.g., max_new_tokens=100
)
end_time_simple = time.time()
model.__call__ = torch.compile(model.__call__, mode="reduce-overhead")
print(f"CB generation took: a{end_time_simple - start_time_simple:.2f} seconds")
# Decode and print results
print("\nResults from simple generate_batch:")
for request in batch_outputs:
input_text = tokenizer.decode(batch_outputs[request].full_prompt_ids, skip_special_tokens=False)
try:
# Decode the static outputs
output_text = tokenizer.decode(batch_outputs[request].static_outputs, skip_special_tokens=False)
except Exception as e:
# Handle the case where decoding fails
print(f"Decoding failed for request {request}: {e}")
output_text = tokenizer.decode(batch_outputs[request].static_outputs[1:], skip_special_tokens=False)
if len(output_text) > 0:
print("-" * 20)
print(f"{request} Input: {input_text}")
print(f"{request} Output: {output_text}")
else:
print("", end="\r\r\r\r")
print("-" * 20)
print("--- Finished Simple Batch Generation Example ---\n\n")

View File

@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)

View File

@ -42,7 +42,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -47,7 +47,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -52,7 +52,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -59,7 +59,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)

View File

@ -53,7 +53,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")

View File

@ -45,7 +45,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = logging.getLogger(__name__)

View File

@ -53,7 +53,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)
# You should update this to your particular problem to have better documentation of `model_type`

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/object-detection/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logging.basicConfig(level=logging.INFO)
logger = get_logger(__name__)

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -45,7 +45,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/question-answering/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)

View File

@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)

View File

@ -47,7 +47,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt")

View File

@ -34,7 +34,6 @@ from transformers import (
GPT2Tokenizer,
GPTJForCausalLM,
LlamaForCausalLM,
LlamaTokenizer,
OpenAIGPTLMHeadModel,
OpenAIGPTTokenizer,
OPTForCausalLM,
@ -63,7 +62,7 @@ MODEL_CLASSES = {
"xlm": (XLMWithLMHeadModel, XLMTokenizer),
"gptj": (GPTJForCausalLM, AutoTokenizer),
"bloom": (BloomForCausalLM, BloomTokenizerFast),
"llama": (LlamaForCausalLM, LlamaTokenizer),
"llama": (LlamaForCausalLM, AutoTokenizer),
"opt": (OPTForCausalLM, GPT2Tokenizer),
}

View File

@ -48,7 +48,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -56,7 +56,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = get_logger(__name__)
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")

View File

@ -50,7 +50,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version(
"datasets>=1.8.0", "To fix: pip install -r examples/tensorflow/contrastive-image-text/requirements.txt"

View File

@ -54,7 +54,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -49,7 +49,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = logging.getLogger(__name__)

View File

@ -61,7 +61,7 @@ except (ModuleNotFoundError, ImportError):
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
logger = logging.getLogger(__name__)

View File

@ -52,7 +52,7 @@ from transformers.utils.versions import require_version
# region Checking dependencies
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -46,7 +46,7 @@ from transformers.utils import check_min_version, send_example_telemetry
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
task_to_keys = {
"cola": ("sentence", None),

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
# region Dependencies and constants
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.52.0.dev0")
check_min_version("4.53.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")

View File

@ -1,14 +0,0 @@
apiVersion: 1
datasources:
- name: Prometheus
type: prometheus
access: proxy
url: http://prometheus:9090
isDefault: true
- name: Tempo
type: tempo
access: proxy
url: http://tempo:3200
uid: tempo

View File

@ -1,12 +0,0 @@
Potential issues with in transformers:
- beam sampling
- logit processor
- penalties for repetitions -> introduce bias
-> slow (for loops / non vectorized)
Potentail TODOs:
1. improve DX of transformers to improve reach and simplify usage (typically on local devices)
2. better improve usage of torch within transformers (jit/torchscript, executorch, etc)
3. tokenizers
- take papers and implement improvements (typically better Byte Pair Encoding)
- maintenance work (improve python API, help out on various issues / improvements, etc)

View File

@ -1,3 +0,0 @@
global:
scrape_interval: 15s

View File

@ -201,9 +201,6 @@ _deps = [
"pytest-rich",
"libcst",
"rich",
"opentelemetry-api",
"opentelemetry-exporter-otlp",
"opentelemetry-sdk",
]
@ -438,9 +435,6 @@ extras["torchhub"] = deps_list(
extras["benchmark"] = deps_list("optimum-benchmark")
# OpenTelemetry dependencies for metrics collection in continuous batching
extras["open-telemetry"] = deps_list("opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-sdk")
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
install_requires = [
deps["filelock"], # filesystem locks, e.g., to prevent parallel downloads
@ -457,7 +451,7 @@ install_requires = [
setup(
name="transformers",
version="4.52.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.53.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)",
author_email="transformers@huggingface.co",
description="State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow",

View File

@ -18,7 +18,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.52.0.dev0"
__version__ = "4.53.0.dev0"
from pathlib import Path
from typing import TYPE_CHECKING

View File

@ -21,6 +21,104 @@ if is_hqq_available():
logger = logging.get_logger(__name__)
# Utility functions for static/sliding cache update logic
def _static_cache_update(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_position: Optional[torch.LongTensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the static cache tensors in place.
Args:
k_cache (`torch.Tensor`): The key cache tensor to update.
v_cache (`torch.Tensor`): The value cache tensor to update.
key_states (`torch.Tensor`): The new key states to add.
value_states (`torch.Tensor`): The new value states to add.
cache_position (`Optional[torch.LongTensor]`): The position indices where the new states should be inserted.
If None, the entire cache is overwritten (prefill).
Returns:
Tuple[`torch.Tensor`, `torch.Tensor`]: The updated key and value cache tensors (modified in-place).
"""
if cache_position is None:
# Prefill phase where seq_len potentially equals max_cache_len. Directly copy.
k_cache.copy_(key_states)
v_cache.copy_(value_states)
else:
# Generation phase. Update specific positions.
# Use index_copy_ for in-place update (compile-friendly).
try:
k_cache.index_copy_(2, cache_position, key_states)
v_cache.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# Fallback for devices like MPS where index_copy_ might not be supported.
k_cache[:, :, cache_position] = key_states
v_cache[:, :, cache_position] = value_states
return k_cache, v_cache
def _sliding_cache_update(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_position: torch.LongTensor,
max_cache_len: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the sliding window cache tensors, returning the potentially modified tensors.
Args:
k_cache (`torch.Tensor`): The key cache tensor to update.
v_cache (`torch.Tensor`): The value cache tensor to update.
key_states (`torch.Tensor`): The new key states to add.
value_states (`torch.Tensor`): The new value states to add.
cache_position (`torch.LongTensor`): The position indices where the new states should be inserted.
max_cache_len (`int`): The maximum length of the sliding window cache.
Returns:
Tuple[`torch.Tensor`, `torch.Tensor`]: The key and value tensors representing the cache state after the update.
For prefill > window, these are the full input states.
Otherwise, they are the updated cache tensors.
"""
# Handle prefill phase when prompt length > sliding_window_size
if cache_position.shape[0] > max_cache_len:
new_k = key_states[:, :, -max_cache_len:, :]
new_v = value_states[:, :, -max_cache_len:, :]
k_cache.copy_(new_k)
v_cache.copy_(new_v)
return key_states, value_states
# Sliding window logic for generation phase or prefill < window
slicing = torch.arange(max_cache_len, device=value_states.device)
current_seq_len = cache_position[-1] + 1 # Use last position to determine current length
to_shift = current_seq_len > max_cache_len
indices = (slicing + to_shift.sum()) % max_cache_len
k_out_shifted = k_cache[:, :, indices]
v_out_shifted = v_cache[:, :, indices]
# Clamp cache_position to determine the *target index* within the shifted cache view
update_position = cache_position.clamp(min=0, max=max_cache_len - 1)
try:
k_out_updated = k_out_shifted.index_copy(2, update_position, key_states)
v_out_updated = v_out_shifted.index_copy(2, update_position, value_states)
except NotImplementedError:
# Fallback for MPS: clone and modify the clone
k_out_updated = k_out_shifted.clone()
v_out_updated = v_out_shifted.clone()
k_out_updated[:, :, update_position] = key_states
v_out_updated[:, :, update_position] = value_states
k_cache.copy_(k_out_updated)
v_cache.copy_(v_out_updated)
return k_out_updated, v_out_updated
class Cache:
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
@ -1264,28 +1362,16 @@ class StaticCache(Cache):
"""
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
if cache_position is None:
k_out.copy_(key_states)
v_out.copy_(value_states)
else:
# Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
# operation, that avoids copies and uses less memory.
try:
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
return k_out, v_out
key_states = key_states.to(self.key_cache[layer_idx].dtype)
value_states = value_states.to(self.value_cache[layer_idx].dtype)
return _static_cache_update(
self.key_cache[layer_idx],
self.value_cache[layer_idx],
key_states,
value_states,
cache_kwargs.get("cache_position"),
)
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
@ -1314,7 +1400,7 @@ class SlidingWindowCache(StaticCache):
The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:
indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window
indices = (slicing + to_shift[-1].sum()-1) % self.config.sliding_window
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
@ -1398,46 +1484,21 @@ class SlidingWindowCache(StaticCache):
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
# assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
if cache_position.shape[0] >= self.max_cache_len:
k_out = key_states[:, :, -self.max_cache_len :, :]
v_out = value_states[:, :, -self.max_cache_len :, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states
if cache_position is None:
raise ValueError("`cache_position` must be provided for SlidingWindowCache.")
slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
to_shift = cache_position > self.max_cache_len - 1
cache_position = cache_position.clamp(0, self.max_cache_len - 1)
indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
key_states = key_states.to(self.key_cache[layer_idx].dtype)
value_states = value_states.to(self.value_cache[layer_idx].dtype)
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
try:
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
return k_out, v_out
return _sliding_cache_update(
self.key_cache[layer_idx],
self.value_cache[layer_idx],
key_states,
value_states,
cache_position,
self.max_cache_len,
)
def get_max_cache_shape(self) -> Optional[int]:
return self.max_cache_len
@ -1680,12 +1741,13 @@ class HybridCache(Cache):
super().__init__()
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"Setting `cache_implementation` to 'hybrid' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)
self.max_cache_len = max_cache_len
self._sliding_window_max_len = min(config.sliding_window, max_cache_len)
self.max_cache_len = max_cache_len if max_cache_len is not None else config.max_position_embeddings
# Sliding layers can't be larger than the overall max cache len
self.sliding_window_len = min(config.sliding_window, self.max_cache_len)
self.max_batch_size = max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
@ -1694,22 +1756,17 @@ class HybridCache(Cache):
self._dtype = dtype
self.num_key_value_heads = (
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
else config.num_key_value_heads
)
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
self.is_sliding = torch.tensor(
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
)
self.is_sliding_list = [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)]
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
sliding_cache_shape = (
self.max_batch_size,
self.num_key_value_heads,
self._sliding_window_max_len,
self.head_dim,
)
global_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
sliding_cache_shape = (self.max_batch_size, self.num_key_value_heads, self.sliding_window_len, self.head_dim)
device = torch.device(device) if device is not None else None
for i in range(config.num_hidden_layers):
if layer_device_map is not None:
@ -1718,7 +1775,7 @@ class HybridCache(Cache):
layer_device = device
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
cache_shape = sliding_cache_shape if self.is_sliding_list[i] else global_cache_shape
new_layer_key_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self._dtype, device=layer_device)
torch._dynamo.mark_static_address(new_layer_key_cache)
@ -1726,42 +1783,6 @@ class HybridCache(Cache):
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
if cache_position.shape[0] >= max_cache_len:
k_out = key_states[:, :, -max_cache_len:, :]
v_out = value_states[:, :, -max_cache_len:, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
# we should return the whole states instead of k_out, v_out to take the whole prompt
# into consideration when building kv cache instead of just throwing away tokens outside of the window
return key_states, value_states
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
to_shift = cache_position > max_cache_len - 1
cache_position = cache_position.clamp(0, max_cache_len - 1)
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
# `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()
self.key_cache[layer_idx] += k_out
self.value_cache[layer_idx] += v_out
return k_out, v_out
def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
self.key_cache[layer_idx] = k_out
self.value_cache[layer_idx] = v_out
return k_out, v_out
def update(
self,
key_states: torch.Tensor,
@ -1772,7 +1793,10 @@ class HybridCache(Cache):
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window")
if cache_position is None:
raise ValueError("`cache_position` must be provided for HybridCache.")
is_sliding_layer = self.is_sliding_list[layer_idx]
# These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used
# when the cache is initialized in the forward pass (e.g. Gemma2)
@ -1781,25 +1805,22 @@ class HybridCache(Cache):
if self.value_cache[layer_idx].device != value_states.device:
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device)
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
key_states = key_states.to(k_out.dtype)
value_states = value_states.to(v_out.dtype)
k_cache = self.key_cache[layer_idx]
v_cache = self.value_cache[layer_idx]
key_states = key_states.to(k_cache.dtype)
value_states = value_states.to(v_cache.dtype)
if sliding_window:
update_fn = self._sliding_update
if is_sliding_layer:
return _sliding_cache_update(
k_cache,
v_cache,
key_states,
value_states,
cache_position,
k_cache.shape[2], # Use actual cache dim as max cache len
)
else:
update_fn = self._static_update
return update_fn(
cache_position,
layer_idx,
key_states,
value_states,
k_out,
v_out,
k_out.shape[2],
)
return _static_cache_update(k_cache, v_cache, key_states, value_states, cache_position)
def get_max_cache_shape(self) -> Optional[int]:
return self.max_cache_len
@ -2033,7 +2054,7 @@ class OffloadedHybridCache(HybridChunkedCache):
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
# track of the original device of each layer
unique_devices = set(layer_device_map.values())
unique_devices = set(layer_device_map.values()) if layer_device_map else set()
if len(unique_devices) > 1:
raise ValueError(f"OffloadedHybridCache does not support multiple devices. Got devices: {unique_devices}")
@ -2292,7 +2313,7 @@ class OffloadedStaticCache(StaticCache):
# TODO (joao): to enable this cache on multiple devicesuse the pattern from `OffloadedCache`, which keeps
# track of the original device of each layer
unique_devices = set(layer_device_map.values())
unique_devices = set(layer_device_map.values()) if layer_device_map else set()
if len(unique_devices) > 1:
raise ValueError(f"OffloadedStaticCache does not support multiple devices. Got devices: {unique_devices}")
@ -2369,6 +2390,9 @@ class OffloadedStaticCache(StaticCache):
A tuple containing the updated key and value states.
"""
key_states = key_states.to(self.key_cache[layer_idx].dtype)
value_states = value_states.to(self.value_cache[layer_idx].dtype)
if layer_idx == 0:
# Update seen tokens.
# TODO(gante): Remove this.

Some files were not shown because too many files have changed in this diff Show More