Compare commits

...

121 Commits
v4.56.1 ... v5

Author SHA1 Message Date
2969b17bc6 Merge branch 'main' of github.com:huggingface/transformers into v5 2025-09-09 15:38:00 +02:00
ed100211cb [generate] PromptLookupCandidateGenerator won't generate forbidden tokens (#40726)
* no longer flaky :)

* PR comments

* any token-blocking logits processor works

* ?

* default

* -_-

* create fake tensors once
2025-09-09 11:04:01 +00:00
82d66e5dd0 Fix: swanlab public.cloud.experiment_url api error (#40763)
fix
2025-09-09 09:28:13 +00:00
a871f6f58d Add EfficientLoFTRImageProcessorFast for GPU-accelerated image processing (#40215)
* Add EfficientLoFTRImageProcessorFast for GPU-accelerated image processing

* Fix fast processor output format and add comprehensive tests

* Fix trailing whitespace in test file

* Apply ruff formatting to test file

* simplify pair validation logic

* add superglue tests to fast image processor

---------

Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co>
2025-09-08 21:08:02 +00:00
aee5000f16 Fix Bark failing tests (#39478)
* Fix vocab size for Bark generation.

* Fix Bark processor tests.

* Fix style.

* Address comments.

* Fix formatting.

---------

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
2025-09-08 20:24:51 +02:00
126264d015 🌐 [i18n-KO] Translated 'xclip.md' to Korean (#39594)
* feat: nmt draft

* fix: manual edits

* docs: ko: xclip.md

* feat: nmt draft

* fix: manual edits

* fix: Modify _toctree.yml file to reflect review

* fix: Modify _toctree.yml file to reflect review

* jungnerd_suggestion_modified_01 ko_xclip.md

Co-authored-by: Woojun Jung <46880056+jungnerd@users.noreply.github.com>

* jungnerd_suggestion_modified_02 ko_xclip.md

Co-authored-by: Woojun Jung <46880056+jungnerd@users.noreply.github.com>

---------

Co-authored-by: Woojun Jung <46880056+jungnerd@users.noreply.github.com>
2025-09-08 11:19:10 -07:00
5a468e56b7 Fix continue_final_message in apply_chat_template to prevent substring matching issues (#40732)
* Fix continue_final_message parameter in apply_chat_template

* after run fixup

* Handle trim in the template

* after fixup

* Update src/transformers/utils/chat_template_utils.py

---------

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
2025-09-08 17:25:12 +00:00
e8db153599 Fix inconsistency in SeamlessM4T and SeamlessM4Tv2 docs (#39364) 2025-09-08 10:01:44 -07:00
fd2a29d468 Fix more typos (#40627)
Fix typos

Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
2025-09-08 16:05:40 +00:00
bb8e9cd675 Remove unnecessary tildes from documentation (#40748) 2025-09-08 08:56:35 -07:00
a9b313a0c2 docs: add continuous batching to serving (#40758)
* docs: tmp

* docs: add continuous batching to serving

* docs: reword after @lysandrejik review
2025-09-08 15:50:28 +00:00
2077f17547 feat: err when unsupported attn impl is set w/ --continuous_batching (#40618)
* feat: err when unsupported attn impl is set w/ `--continuous_batching`

* refactor: move defaults and support list to CB code

* feat: add action item in error msg

* fix(serve): add default attn implementation

* feat(serve): add log when `attn_implementation` is `None`

* feat: raise Exception when attn_implementation is not supported by CB
2025-09-08 14:31:49 +00:00
dc262ee6f5 remove FSDP prefix when using save_pretrained with FSDP2 (#40207)
* remove FSDP prefix when using save_pretrained with FSDP2

* Fix: use removeprefix correctly

---------

Co-authored-by: Matej Sirovatka <54212263+S1ro1@users.noreply.github.com>
Co-authored-by: S1ro1 <matej.sirovatka@gmail.com>
2025-09-08 14:52:31 +02:00
9ab6078323 remove gemmas eager training warning (#40744)
* removed warning

* removed remaining warnings
2025-09-08 14:41:52 +02:00
2a1eb5b508 Add BF16 support check for MUSA backend (#40576)
add musa bf16 supported

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
2025-09-08 12:39:14 +00:00
7b8d40ea7a Set accepts_loss_kwargs to False for ConvNext(|V2)ForImageClassification (#40746) 2025-09-08 14:25:43 +02:00
def7558f74 Fix np array typing (#40741)
Fix typing

Signed-off-by: cyy <cyyever@outlook.com>
Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
2025-09-08 11:30:40 +00:00
44b3888d2a Fix order of mask functions when using and/or_mask_function (#40753)
fix order
2025-09-08 12:31:42 +02:00
3f7bda4209 [Continous Batching] fix do_Sample=True in continuous batching (#40692)
* fix do_Sample=True in continous batching

* added test

* fix top_p

* test

* Update examples/pytorch/continuous_batching.py
2025-09-08 10:30:15 +02:00
bb45d3631e refactor(serve): move request_id to headers (#40722)
* refactor(serve): move `request_id` to headers

* fix(serve): typo in middleware fn name

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

---------

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
2025-09-05 17:50:04 +02:00
12b8e10dbf Skip VitMatteImageProcessingTest::test_fast_is_faster_than_slow (#40713)
* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-05 17:36:20 +02:00
6b232618b6 Keypoint matching docs (#40541)
---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: StevenBucaille <steven.bucaille@gmail.com>
2025-09-05 17:24:56 +02:00
948bc0fa34 [Gemma Embedding] Fix SWA (#40700)
* fix gemma embedding flash attention

* fix sdpa

* fix atttempt number 2

* alternative gemma fix

* fix modular
2025-09-05 17:12:00 +02:00
828044cadb Add Optional typing (#40686)
* Add Optional typing

Signed-off-by: cyy <cyyever@outlook.com>

* Fix typing

Signed-off-by: cyy <cyyever@outlook.com>

* Format

Signed-off-by: cyy <cyyever@outlook.com>

---------

Signed-off-by: cyy <cyyever@outlook.com>
2025-09-05 15:05:51 +00:00
e9d6a6907b [tests] remove overwrites of removed test (#40720)
rm tests from method moved to hub
2025-09-05 16:04:22 +01:00
96a5774f2e [serve] re-enable tests (#40717)
run tests
2025-09-05 15:15:34 +01:00
c76387e580 Fix arguments (#40605)
* Fix invalid arguments

Signed-off-by: cyy <cyyever@outlook.com>

* Fix typing

Signed-off-by: cyy <cyyever@outlook.com>

* Add missing self

Signed-off-by: cyy <cyyever@outlook.com>

* Add missing self and other fixes

Signed-off-by: cyy <cyyever@outlook.com>

*  More fixes

Signed-off-by: cyy <cyyever@outlook.com>

*  More fixes

Signed-off-by: cyy <cyyever@outlook.com>

---------

Signed-off-by: cyy <cyyever@outlook.com>
2025-09-05 13:50:04 +00:00
21f09032db 🔴 Update Glm4V to use config values (#40712)
* update to use config

* just fix it

* fixup want this to be reformatted
2025-09-05 13:19:50 +00:00
b62e5b6051 Fix parent classes of AllKwargsForChatTemplate (#40685)
Fix parent classes of AllKwargsForChatTemplate because the *Kwargs are members

Signed-off-by: cyy <cyyever@outlook.com>
2025-09-05 11:08:51 +00:00
313effa7ad [onnx] use logical or for grounding dino mask (#40625)
* change |= operator to use torch logical or for friendly export to different backends

* change |= operator to use torch logical or for friendly export to different backends in grounding dino model

---------

Co-authored-by: Lewis Marshall <lewism@elderda.co.uk>
2025-09-05 10:55:20 +00:00
f3211b5db7 [moduar] Add missing self in post-process methods (#40711) 2025-09-05 10:49:52 +00:00
a2a8a3ca1e [tests] fix blip2 edge case (#40699) 2025-09-05 11:35:29 +01:00
4e195f1949 🚨 Allow check_model_inputs in core VLMs (#40342)
* allow `check_model_inputs` in core VLMs

* address comments

* fix style

* why this didnt fail prev?

* chec for Noneness instead

* batch update vlms

* fix some tests

* fix copies

* oops delete

* fix efficientloftr

* fix copies

* i am stupid, fix idefics

* fix GC

* return type and other comments

* we shouldn't manually change attention anymore

* fix style

* fix copies

* fix the test
2025-09-05 10:05:56 +00:00
93df343def Fix parent classes of ProcessingKwargs (#40676)
FIx parent classes of ProcessingKwargs

Signed-off-by: cyy <cyyever@outlook.com>
2025-09-05 10:01:16 +00:00
89e103c15e feat(serve): add healthcheck test (#40697) 2025-09-05 11:56:34 +02:00
a2fffa505d Fetch more test data with hf_hub_download (#40710)
[test-all] tests

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-05 09:49:31 +00:00
4a88e81532 Add Fast Image Processor for ImageGPT (#39592)
* initial commit

* initial setup

* Overiding imageGPT specific functions

* imported is_torch_available and utilized it for importing torch in imageGPT fast

* Created init and ImageGPTFastImageProcessorKwargs

* added return_tensors, data_format, and input_data_format to ImageGPTFastImageProcessorKwargs

* set up arguments and process and _preprocess definitions

* Added arguments to _preprocess

* Added additional optional arguments

* Copied logic over from base imageGPT processor

* Implemented 2nd draft of fast imageGPT preprocess using batch processing

* Implemented 3rd draft of imageGPT fast _preprocessor. Pulled logic from BaseImageProcessorFast

* modified imageGPT test file to properly run fast processor tests

* converts images to torch.float32 from torch.unit8

* fixed a typo with self.image_processor_list in the imagegpt test file

* updated more instances of image_processing = self.image_processing_class in the test file to test fast processor

* standardized normalization to not use image mean or std

* Merged changes from solution2 branch

* Merged changes from solution2 test file

* fixed testing through baseImageGPT processor file

* Fixed check_code_quality test. Removed unncessary list comprehension.

* reorganized imports in image_processing_imagegpt_fast

* formatted image_processing_imagegpt_fast.py

* Added arg documentation

* Added FastImageProcessorKwargs class + Docs for new kwargs

* Reformatted previous

* Added F to normalization

* fixed ruff linting and cleaned up fast processor file

* implemented requested changes

* fixed ruff checks

* fixed formatting issues

* fix(ruff after merging main)

* simplify logic and reuse standard equivalenec tests

---------

Co-authored-by: Ethan Ayaay <ayaayethan@gmail.com>
Co-authored-by: chris <christine05789@gmail.com>
Co-authored-by: Ethan Ayaay <98191976+ayaayethan@users.noreply.github.com>
Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co>
2025-09-04 22:45:06 +00:00
9db11b728b Fetch one missing test data (#40703)
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-04 23:05:23 +02:00
acd820561f Align assisted generate for unified signature in decoding methods (#40657)
* Squashed previous branch

* unify assisted generate to common decoding method signature

* move checks to validate steps where possible

* fix csm and other models that override _sample

* ops dia you again

* opsie

* joao review
2025-09-04 22:47:44 +02:00
16b821c542 Avoid T5GemmaModelTest::test_eager_matches_sdpa_inference being flaky (#40702)
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-04 20:44:40 +00:00
519c2524af Fix broken Llama4 accuracy in MoE part (#40609)
* Fix broken Llama4 accuracy in MoE part

Llama4 accuracy is broken by a bug in
https://github.com/huggingface/transformers/pull/39501 . It forgot to
transpose the router_scores before applying it to routed_in, causing
Llama4 to generate garbage output.

This PR fixes that issue by adding back the transpose() and adding some
comments explaining why the transpose() is needed.

Signed-off-by: Po-Han Huang <pohanh@nvidia.com>

* remove comment

---------

Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
2025-09-04 22:14:44 +02:00
586dc5d06e [Glm4.5V] fix vLLM support (#40696)
* fix

* add a test case
2025-09-04 22:09:20 +02:00
ad2da3ea83 Fix self.dropout_p is not defined for SamAttention/Sam2Attention (#40667)
Fix dropout_p is not defined for SamAttention/Sam2Attention
2025-09-04 19:32:39 +02:00
e39f222096 Fix backward compatibility with accelerate in Trainer (#40668) 2025-09-04 18:15:15 +02:00
d8f670583e Change docker image to preview for the MI355 CI (#40693)
* Change docker image to preview for the MI355 CI

* Use pushed image
2025-09-04 17:23:09 +02:00
4cbca0d1af Fixing bug in Voxtral when merging text and audio embeddings (#40671)
* Fixing bug when replacing text-audio token placeholders with audio embeddings

* apply changes

---------

Co-authored-by: Eustache Le Bihan <eulebihan@gmail.com>
Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
2025-09-04 15:11:23 +00:00
9a6c6568db feat: support request cancellation (#40599)
* feat: support request cancellation

* test: add cancellation test

* refactor: use exisitng fn to check req cancellation

* feat(cb): make cancellation thread safe

* refactor(serve): update test to use `requests` instead of `httpx`
2025-09-04 17:01:29 +02:00
87f38dbfce add: embedding model (#40694)
* Gemma 3 for Embeddings

* Style fixes

* Rename conversion file for consistency

* Default padding side emb vs gen

* Corrected 270m config

* style fixes

* EmbeddingGemma config

* TODO for built-in prompts

* Resolving the sentence similarity bug and updating the architecture

* code style

* Add query prompt for SentenceTransformers

* Code quality

* Fixing or_mask_function return types

* Adding placeholder prompts for document and passage

* Finalizing prompt templates

* Adding Retrieval ro preconfigured prompts

* Add Gemma 3 270M Config

* Correcting num_linear_layers flag default

* Export Sentence Transformer in correct dtype

---------

Co-authored-by: Sindhu Raghuram <sindhuraghuram@google.com>
2025-09-04 16:16:15 +02:00
5b0c01b5e2 Final test data cache - inside CI docker images (#40689)
* run

* build

* build

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-04 13:12:49 +00:00
1f3cc935cc Load a tiny video to make CI faster (#40684)
* load a tiny video to make CI faster

* add video in url_to_local_path
2025-09-04 14:49:00 +02:00
669230a86f fix broken offline mode when loading tokenizer from hub (#40669)
* fix broken offline mode when loading tokenizer from hub

* formatting

* make quality

* fix import order
2025-09-04 12:15:56 +00:00
91b34be9cf Add codebook_dim attribute to DacVectorQuantize for DacResidualVectorQuantize.from_latents() (#40665)
* Add instance attribute to DacVectorQuantize for use in DacResidualVectorQuantize.from_latents

* add from_latent tests

* style fix

* Fix style for test_modeling_dac.py
2025-09-04 11:29:53 +00:00
25b4a0d8ae Add sequence classification support for small Gemma 3 text models (#40562)
* add seq class for gemma3 text model

* add Gemma3TextForSequenceClassification to modeling file

* After run make fixup

* let's just check

* thiis is why it was crashing, tests were just failing...

* skip it, tested only for seq clf

---------

Co-authored-by: Raushan Turganbay <raushan@huggingface.co>
2025-09-04 09:44:59 +00:00
30a4b8707d CircleCI docker images cleanup / update / fix (#40681)
* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-04 10:42:18 +02:00
7f92e1f91a Mark Aimv2ModelTest::test_eager_matches_sdpa_inference_04_fp16_pad_right_sdpa_kernels as flaky (#40683)
* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-04 10:30:14 +02:00
ca9b36a9c1 Avoid night torch CI not run because of irrelevant docker image failing to build (#40677)
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-04 09:06:37 +02:00
d40e7ea52d Skip more fast v.s slow image processor tests (#40675)
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-04 06:35:44 +02:00
34595cf296 Even more test data cached (#40636)
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-03 21:20:37 +00:00
f22ec7f174 Benchmarking V2: framework impl (#40486)
* Start revamping benchmarking

* Start refactoring benchmarking

* Use Pandas for CSV

* import fix

* Remove benchmark files

* Remove sample data

* Address review comments

* Benchmarking v2

* Fix llama bench parameters

* Working checkpoint

* Readme touchups

* Remove unnecessary test

* Massage the framework a bit

* Small cleanup

* Remove unnecessary flushes

* Remove references to mock benchmark

* Take commit ID from CLI

* Address review comments

* Use Events for thread comms

* Tiny renaming
2025-09-03 22:26:32 +02:00
459c1fa47a refactor: use tolist instead of list comprehension calling .item() (#40646) 2025-09-03 19:25:29 +02:00
afd1393df1 Remove overwritten GitModelTest::test_beam_search_generate (#40666)
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-03 18:55:45 +02:00
68b9cbb7f5 Skip test_prompt_lookup_decoding_matches_greedy_search for qwen2_audio (#40664)
* Skip `test_prompt_lookup_decoding_matches_greedy_search` for `qwen2_audio`

* Skip `test_prompt_lookup_decoding_matches_greedy_search` for `qwen2_audio`

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-03 18:43:35 +02:00
55676d7d4c Fix warning for output_attentions=True (#40597)
* Fix attn_implementation for output_attentions

* remove setting attention, just raise warning

* improve message

* Update src/transformers/utils/generic.py
2025-09-03 16:25:13 +00:00
b67608f587 Skip test_fast_is_faster_than_slow for Owlv2ImageProcessingTest (#40663)
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-03 17:49:10 +02:00
30d66dc3bc Update check_determinism inside test_determinism (#40661)
* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-03 17:30:39 +02:00
3f40ebf620 Allow custom args in custom_generate Callables and unify generation args structure (#40586)
* Squashed commit of the following:

commit beb2b5f7a04ea9e12876696db66f3589fbae10c5
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Mon Sep 1 16:03:25 2025 +0200

    also standardize _get_stopping_criteria

commit 15c25663fa991e0a215a7f3cdcf13a9d3a989faa
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Mon Sep 1 15:48:38 2025 +0200

    watch super.generate() usages

commit 67dd845be2202d191a54b2872f1cb3f71b74b7d6
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Mon Sep 1 14:44:32 2025 +0200

    ops

commit 4655dfa28fd59d5dc083a41d8396de042d99858c
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Mon Sep 1 14:41:36 2025 +0200

    wrong merge

commit 46478143994e7b27d51c972a7881e0fea3cb6e3c
Merge: a72c2c4b2f 8564e210ca
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Mon Sep 1 14:36:15 2025 +0200

    Merge branch 'main' of github.com:huggingface/transformers into fix-custom-gen-from-function2

commit a72c2c4b2f9c0e09fe6ec7992d4d02bfa279da2a
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Mon Sep 1 14:04:59 2025 +0200

    ops5

commit e72f91411b961979bb3d271810f57905cee5b577
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Mon Sep 1 12:06:19 2025 +0200

    ops4

commit 12ca97b1078a42167143e0243036f6ef87d5fdac
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Mon Sep 1 11:58:59 2025 +0200

    ops3

commit 8cac6c60a318dd381793d4bf1ef3775823f3c95b
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Mon Sep 1 11:43:03 2025 +0200

    ops2

commit 4681a7d5dc6c8b96a515d9d79f06380c096b9a9f
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Mon Sep 1 11:40:51 2025 +0200

    ops

commit 0d72aa6cbd99a5933c5a95a39bea9088ee21e50f
Merge: e0d47e980e 5bb6186b8e
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Mon Sep 1 11:37:28 2025 +0200

    Merge branch 'remove-constrained-bs' into fix-custom-gen-from-function2

commit 5bb6186b8efbd5fdb8e3464a22f958343b9c450c
Merge: 44973dac7d b0db5a02f3
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Mon Sep 1 11:36:30 2025 +0200

    Merge branch 'main' into remove-constrained-bs

commit 44973dac7df4b4e2111c71f5fac918be21f3de52
Merge: 1ddab4bee1 893d89e5e6
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Mon Sep 1 11:29:48 2025 +0200

    Merge commit '893d89e5e6fac7279fe4292bfa3b027172287162' into remove-constrained-bs

commit e0d47e980e26d32b028c2b402ccb71262637a7a7
Merge: 88128e4563 1ddab4bee1
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Mon Sep 1 10:52:50 2025 +0200

    Merge branch 'remove-constrained-bs' into fix-custom-gen-from-function2

commit 88128e4563c0be583728e1d3c639bc93143c4029
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Mon Sep 1 10:44:38 2025 +0200

    fix custom generate args, refactor gen mode args

commit 1ddab4bee159f6c20722e7ff5cd41d5041fab0aa
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Sun Aug 31 21:03:53 2025 +0200

    fix

commit 6095fdda677ef7fbeb06c05f4f914a11b45257b4
Merge: 4a8b6d2ce1 04addbc9ec
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Thu Aug 28 17:49:16 2025 +0200

    Merge branch 'remove-constrained-bs' of github.com:manueldeprada/transformers into remove-constrained-bs

commit 4a8b6d2ce18b3a8b52c5261fea427e2416f65187
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Thu Aug 28 17:48:25 2025 +0200

    restore and deprecate beam obkects

commit 04addbc9ec62dd4f59d15128e8cd9499e2cda3bb
Merge: e800c7841e becab2c601
Author: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com>
Date:   Thu Aug 28 14:38:29 2025 +0200

    Merge branch 'main' into remove-constrained-bs

commit e800c7841e5c46ce5698fc9be309d0808f85d23c
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Thu Aug 28 14:38:10 2025 +0200

    tests gone after green

commit 33971d21ac40aef76a7e1122f4a98ef28beadbe8
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Thu Aug 28 14:07:11 2025 +0200

    tests green, changed handling of deprecated methods

commit ab303835c184d0a87789da7aed7d8de5ba85d867
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Thu Aug 28 12:58:01 2025 +0200

    tests fix

commit ec74274ca52a6aa0b5f300374fda838609680506
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Thu Aug 28 12:32:05 2025 +0200

    ops

commit 0fb19004ccd285dcad485fce0865b355ce5493e0
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Thu Aug 28 11:45:16 2025 +0200

    whoops

commit c946bea5e45aea021c8878c57fcabc2a13f06fe5
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Thu Aug 28 11:35:36 2025 +0200

    testing...

commit 924c0dec6d9ea6b4890644fe7f711dc778f820bb
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Thu Aug 28 11:22:46 2025 +0200

    sweeep ready for tests

commit b05aa771d3994b07cd460cda74b274c9e4f315e6
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Thu Aug 28 11:13:01 2025 +0200

    restore and deprecate constraints

commit 9c7962d10efa7178b69d3c99e69663756e1cd979
Merge: fceeb383f9 c17bf304d5
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Wed Aug 27 20:44:21 2025 +0200

    Merge branch 'remove-group-bs' into remove-constrained-bs

commit c17bf304d5cf33af7f34f9f6057915d5f5821dae
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Wed Aug 27 17:00:50 2025 +0200

    fix test

commit d579aeec6706b77fcc24c1f6806cd7277d7db56e
Merge: 822efd8c3c ed5dd2999c
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Wed Aug 27 16:04:31 2025 +0200

    Merge branch 'main' of github.com:huggingface/transformers into remove-group-bs

commit 822efd8c3cf475d079e64293aa06e4ab59740fd7
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Wed Aug 27 15:59:51 2025 +0200

    aaand remove tests after all green!!

commit 62cb274a4acb9f24201902242f1b0dc4e46daac1
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Wed Aug 27 11:48:19 2025 +0200

    fix

commit c89c892e7b24a7d71831f2b35264456005030925
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Wed Aug 27 11:45:20 2025 +0200

    testing that hub works the same

commit fceeb383f99e4a836679d67b1d2a8520152eaf49
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Tue Aug 26 20:06:59 2025 +0200

    draft

commit 6a9b384078f3798587ba865ac7ddfefc9a79e41c
Merge: 8af3af13ab 58cebc848b
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Tue Aug 26 15:00:05 2025 +0200

    Merge branch 'main' of github.com:huggingface/transformers into remove-group-bs

commit 8af3af13abb85ca60e795d0390832f398a56c34f
Author: Manuel de Prada Corral <manueldeprada@gmail.com>
Date:   Tue Aug 26 11:55:45 2025 +0200

    Squashed commit remove-constrastive-search

* ops

* fix

* ops

* review

* fix

* fix dia

* review
2025-09-03 17:30:09 +02:00
a8f400367d Avoid attention_mask copy in qwen2.5 (#40658)
Signed-off-by: cyy <cyyever@outlook.com>
2025-09-03 15:17:22 +00:00
57f5668d0b Fix Metaclip modular conversion (#40660)
* Fix Metaclip modular conversion

* manually run check_copies
2025-09-03 16:13:50 +01:00
238a8274b4 feat(serving): add healthcheck (#40653) 2025-09-03 16:43:12 +02:00
f2416b4fd2 fix pipeline dtype (#40638)
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
2025-09-03 16:05:48 +02:00
5ea5c8179b Mark LongformerModelTest::test_attention_outputs as flaky (#40655)
* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-03 13:19:02 +00:00
fe1a9e0dba Remove TF/Flax examples (#40654)
* Remove TF/Flax examples

* Remove check_full_copies

* Trigger CI
2025-09-03 14:15:57 +01:00
5e2e496149 fix MetaCLIP 2 wrong link & wrong model names in the docstrings (#40565)
* fix MetaCLIP 2 wrong link & wrong model names in the documentation and docstrings

* ruff reformatted

* update files generated by modular

* update meta_clip2 to metaclip_2 to match the original

* _supports_flash_attn = False

---------

Co-authored-by: Yung-Sung Chuang <yungsung@meta.com>
2025-09-03 13:53:56 +01:00
03708ccf6f add DeepseekV3ForTokenClassification (#40641)
* add DeepseekV3ForTokenClassification

* fix typo

---------

Co-authored-by: json.bourne <json.bourne@kakaocorp.com>
2025-09-03 12:30:09 +00:00
c485c52db4 Skip test_prompt_lookup_decoding_matches_greedy_search for voxtral (#40643)
* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-03 11:45:29 +00:00
2bbf98a83d Fix: PIL image load in Processing utils apply_chat_template (#40622) 2025-09-03 13:06:05 +02:00
acc968c581 [CP] Add attention_mask to the buffer when the mask is causal (#40619)
Fix attention mask validation for context parallelism

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
2025-09-03 10:19:35 +00:00
cb54ce4ec6 [auto-model] propagate kwargs (#40491)
propagate kwargs
2025-09-03 09:59:20 +00:00
ye
0f5e45a6d1 fix: gas for gemma fixed (#40591)
* fix: gas for gemma fixed

* feat: run fix-copies

* feat: added issue label
2025-09-03 08:44:14 +00:00
e690fe61e8 Fix too many requests in TestMistralCommonTokenizer (#40623)
* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-03 05:05:03 +02:00
00a8364271 🌐 [i18n-KO] Translated deepseek_v3.md to Korean (#39649)
* docs: ko: deepseek_v3.md

* feat: nmt draft

* fix: manual edits

* fix: glossary edits

* docs : 4N3MONE recommandced modified contents

* Update docs/source/ko/model_doc/deepseek_v3.md

Co-authored-by: Kim Juwon <81630351+Kim-Ju-won@users.noreply.github.com>

* Update docs/source/ko/model_doc/deepseek_v3.md

Co-authored-by: Kim Juwon <81630351+Kim-Ju-won@users.noreply.github.com>

* add_toctree.yml

---------

Co-authored-by: Kim Juwon <81630351+Kim-Ju-won@users.noreply.github.com>
2025-09-02 13:35:56 -07:00
ed49376a42 Remove random flag (#40629)
remove flag
2025-09-02 19:10:02 +02:00
d47ad91c3c Support TF32 flag for MUSA backend (#33187)
* Support MUSA (Moore Threads GPU) backend in transformers
Add accelerate version check, needs accelerate>=0.33.0

* Support TF32 flag for MUSA backend

* fix typo
2025-09-02 16:27:10 +00:00
a470f21396 Enable more ruff UP rules (#40579)
* Import Sequence from collections.abc

Signed-off-by: cyy <cyyever@outlook.com>

* Apply ruff UP rules

Signed-off-by: cyy <cyyever@outlook.com>

---------

Signed-off-by: cyy <cyyever@outlook.com>
2025-09-02 17:29:59 +02:00
37103d6f22 Fix invalid typing (#40612)
Signed-off-by: cyy <cyyever@outlook.com>
2025-09-02 13:10:22 +00:00
4f542052b9 Remove unnecessary pillow version check (#40604)
Signed-off-by: cyy <cyyever@outlook.com>
2025-09-02 12:59:22 +00:00
8c60a7c385 Add collated reports job to Nvidia CI (#40470)
* Add collated reports job to Nvidia CI

* machine_type

* Move collated reports job to model_jobs

* Propagate repo id variable

* assifgn runner_type is self-scheduled-caller
2025-09-02 14:25:22 +02:00
97266dfd50 Fix flaky JambaModelTest.test_load_balancing_loss (#40617)
* fix

* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-02 13:58:16 +02:00
91be12bdc6 Avoid too many request caused by AutoModelTest::test_dynamic_saving_from_local_repo (#40614)
* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-02 12:08:52 +02:00
bbd8085b0b Fix processor chat template (#40613)
fix tests
2025-09-02 10:59:48 +02:00
b2b1c30b1b fix: continuous batching in transformers serve (#40479)
* fix: continuous batching in `transformers serve`

* fix: short circuit inner gen loop when prepare_next_batch prepared nothing

* docs: add comment explaining FastAPI lifespan

* test: add CB serving tests

* refactor: remove gen cfg max new tokens override bc unnecessary

* docs: add docstring for `ServeCommand::run`

* feat: use new `DecodeStream` API
2025-09-02 10:45:05 +02:00
8a091cc07c Disable cache for TokenizerTesterMixin temporarily (#40611)
* try no cache

* try no cache

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-02 08:40:04 +02:00
514b3e81b7 Multiple fixes to FA tests in AMD (#40498)
* Expectations for gemma3

* Fixes for Qwen2_5_VL tests

* Added expectation but underlying pb is still there

* Better handling of mrope section for Qwen2_5_vl

* Fixes for FA2 tests and reformat batch test for Qwen2_5_Omni

* Fix multi-device error in qwen2_5_omni

* Styel and repo-consistency

* Removed inherited test because fix in common

* slow tests fixes

* Style

* Fixes for qwen2_5_vl or omni for FA test
2025-09-01 20:49:50 +02:00
b3655507bb Pin torchcodec to 0.5 in AMD docker (#40598) 2025-09-01 20:39:55 +02:00
4da03d7f57 Reduce more test data fetch (#40595)
* example

* fix

* fix

* add to fetch script

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-01 18:07:18 +02:00
abf5900a76 [Tests] Fixup duplicated mrope logic (#40592)
cleanup duplicated logic
2025-09-01 17:22:34 +02:00
3beac9c659 Fix quite a lot of FA tests (#40548)
* fix_rope_change

* fix

* do it dynamically

* style

* simplify a lot

* better fix

* fix

* fix

* fix

* fix

* style

* fix
2025-09-01 16:42:50 +02:00
21e708c8fd Fix for missing default values in encoder decoder (#40517)
* Added default_value for is_updated and type check

* Forgot one

* Repo consistency
2025-09-01 16:11:23 +02:00
c99d43e6ec Fix siglip flaky test_eager_matches_sdpa_inference (#40584)
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-01 15:17:25 +02:00
3c3dac3c12 Add Copilot instructions (#40432)
* Add copilot-instructions.md

* Fix typo

* Update .github/copilot-instructions.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
2025-09-01 14:09:54 +01:00
2b71c5b7a6 Fix inexistent imports (#40580)
Signed-off-by: cyy <cyyever@outlook.com>
2025-09-01 13:05:00 +00:00
8e0b2c8baf Skip TvpImageProcessingTest::test_slow_fast_equivalence (#40593)
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-01 15:03:34 +02:00
a543095c99 Fix typos (#40585)
Signed-off-by: cyy <cyyever@outlook.com>
2025-09-01 12:58:23 +00:00
8564e210ca 🚨 Remove Constrained Beam Search decoding strategy (#40518)
* Squashed remove-constrastive-search

* sweeep ready for tests

* testing...

* whoops

* ops

* tests fix

* tests green, changed handling of deprecated methods

* tests gone after green

* restore and deprecate beam obkects

* restore and deprecate constraint objects

* fix ci

* review
2025-09-01 12:34:48 +00:00
564be6d895 Support batch size > 1 image-text inference (#36682)
* update make nested image list

* fix make flat list of images

* update type anno

* fix image_processing_smolvlm

* use first image

* add verbose comment

* fix images

* rollback

* fix ut

* Update image_processing_smolvlm.py

* Update image_processing_idefics3.py

* add tests and fix some processors

* fix copies

* fix after rebase

* make the test cover chat templates

* sjip udop, no point in fixing it

* fix after rebase

* fix a few more tests

---------

Co-authored-by: Pavel Iakubovskii <qubvel@gmail.com>
Co-authored-by: raushan <raushan@huggingface.co>
2025-09-01 12:26:07 +00:00
3bccb02616 🚨 Remove Group Beam Search decoding strategy (#40495)
* Squashed remove-constrastive-search

* testing that tests pass using hub

* fix

* aaand remove tests after all green!!
2025-09-01 13:42:48 +02:00
90953d5bc1 Fix custom generate relative imports (#40480) 2025-09-01 13:38:56 +02:00
2537ed4477 Update get_*_features methods + update doc snippets (#40555)
* siglip

* clip

* aimv2

* metaclip_2

* align

* align fixup

* altclip

* blip2 (make consistent)

* chineese clip

* clipseg

* flava

* groupvit

* owlv2

* owlvit

* vision_encoder

* clap

* x_clip

* fixup

* fix siglip2

* blip2

* fix blip2 tests (revert to original)

* fix docs
2025-09-01 12:37:43 +01:00
48ebae975e Fix llava image processor (#40588)
fix
2025-09-01 13:32:57 +02:00
db6821b79c Allow remi-or to run-slow (#40590)
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-01 12:30:53 +02:00
6546f288a1 Fix CircleCI step passes in the case of pytest worker crash at test collection time (#40552)
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-01 11:33:23 +02:00
cfed99d310 Fix test_eager_matches_sdpa_inference not run for CLIP (#40581)
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-09-01 11:21:56 +02:00
1d742644c0 [qwen-vl] fix position ids (#40490)
* fix position ids

* fixup

* adjust tests since they are failing on main as well

* add a comment to make it clear
2025-09-01 09:10:41 +00:00
0b24507379 processor tests - use dummy videos (#40537)
* use dummy videos

* failing on main, new model merged had conflicts
2025-09-01 09:04:47 +00:00
b0db5a02f3 Set test_all_params_have_gradient=False for DeepseekV2ModelTest (#40566)
fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-08-30 22:46:31 +02:00
1363fceeec remove the redundant non maintained jieba and use rjieba instead (#40383)
* porting not maintained jieba to rjieba

* Fix format

* replaced the line with rjieba instead of removing it

* cut_all is not included as a parameter. cut_all is a seperate function rjieba

* rev

* jieba remove installation

* Trigger tests

* Update tokenization_cpm.py

* Update tokenization_cpm_fast.py

---------

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
2025-08-30 13:28:52 +02:00
36fddebcee pin pytest-rerunfailures<16.0 (#40561)
ping pytest-rerunfailures<16.0

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
2025-08-30 12:58:44 +02:00
2d3b8863e8 Fix collated reports upload filename (#40556) 2025-08-30 09:35:51 +02:00
ce48e9cac0 Dev version 2025-08-29 20:17:34 +02:00
155fd926d2 Fix GptOssModelTest::test_assisted_decoding_matches_greedy_search_1_same (#40551)
* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Manuel de Prada Corral <6536835+manueldeprada@users.noreply.github.com>
2025-08-29 15:53:53 +00:00
80afd1e861 fix bump! 2025-08-29 12:54:01 +02:00
815 changed files with 9646 additions and 31697 deletions

View File

@ -177,11 +177,29 @@ class CircleCIJob:
"command": f"TESTS=$(circleci tests split --split-by=timings {self.job_name}_test_list.txt) && echo $TESTS > splitted_tests.txt && echo $TESTS | tr ' ' '\n'" if self.parallelism else f"awk '{{printf \"%s \", $0}}' {self.job_name}_test_list.txt > splitted_tests.txt"
}
},
{"run": {"name": "fetch hub objects before pytest", "command": "python3 utils/fetch_hub_objects_for_ci.py"}},
# During the CircleCI docker images build time, we might already (or not) download the data.
# If it's done already, the files are inside the directory `/test_data/`.
{"run": {"name": "fetch hub objects before pytest", "command": "cp -r /test_data/* . 2>/dev/null || true; python3 utils/fetch_hub_objects_for_ci.py"}},
{"run": {
"name": "Run tests",
"command": f"({timeout_cmd} python3 -m pytest {marker_cmd} -n {self.pytest_num_workers} {junit_flags} {repeat_on_failure_flags} {' '.join(pytest_flags)} $(cat splitted_tests.txt) | tee tests_output.txt)"}
},
{"run":
{
"name": "Check for test crashes",
"when": "always",
"command": """if [ ! -f tests_output.txt ]; then
echo "ERROR: tests_output.txt does not exist - tests may not have run properly"
exit 1
elif grep -q "crashed and worker restarting disabled" tests_output.txt; then
echo "ERROR: Worker crash detected in test output"
echo "Found: crashed and worker restarting disabled"
exit 1
else
echo "Tests output file exists and no worker crashes detected"
fi"""
},
},
{"run": {"name": "Expand to show skipped tests", "when": "always", "command": f"python3 .circleci/parse_test_outputs.py --file tests_output.txt --skip"}},
{"run": {"name": "Failed tests: show reasons", "when": "always", "command": f"python3 .circleci/parse_test_outputs.py --file tests_output.txt --fail"}},
{"run": {"name": "Errors", "when": "always", "command": f"python3 .circleci/parse_test_outputs.py --file tests_output.txt --errors"}},
@ -246,7 +264,6 @@ custom_tokenizers_job = CircleCIJob(
docker_image=[{"image": "huggingface/transformers-custom-tokenizers"}],
)
examples_torch_job = CircleCIJob(
"examples_torch",
additional_env={"OMP_NUM_THREADS": 8},
@ -270,19 +287,6 @@ hub_job = CircleCIJob(
resource_class="medium",
)
onnx_job = CircleCIJob(
"onnx",
docker_image=[{"image":"huggingface/transformers-torch-tf-light"}],
install_steps=[
"uv pip install .[testing,sentencepiece,onnxruntime,vision,rjieba]",
],
pytest_options={"k onnx": None},
pytest_num_workers=1,
resource_class="small",
)
exotic_models_job = CircleCIJob(
"exotic_models",
docker_image=[{"image":"huggingface/transformers-exotic-models"}],
@ -290,7 +294,6 @@ exotic_models_job = CircleCIJob(
pytest_options={"durations": 100},
)
repo_utils_job = CircleCIJob(
"repo_utils",
docker_image=[{"image":"huggingface/transformers-consistency"}],
@ -298,7 +301,6 @@ repo_utils_job = CircleCIJob(
resource_class="large",
)
non_model_job = CircleCIJob(
"non_model",
docker_image=[{"image": "huggingface/transformers-torch-light"}],
@ -334,7 +336,7 @@ doc_test_job = CircleCIJob(
pytest_num_workers=1,
)
REGULAR_TESTS = [torch_job, hub_job, onnx_job, tokenization_job, processor_job, generate_job, non_model_job] # fmt: skip
REGULAR_TESTS = [torch_job, hub_job, tokenization_job, processor_job, generate_job, non_model_job] # fmt: skip
EXAMPLES_TESTS = [examples_torch_job]
PIPELINE_TESTS = [pipelines_torch_job]
REPO_UTIL_TESTS = [repo_utils_job]

39
.github/copilot-instructions.md vendored Normal file
View File

@ -0,0 +1,39 @@
# copilot-instructions.md Guide for Hugging Face Transformers
This copilot-instructions.md file provides guidance for code agents working with this codebase.
## Core Project Structure
- `/src/transformers`: This contains the core source code for the library
- `/models`: Code for individual models. Models inherit from base classes in the root `/src/transformers` directory.
- `/tests`: This contains the core test classes for the library. These are usually inherited rather than directly run.
- `/models`: Tests for individual models. Model tests inherit from common tests in the root `/tests` directory.
- `/docs`: This contains the documentation for the library, including guides, tutorials, and API references.
## Coding Conventions for Hugging Face Transformers
- PRs should be as brief as possible. Bugfix PRs in particular can often be only one or two lines long, and do not need large comments, docstrings or new functions in this case. Aim to minimize the size of the diff.
- When writing tests, they should be added to an existing file. The only exception is for PRs to add a new model, when a new test directory should be created for that model.
- Code style is enforced in the CI. You can install the style tools with `pip install -e .[quality]`. You can then run `make fixup` to apply style and consistency fixes to your code.
## Copying and inheritance
Many models in the codebase have similar code, but it is not shared by inheritance because we want each model file to be self-contained.
We use two mechanisms to keep this code in sync:
- "Copied from" syntax. Functions or entire classes can have a comment at the top like this: `# Copied from transformers.models.llama.modeling_llama.rotate_half` or `# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->MT5`
These comments are actively checked by the style tools, and copies will automatically be updated when the base code is updated. If you need to update a copied function, you should
either update the base function and use `make fixup` to propagate the change to all copies, or simply remove the `# Copied from` comment if that is inappropriate.
- "Modular" files. These files briefly define models by composing them using inheritance from other models. They are not meant to be used directly. Instead, the style tools
automatically generate a complete modeling file, like `modeling_bert.py`, from the modular file like `modular_bert.py`. If a model has a modular file, the modeling file
should never be edited directly! Instead, changes should be made in the modular file, and then you should run `make fixup` to update the modeling file automatically.
When adding new models, you should prefer `modular` style and inherit as many classes as possible from existing models.
## Testing
After making changes, you should usually run `make fixup` to ensure any copies and modular files are updated, and then test all affected models. This includes both
the model you made the changes in and any other models that were updated by `make fixup`. Tests can be run with `pytest tests/models/[name]/test_modeling_[name].py`
If your changes affect code in other classes like tokenizers or processors, you should run those tests instead, like `test_processing_[name].py` or `test_tokenization_[name].py`.
In order to run tests, you may need to install dependencies. You can do this with `pip install -e .[testing]`. You will probably also need to `pip install torch accelerate` if your environment does not already have them.

View File

@ -26,7 +26,7 @@ jobs:
strategy:
matrix:
file: ["quality", "consistency", "custom-tokenizers", "torch-light", "tf-light", "exotic-models", "torch-tf-light", "jax-light", "examples-torch", "examples-tf"]
file: ["quality", "consistency", "custom-tokenizers", "torch-light", "exotic-models", "examples-torch"]
continue-on-error: true
steps:

View File

@ -2,6 +2,10 @@ name: Build docker images (Nightly CI)
on:
workflow_call:
inputs:
job:
required: true
type: string
push:
branches:
- build_nightly_ci_docker_image*
@ -12,7 +16,8 @@ concurrency:
jobs:
latest-with-torch-nightly-docker:
name: "Nightly PyTorch + Stable TensorFlow"
name: "Nightly PyTorch"
if: inputs.job == 'latest-with-torch-nightly-docker' || inputs.job == ''
runs-on:
group: aws-general-8-plus
steps:
@ -41,6 +46,7 @@ jobs:
nightly-torch-deepspeed-docker:
name: "Nightly PyTorch + DeepSpeed"
if: inputs.job == 'nightly-torch-deepspeed-docker' || inputs.job == ''
runs-on:
group: aws-g4dn-2xlarge-cache
steps:

View File

@ -25,6 +25,12 @@ on:
required: false
default: run_models_gpu
type: string
runner_type:
required: false
type: string
report_repo_id:
required: false
type: string
env:
HF_HOME: /mnt/cache
@ -143,3 +149,15 @@ jobs:
with:
name: ${{ env.machine_type }}_${{ inputs.report_name_prefix }}_${{ env.matrix_folders }}_test_reports
path: /transformers/reports/${{ env.machine_type }}_${{ inputs.report_name_prefix }}_${{ matrix.folders }}_test_reports
collated_reports:
name: Collated Reports
if: ${{ always() }}
needs: run_models_gpu
uses: huggingface/transformers/.github/workflows/collated-reports.yml@main
with:
job: run_models_gpu
report_repo_id: ${{ inputs.report_repo_id }}
gpu_name: ${{ inputs.runner_type }}
machine_type: ${{ inputs.machine_type }}
secrets: inherit

View File

@ -29,7 +29,7 @@ jobs:
runs-on: ubuntu-22.04
name: Get PR number
# For security: only allow team members to run
if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "qubvel", "molbap", "gante", "LysandreJik", "Cyrilvallez", "Rocketknight1", "SunMarc", "muellerzr", "eustlb", "MekkCyber", "manueldeprada", "vasqu", "ivarflakstad", "stevhliu", "ebezzam"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }}
if: ${{ github.event.issue.state == 'open' && contains(fromJSON('["ydshieh", "ArthurZucker", "zucchini-nlp", "qubvel", "molbap", "gante", "LysandreJik", "Cyrilvallez", "Rocketknight1", "SunMarc", "muellerzr", "eustlb", "MekkCyber", "manueldeprada", "vasqu", "ivarflakstad", "stevhliu", "ebezzam", "remi-or"]'), github.actor) && (startsWith(github.event.comment.body, 'run-slow') || startsWith(github.event.comment.body, 'run slow') || startsWith(github.event.comment.body, 'run_slow')) }}
outputs:
PR_NUMBER: ${{ steps.set_pr_number.outputs.PR_NUMBER }}
steps:

View File

@ -22,6 +22,8 @@ jobs:
build_nightly_torch_ci_images:
name: Build CI Docker Images with nightly torch
uses: ./.github/workflows/build-nightly-ci-docker-images.yml
with:
job: latest-with-torch-nightly-docker
secrets: inherit
setup:

View File

@ -21,7 +21,7 @@ jobs:
job: run_models_gpu
slack_report_channel: "#amd-hf-ci"
runner_scale_set: amd-mi355-ci
docker: huggingface/transformers-pytorch-amd-gpu
docker: huggingface/testing-rocm7.0-preview
ci_event: Scheduled CI (AMD) - mi355
report_repo_id: optimum-amd/transformers_daily_ci
secrets: inherit
@ -33,7 +33,7 @@ jobs:
job: run_pipelines_torch_gpu
slack_report_channel: "#amd-hf-ci"
runner_scale_set: amd-mi355-ci
docker: huggingface/transformers-pytorch-amd-gpu
docker: huggingface/testing-rocm7.0-preview
ci_event: Scheduled CI (AMD) - mi355
report_repo_id: optimum-amd/transformers_daily_ci
secrets: inherit
@ -45,7 +45,7 @@ jobs:
job: run_examples_gpu
slack_report_channel: "#amd-hf-ci"
runner_scale_set: amd-mi355-ci
docker: huggingface/transformers-pytorch-amd-gpu
docker: huggingface/testing-rocm7.0-preview
ci_event: Scheduled CI (AMD) - mi355
report_repo_id: optimum-amd/transformers_daily_ci
secrets: inherit
@ -57,7 +57,7 @@ jobs:
job: run_torch_cuda_extensions_gpu
slack_report_channel: "#amd-hf-ci"
runner_scale_set: amd-mi355-ci
docker: huggingface/transformers-pytorch-deepspeed-amd-gpu
docker: huggingface/testing-rocm7.0-preview
ci_event: Scheduled CI (AMD) - mi355
report_repo_id: optimum-amd/transformers_daily_ci
secrets: inherit

View File

@ -52,6 +52,7 @@ jobs:
slack_report_channel: "#transformers-ci-daily-models"
docker: huggingface/transformers-all-latest-gpu
ci_event: Daily CI
runner_type: "a10"
report_repo_id: hf-internal-testing/transformers_daily_ci
commit_sha: ${{ github.sha }}
secrets: inherit

View File

@ -31,6 +31,9 @@ on:
commit_sha:
required: false
type: string
runner_type:
required: false
type: string
models:
default: ""
required: false
@ -126,6 +129,8 @@ jobs:
runner_map: ${{ needs.setup.outputs.runner_map }}
docker: ${{ inputs.docker }}
commit_sha: ${{ inputs.commit_sha || github.sha }}
runner_type: ${{ inputs.runner_type }}
report_repo_id: ${{ inputs.report_repo_id }}
secrets: inherit
run_trainer_and_fsdp_gpu:

1
benchmark_v2/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
benchmark_results/

98
benchmark_v2/README.md Normal file
View File

@ -0,0 +1,98 @@
# Benchmarking v2
A comprehensive benchmarking framework for transformer models that supports multiple execution modes (eager, compiled, kernelized), detailed performance metrics collection, and structured output format.
## Quick Start
### Running All Benchmarks
```bash
# Run all benchmarks with default settings
python run_benchmarks.py
# Specify output directory
python run_benchmarks.py --output-dir my_results
# Run with custom parameters
python run_benchmarks.py \
--warmup-iterations 5 \
--measurement-iterations 10 \
--num-tokens-to-generate 200
```
### Running Specific Benchmarks
```bash
# Include only specific benchmarks
python run_benchmarks.py --include llama
# Exclude specific benchmarks
python run_benchmarks.py --exclude old_benchmark
## Output Format
Results are saved as JSON files with the following structure:
```json
{
"model_name": "llama_2_7b",
"benchmark_scenarios": [
{
"scenario_name": "eager_variant",
"metadata": {
"timestamp": "2025-01-XX...",
"commit_id": "abc123...",
"hardware_info": {
"gpu_name": "NVIDIA A100",
"gpu_memory_total": 40960,
"cpu_count": 64
},
"config": {
"variant": "eager",
"warmup_iterations": 3,
"measurement_iterations": 5
}
},
"measurements": {
"latency": {
"mean": 2.45,
"median": 2.43,
"std": 0.12,
"min": 2.31,
"max": 2.67,
"p95": 2.61,
"p99": 2.65
},
"time_to_first_token": {
"mean": 0.15,
"std": 0.02
},
"tokens_per_second": {
"mean": 87.3,
"unit": "tokens/sec"
}
},
"gpu_metrics": {
"gpu_utilization_mean": 85.2,
"gpu_memory_used_mean": 12450
}
}
]
}
```
### Debug Mode
```bash
python run_benchmarks.py --log-level DEBUG
```
## Contributing
To add new benchmarks:
1. Create a new file in `benches/`
2. Implement the `ModelBenchmark` interface
3. Add a runner function (`run_<benchmark_name>` or `run_benchmark`)
4. run_benchmarks.py

View File

@ -0,0 +1 @@
# Benchmark implementations directory

View File

@ -0,0 +1,156 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
from typing import Dict, Any, List
from benchmark_framework import ModelBenchmark
import torch
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "1"
torch.set_float32_matmul_precision("high")
class LLaMABenchmark(ModelBenchmark):
"""Simplified LLaMA model benchmark implementation using the ModelBenchmark base class."""
def __init__(self, logger: logging.Logger):
super().__init__(logger)
self._default_prompt = "Why dogs are so cute?" # Custom prompt for LLaMA
def get_scenario_configs(self) -> List[Dict[str, Any]]:
"""
Get LLaMA-specific scenario configurations.
Returns:
List of scenario configuration dictionaries
"""
return [
# Eager variants
{"variant": "eager", "compile_mode": None, "use_cache": True, "description": "Eager execution with cache"},
# Compiled variants
{"variant": "compiled", "compile_mode": "max-autotune", "use_cache": True, "description": "Compiled with max autotune"},
# Kernelized variant (if available)
{"variant": "kernelized", "compile_mode": "max-autotune", "use_cache": True, "description": "Kernelized execution"},
]
def _is_kernelization_available(self) -> bool:
"""Check if kernelization is available for LLaMA."""
try:
from kernels import Mode, kernelize
return True
except ImportError:
self.logger.debug("Kernelization not available: kernels module not found")
return False
def get_default_generation_config(self) -> Dict[str, Any]:
"""Get LLaMA-specific generation configuration."""
return {
"do_sample": False,
"top_p": 1.0,
"temperature": 1.0,
"repetition_penalty": 1.0,
"max_new_tokens": None, # Will be set per scenario
}
def get_model_init_kwargs(self, config) -> Dict[str, Any]:
"""Get LLaMA-specific model initialization kwargs."""
from benchmark_framework import BenchmarkConfig
return {
"torch_dtype": getattr(torch, config.torch_dtype),
"attn_implementation": config.attn_implementation,
"use_cache": True,
}
def get_default_torch_dtype(self) -> str:
"""Get default torch dtype for LLaMA."""
return "float16" # LLaMA works well with float16
def get_default_device(self) -> str:
"""Get default device for LLaMA."""
return "cuda" # LLaMA prefers CUDA
def run_llama(logger, output_dir, **kwargs):
"""
Run LLaMA benchmark with the given configuration.
Args:
logger: Logger instance
output_dir: Output directory for results
**kwargs: Additional configuration options
Returns:
Path to output file if successful
"""
from benchmark_framework import BenchmarkRunner
# Extract parameters with defaults
model_id = kwargs.get('model_id', 'meta-llama/Llama-2-7b-hf')
warmup_iterations = kwargs.get('warmup_iterations', 3)
measurement_iterations = kwargs.get('measurement_iterations', 5)
num_tokens_to_generate = kwargs.get('num_tokens_to_generate', 100)
include_sdpa_variants = kwargs.get('include_sdpa_variants', True)
device = kwargs.get('device', 'cuda')
torch_dtype = kwargs.get('torch_dtype', 'float16')
batch_size = kwargs.get('batch_size', 1)
commit_id = kwargs.get('commit_id', None)
logger.info(f"Starting LLaMA benchmark for model: {model_id}")
logger.info(f"Configuration: warmup={warmup_iterations}, measurement={measurement_iterations}, tokens={num_tokens_to_generate}")
try:
# Create benchmark instance
benchmark = LLaMABenchmark(logger)
# Create scenarios
scenarios = benchmark.create_scenarios(
model_id=model_id,
warmup_iterations=warmup_iterations,
measurement_iterations=measurement_iterations,
num_tokens_to_generate=num_tokens_to_generate,
include_sdpa_variants=include_sdpa_variants,
device=device,
torch_dtype=torch_dtype,
batch_size=batch_size
)
logger.info(f"Created {len(scenarios)} benchmark scenarios")
# Create runner and execute benchmarks
runner = BenchmarkRunner(logger, output_dir)
results = runner.run_benchmark(benchmark, scenarios, commit_id=commit_id)
if not results:
logger.warning("No successful benchmark results")
return None
# Save results
model_name = model_id.split('/')[-1] # Extract model name from ID
output_file = runner.save_results(model_name, results)
logger.info(f"LLaMA benchmark completed successfully. Results saved to: {output_file}")
return output_file
except Exception as e:
logger.error(f"LLaMA benchmark failed: {e}")
import traceback
logger.debug(traceback.format_exc())
raise

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,6 @@
numpy>=1.21.0
psutil>=5.8.0
gpustat>=1.0.0
torch>=2.0.0
transformers>=4.30.0
datasets>=2.10.0

385
benchmark_v2/run_benchmarks.py Executable file
View File

@ -0,0 +1,385 @@
#!/usr/bin/env python3
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Top-level benchmarking script that automatically discovers and runs all benchmarks
in the ./benches directory, organizing outputs into model-specific subfolders.
"""
import argparse
import importlib.util
import logging
import os
import sys
import json
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Any, Optional
def setup_logging(log_level: str = "INFO", enable_file_logging: bool = False) -> logging.Logger:
"""Setup logging configuration."""
numeric_level = getattr(logging, log_level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f'Invalid log level: {log_level}')
handlers = [logging.StreamHandler(sys.stdout)]
if enable_file_logging:
handlers.append(
logging.FileHandler(f'benchmark_run_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log')
)
logging.basicConfig(
level=numeric_level,
format='[%(levelname)s - %(asctime)s] %(name)s: %(message)s',
handlers=handlers
)
return logging.getLogger(__name__)
def discover_benchmarks(benches_dir: str) -> List[Dict[str, Any]]:
"""
Discover all benchmark modules in the benches directory.
Returns:
List of dictionaries containing benchmark module info
"""
benchmarks = []
benches_path = Path(benches_dir)
if not benches_path.exists():
raise FileNotFoundError(f"Benches directory not found: {benches_dir}")
for py_file in benches_path.glob("*.py"):
if py_file.name.startswith("__"):
continue
module_name = py_file.stem
try:
# Import the module
spec = importlib.util.spec_from_file_location(module_name, py_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Check if it has a benchmark runner function
if hasattr(module, f'run_{module_name}'):
benchmarks.append({
'name': module_name,
'path': str(py_file),
'module': module,
'runner_function': getattr(module, f'run_{module_name}')
})
elif hasattr(module, 'run_benchmark'):
benchmarks.append({
'name': module_name,
'path': str(py_file),
'module': module,
'runner_function': getattr(module, 'run_benchmark')
})
else:
logging.warning(f"No runner function found in {py_file}")
except Exception as e:
logging.error(f"Failed to import {py_file}: {e}")
return benchmarks
def run_single_benchmark(
benchmark_info: Dict[str, Any],
output_dir: str,
logger: logging.Logger,
**kwargs
) -> Optional[str]:
"""
Run a single benchmark and return the output file path.
Args:
benchmark_info: Dictionary containing benchmark module info
output_dir: Base output directory
logger: Logger instance
**kwargs: Additional arguments to pass to the benchmark
Returns:
Path to the output file if successful, None otherwise
"""
benchmark_name = benchmark_info['name']
runner_func = benchmark_info['runner_function']
logger.info(f"Running benchmark: {benchmark_name}")
try:
# Check function signature to determine what arguments to pass
import inspect
sig = inspect.signature(runner_func)
# Prepare arguments based on function signature
func_kwargs = {
'logger': logger,
'output_dir': output_dir
}
# Add other kwargs if the function accepts them
for param_name in sig.parameters:
if param_name in kwargs:
func_kwargs[param_name] = kwargs[param_name]
# Filter kwargs to only include parameters the function accepts
# If function has **kwargs, include all provided kwargs
has_var_kwargs = any(param.kind == param.VAR_KEYWORD for param in sig.parameters.values())
if has_var_kwargs:
valid_kwargs = {**func_kwargs, **kwargs}
else:
valid_kwargs = {k: v for k, v in func_kwargs.items()
if k in sig.parameters}
# Run the benchmark
result = runner_func(**valid_kwargs)
if isinstance(result, str):
# Function returned a file path
return result
else:
logger.info(f"Benchmark {benchmark_name} completed successfully")
return "completed"
except Exception as e:
logger.error(f"Benchmark {benchmark_name} failed: {e}")
import traceback
logger.debug(traceback.format_exc())
return None
def generate_summary_report(
output_dir: str,
benchmark_results: Dict[str, Any],
logger: logging.Logger
) -> str:
"""Generate a summary report of all benchmark runs."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
summary_file = os.path.join(output_dir, f"benchmark_summary_{timestamp}.json")
summary_data = {
"run_metadata": {
"timestamp": datetime.utcnow().isoformat(),
"total_benchmarks": len(benchmark_results),
"successful_benchmarks": len([r for r in benchmark_results.values() if r is not None]),
"failed_benchmarks": len([r for r in benchmark_results.values() if r is None])
},
"benchmark_results": benchmark_results,
"output_directory": output_dir
}
with open(summary_file, 'w') as f:
json.dump(summary_data, f, indent=2, default=str)
logger.info(f"Summary report saved to: {summary_file}")
return summary_file
def main():
"""Main entry point for the benchmarking script."""
parser = argparse.ArgumentParser(
description="Run all benchmarks in the ./benches directory"
)
parser.add_argument(
"--output-dir",
type=str,
default="benchmark_results",
help="Base output directory for benchmark results (default: benchmark_results)"
)
parser.add_argument(
"--benches-dir",
type=str,
default="./benches",
help="Directory containing benchmark implementations (default: ./benches)"
)
parser.add_argument(
"--log-level",
type=str,
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO",
help="Logging level (default: INFO)"
)
parser.add_argument(
"--model-id",
type=str,
help="Specific model ID to benchmark (if supported by benchmarks)"
)
parser.add_argument(
"--warmup-iterations",
type=int,
default=3,
help="Number of warmup iterations (default: 3)"
)
parser.add_argument(
"--measurement-iterations",
type=int,
default=5,
help="Number of measurement iterations (default: 5)"
)
parser.add_argument(
"--num-tokens-to-generate",
type=int,
default=100,
help="Number of tokens to generate in benchmarks (default: 100)"
)
parser.add_argument(
"--include",
type=str,
nargs="*",
help="Only run benchmarks matching these names"
)
parser.add_argument(
"--exclude",
type=str,
nargs="*",
help="Exclude benchmarks matching these names"
)
parser.add_argument(
"--enable-mock",
action="store_true",
help="Enable mock benchmark (skipped by default)"
)
parser.add_argument(
"--enable-file-logging",
action="store_true",
help="Enable file logging (disabled by default)"
)
parser.add_argument(
"--commit-id",
type=str,
help="Git commit ID for metadata (if not provided, will auto-detect from git)"
)
args = parser.parse_args()
# Setup logging
logger = setup_logging(args.log_level, args.enable_file_logging)
logger.info("Starting benchmark discovery and execution")
logger.info(f"Output directory: {args.output_dir}")
logger.info(f"Benches directory: {args.benches_dir}")
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
try:
# Discover benchmarks
benchmarks = discover_benchmarks(args.benches_dir)
logger.info(f"Discovered {len(benchmarks)} benchmark(s): {[b['name'] for b in benchmarks]}")
if not benchmarks:
logger.warning("No benchmarks found!")
return 1
# Filter benchmarks based on include/exclude
filtered_benchmarks = benchmarks
if args.include:
filtered_benchmarks = [b for b in filtered_benchmarks
if any(pattern in b['name'] for pattern in args.include)]
logger.info(f"Filtered to include: {[b['name'] for b in filtered_benchmarks]}")
if args.exclude:
filtered_benchmarks = [b for b in filtered_benchmarks
if not any(pattern in b['name'] for pattern in args.exclude)]
logger.info(f"After exclusion: {[b['name'] for b in filtered_benchmarks]}")
if not filtered_benchmarks:
logger.warning("No benchmarks remaining after filtering!")
return 1
# Prepare common kwargs for benchmarks
benchmark_kwargs = {
'warmup_iterations': args.warmup_iterations,
'measurement_iterations': args.measurement_iterations,
'num_tokens_to_generate': args.num_tokens_to_generate
}
if args.model_id:
benchmark_kwargs['model_id'] = args.model_id
# Add enable_mock flag for mock benchmark
benchmark_kwargs['enable_mock'] = args.enable_mock
# Add commit_id if provided
if args.commit_id:
benchmark_kwargs['commit_id'] = args.commit_id
# Run benchmarks
benchmark_results = {}
successful_count = 0
for benchmark_info in filtered_benchmarks:
result = run_single_benchmark(
benchmark_info,
args.output_dir,
logger,
**benchmark_kwargs
)
benchmark_results[benchmark_info['name']] = result
if result is not None:
successful_count += 1
# Generate summary report
summary_file = generate_summary_report(args.output_dir, benchmark_results, logger)
# Final summary
total_benchmarks = len(filtered_benchmarks)
failed_count = total_benchmarks - successful_count
logger.info("=" * 60)
logger.info("BENCHMARK RUN SUMMARY")
logger.info("=" * 60)
logger.info(f"Total benchmarks: {total_benchmarks}")
logger.info(f"Successful: {successful_count}")
logger.info(f"Failed: {failed_count}")
logger.info(f"Output directory: {args.output_dir}")
logger.info(f"Summary report: {summary_file}")
if failed_count > 0:
logger.warning(f"{failed_count} benchmark(s) failed. Check logs for details.")
return 1
else:
logger.info("All benchmarks completed successfully!")
return 0
except Exception as e:
logger.error(f"Benchmark run failed: {e}")
import traceback
logger.debug(traceback.format_exc())
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@ -2,7 +2,7 @@ FROM python:3.9-slim
ENV PYTHONDONTWRITEBYTECODE=1
ARG REF=main
USER root
RUN apt-get update && apt-get install -y libsndfile1-dev espeak-ng time git cmake wget xz-utils build-essential g++5 libprotobuf-dev protobuf-compiler
RUN apt-get update && apt-get install -y libsndfile1-dev espeak-ng time git cmake wget xz-utils build-essential g++5 libprotobuf-dev protobuf-compiler git-lfs curl
ENV UV_PYTHON=/usr/local/bin/python
RUN pip --no-cache-dir install uv && uv pip install --no-cache-dir -U pip setuptools
@ -15,12 +15,20 @@ RUN mv catch.hpp ../libs/
RUN cmake .. -DCMAKE_INSTALL_PREFIX=/usr/local
RUN make install -j 10
WORKDIR /
RUN uv pip install --no-cache --upgrade 'torch' --index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-cache-dir --no-deps accelerate --extra-index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[ja,testing,sentencepiece,jieba,spacy,ftfy,rjieba]" unidic unidic-lite
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[ja,testing,sentencepiece,spacy,ftfy,rjieba]" unidic unidic-lite
# spacy is not used so not tested. Causes to failures. TODO fix later
RUN uv run python -m unidic download
# fetch test data and hub objects within CircleCI docker images to reduce even more connections
# we don't need a full clone of `transformers` to run `fetch_hub_objects_for_ci.py`
# the data are downloaded to the directory `/test_data` and during CircleCI's CI runtime, we need to move them to the root of `transformers`
RUN mkdir test_data && cd test_data && curl -O https://raw.githubusercontent.com/huggingface/transformers/${REF}/utils/fetch_hub_objects_for_ci.py && python3 fetch_hub_objects_for_ci.py
RUN uv pip uninstall transformers
RUN apt-get clean && rm -rf /var/lib/apt/lists/*

View File

@ -1,13 +0,0 @@
FROM python:3.9-slim
ENV PYTHONDONTWRITEBYTECODE=1
ARG REF=main
USER root
RUN apt-get update && apt-get install -y libsndfile1-dev espeak-ng time git
RUN apt-get install -y g++ cmake
ENV UV_PYTHON=/usr/local/bin/python
RUN pip --no-cache-dir install uv
RUN uv pip install --no-cache-dir -U pip setuptools albumentations seqeval
RUN uv pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[tf-cpu,sklearn,testing,sentencepiece,tf-speech,vision]"
RUN uv pip install --no-cache-dir "protobuf==3.20.3"
RUN uv pip uninstall transformers
RUN apt-get clean && rm -rf /var/lib/apt/lists/*

View File

@ -2,11 +2,18 @@ FROM python:3.9-slim
ENV PYTHONDONTWRITEBYTECODE=1
ARG REF=main
USER root
RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-dev espeak-ng time git g++ cmake pkg-config openssh-client git ffmpeg
RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-dev espeak-ng time git g++ cmake pkg-config openssh-client git-lfs ffmpeg curl
ENV UV_PYTHON=/usr/local/bin/python
RUN pip --no-cache-dir install uv && uv pip install --no-cache-dir -U pip setuptools
RUN uv pip install --no-cache-dir 'torch' 'torchaudio' 'torchvision' 'torchcodec' --index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-deps timm accelerate --extra-index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-cache-dir librosa "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,sentencepiece,vision,testing]" seqeval albumentations jiwer
# fetch test data and hub objects within CircleCI docker images to reduce even more connections
# we don't need a full clone of `transformers` to run `fetch_hub_objects_for_ci.py`
# the data are downloaded to the directory `/test_data` and during CircleCI's CI runtime, we need to move them to the root of `transformers`
RUN mkdir test_data && cd test_data && curl -O https://raw.githubusercontent.com/huggingface/transformers/${REF}/utils/fetch_hub_objects_for_ci.py && python3 fetch_hub_objects_for_ci.py
RUN uv pip uninstall transformers
RUN apt-get clean && rm -rf /var/lib/apt/lists/*

View File

@ -2,16 +2,23 @@ FROM python:3.9-slim
ENV PYTHONDONTWRITEBYTECODE=1
ARG REF=main
USER root
RUN apt-get update && apt-get install -y libsndfile1-dev espeak-ng time git libgl1-mesa-glx libgl1 g++ tesseract-ocr
RUN apt-get update && apt-get install -y libsndfile1-dev espeak-ng time git libgl1 g++ tesseract-ocr git-lfs curl
ENV UV_PYTHON=/usr/local/bin/python
RUN pip --no-cache-dir install uv && uv pip install --no-cache-dir -U pip setuptools
RUN uv pip install --no-cache-dir 'torch' 'torchaudio' 'torchvision' --index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-cache-dir --no-deps timm accelerate
RUN uv pip install -U --upgrade-strategy eager --no-cache-dir pytesseract python-Levenshtein opencv-python nltk
RUN uv pip install -U --no-cache-dir pytesseract python-Levenshtein opencv-python nltk
# RUN uv pip install --no-cache-dir natten==0.15.1+torch210cpu -f https://shi-labs.com/natten/wheels
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[testing, vision]" 'scikit-learn' 'torch-stft' 'nose' 'dataset'
# RUN git clone https://github.com/facebookresearch/detectron2.git
# RUN python3 -m pip install --no-cache-dir -e detectron2
RUN uv pip install 'git+https://github.com/facebookresearch/detectron2.git@92ae9f0b92aba5867824b4f12aa06a22a60a45d3' --no-build-isolation
# fetch test data and hub objects within CircleCI docker images to reduce even more connections
# we don't need a full clone of `transformers` to run `fetch_hub_objects_for_ci.py`
# the data are downloaded to the directory `/test_data` and during CircleCI's CI runtime, we need to move them to the root of `transformers`
RUN mkdir test_data && cd test_data && curl -O https://raw.githubusercontent.com/huggingface/transformers/${REF}/utils/fetch_hub_objects_for_ci.py && python3 fetch_hub_objects_for_ci.py
RUN uv pip uninstall transformers
RUN apt-get clean && rm -rf /var/lib/apt/lists/*

View File

@ -1,10 +0,0 @@
FROM python:3.9-slim
ENV PYTHONDONTWRITEBYTECODE=1
ARG REF=main
USER root
RUN apt-get update && apt-get install -y libsndfile1-dev espeak-ng time git g++ cmake
ENV UV_PYTHON=/usr/local/bin/python
RUN pip --no-cache-dir install uv && uv pip install --no-cache-dir -U pip setuptools
RUN uv pip install --no-cache-dir "scipy<1.13" "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[flax,testing,sentencepiece,flax-speech,vision]"
RUN uv pip uninstall transformers
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean

View File

@ -1,10 +0,0 @@
FROM python:3.9-slim
ENV PYTHONDONTWRITEBYTECODE=1
ARG REF=main
USER root
RUN apt-get update && apt-get install -y libsndfile1-dev espeak-ng time git cmake g++
ENV UV_PYTHON=/usr/local/bin/python
RUN pip --no-cache-dir install uv && uv pip install --no-cache-dir -U pip setuptools
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,tf-cpu,testing,sentencepiece,tf-speech,vision]"
RUN uv pip install --no-cache-dir "protobuf==3.20.3" tensorflow_probability
RUN apt-get clean && rm -rf /var/lib/apt/lists/*

View File

@ -2,10 +2,17 @@ FROM python:3.9-slim
ENV PYTHONDONTWRITEBYTECODE=1
ARG REF=main
USER root
RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-dev espeak-ng time git pkg-config openssh-client git ffmpeg
RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-dev espeak-ng time git pkg-config openssh-client git ffmpeg curl
ENV UV_PYTHON=/usr/local/bin/python
RUN pip --no-cache-dir install uv && uv pip install --no-cache-dir -U pip setuptools
RUN uv pip install --no-cache-dir 'torch' 'torchaudio' 'torchvision' 'torchcodec' --index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-deps timm accelerate --extra-index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-cache-dir librosa "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,sentencepiece,vision,testing]"
# fetch test data and hub objects within CircleCI docker images to reduce even more connections
# we don't need a full clone of `transformers` to run `fetch_hub_objects_for_ci.py`
# the data are downloaded to the directory `/test_data` and during CircleCI's CI runtime, we need to move them to the root of `transformers`
RUN mkdir test_data && cd test_data && curl -O https://raw.githubusercontent.com/huggingface/transformers/${REF}/utils/fetch_hub_objects_for_ci.py && python3 fetch_hub_objects_for_ci.py
RUN uv pip uninstall transformers

View File

@ -1,12 +0,0 @@
FROM python:3.9-slim
ENV PYTHONDONTWRITEBYTECODE=1
ARG REF=main
USER root
RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-dev espeak-ng time git g++ pkg-config openssh-client git
RUN apt-get install -y cmake
ENV UV_PYTHON=/usr/local/bin/python
RUN pip --no-cache-dir install uv && uv pip install --no-cache-dir -U pip setuptools
RUN uv pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[tf-cpu,sklearn,testing,sentencepiece,tf-speech,vision]"
RUN uv pip install --no-cache-dir "protobuf==3.20.3"
RUN uv pip uninstall transformers
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean

View File

@ -1,16 +0,0 @@
FROM python:3.9-slim
ENV PYTHONDONTWRITEBYTECODE=1
ARG REF=main
USER root
RUN apt-get update && apt-get install -y libsndfile1-dev espeak-ng time git g++ cmake pkg-config openssh-client git
ENV UV_PYTHON=/usr/local/bin/python
RUN pip --no-cache-dir install uv && uv pip install --no-cache-dir -U pip setuptools
RUN uv pip install --no-deps accelerate
RUN uv pip install --no-cache-dir 'torch' 'torchvision' 'torchaudio' --index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-cache-dir "scipy<1.13" "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[flax,audio,sklearn,sentencepiece,vision,testing]"
# RUN pip install --no-cache-dir "scipy<1.13" "transformers[flax,testing,sentencepiece,flax-speech,vision]"
RUN uv pip uninstall transformers
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean

View File

@ -2,10 +2,16 @@ FROM python:3.9-slim
ENV PYTHONDONTWRITEBYTECODE=1
ARG REF=main
USER root
RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-dev espeak-ng time git g++ cmake pkg-config openssh-client git git-lfs ffmpeg
RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-dev espeak-ng time git g++ cmake pkg-config openssh-client git-lfs ffmpeg curl
ENV UV_PYTHON=/usr/local/bin/python
RUN pip --no-cache-dir install uv && uv pip install --no-cache-dir -U pip setuptools
RUN uv pip install --no-cache-dir 'torch' 'torchaudio' 'torchvision' 'torchcodec' --index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-deps timm accelerate --extra-index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-cache-dir librosa "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[sklearn,sentencepiece,vision,testing,tiktoken,num2words,video]"
# fetch test data and hub objects within CircleCI docker images to reduce even more connections
# we don't need a full clone of `transformers` to run `fetch_hub_objects_for_ci.py`
# the data are downloaded to the directory `/test_data` and during CircleCI's CI runtime, we need to move them to the root of `transformers`
RUN mkdir test_data && cd test_data && curl -O https://raw.githubusercontent.com/huggingface/transformers/${REF}/utils/fetch_hub_objects_for_ci.py && python3 fetch_hub_objects_for_ci.py
RUN uv pip uninstall transformers

View File

@ -1,19 +0,0 @@
FROM python:3.9-slim
ENV PYTHONDONTWRITEBYTECODE=1
ARG REF=main
RUN echo ${REF}
USER root
RUN apt-get update && apt-get install -y --no-install-recommends libsndfile1-dev espeak-ng time git g++ cmake pkg-config openssh-client git git-lfs
ENV UV_PYTHON=/usr/local/bin/python
RUN pip --no-cache-dir install uv && uv pip install --no-cache-dir -U pip setuptools
RUN uv pip install --no-cache-dir --no-deps accelerate --extra-index-url https://download.pytorch.org/whl/cpu
RUN uv pip install --no-cache-dir 'torch' 'torchaudio' 'torchvision' --index-url https://download.pytorch.org/whl/cpu
RUN git lfs install
RUN uv pip install --no-cache-dir pypi-kenlm
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/transformers.git@${REF}#egg=transformers[tf-cpu,sklearn,sentencepiece,vision,testing]"
RUN uv pip install --no-cache-dir "protobuf==3.20.3" librosa
RUN uv pip uninstall transformers
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean

View File

@ -20,8 +20,6 @@ WORKDIR /
ADD https://api.github.com/repos/huggingface/transformers/git/refs/heads/main version.json
RUN git clone https://github.com/huggingface/transformers && cd transformers && git checkout $REF
# On ROCm, torchcodec is required to decode audio files
# RUN python3 -m pip install --no-cache-dir torchcodec
# Install transformers
RUN python3 -m pip install --no-cache-dir -e ./transformers[dev-torch,testing,video,audio]
@ -37,3 +35,6 @@ RUN python3 -m pip uninstall py3nvml pynvml nvidia-ml-py apex -y
# `kernels` may causes many failing tests
RUN python3 -m pip uninstall -y kernels
# On ROCm, torchcodec is required to decode audio files and 0.4 or 0.6 fails
RUN python3 -m pip install --no-cache-dir "torchcodec==0.5"

View File

@ -39,7 +39,6 @@
| [كيفية ضبط نموذج بدقة على التلخيص](https://github.com/huggingface/notebooks/blob/main/examples/summarization.ipynb)| يوضح كيفية معالجة البيانات مسبقًا وضبط نموذج مُدرَّب مسبقًا بدقة على XSUM. | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/summarization.ipynb)| [![Open in AWS Studio](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/huggingface/notebooks/blob/main/examples/summarization.ipynb)|
| [كيفية تدريب نموذج لغة من البداية](https://github.com/huggingface/blog/blob/main/notebooks/01_how_to_train.ipynb)| تسليط الضوء على جميع الخطوات لتدريب نموذج Transformer بشكل فعال على بيانات مخصصة | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/01_how_to_train.ipynb)| [![Open in AWS Studio](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/huggingface/blog/blob/main/notebooks/01_how_to_train.ipynb)|
| [كيفية إنشاء نص](https://github.com/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)| كيفية استخدام أساليب فك التشفير المختلفة لإنشاء اللغة باستخدام المحولات | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)| [![Open in AWS Studio](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/huggingface/blog/blob/main/notebooks/02_how_to_generate.ipynb)|
| [كيفية إنشاء نص (مع قيود)](https://github.com/huggingface/blog/blob/main/notebooks/53_constrained_beam_search.ipynb)| كيفية توجيه إنشاء اللغة باستخدام القيود التي يوفرها المستخدم | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/53_constrained_beam_search.ipynb)| [![Open in AWS Studio](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/huggingface/blog/blob/main/notebooks/53_constrained_beam_search.ipynb)|
| [Reformer](https://github.com/huggingface/blog/blob/main/notebooks/03_reformer.ipynb)| كيف يدفع Reformer حدود النمذجة اللغوية | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patrickvonplaten/blog/blob/main/notebooks/03_reformer.ipynb)| [![Open in AWS Studio](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/patrickvonplaten/blog/blob/main/notebooks/03_reformer.ipynb)|
#### رؤية الكمبيوتر[[pytorch-cv]]

View File

@ -277,6 +277,8 @@
title: Keypoint detection
- local: tasks/knowledge_distillation_for_image_classification
title: Knowledge Distillation for Computer Vision
- local: tasks/keypoint_matching
title: Keypoint matching
title: Computer vision
- sections:
- local: tasks/image_captioning

View File

@ -225,28 +225,6 @@ outputs = model.generate(**inputs, assistant_model=assistant_model, tokenizer=to
tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
```
### Diverse beam search
[Diverse beam search](https://hf.co/papers/1610.02424) is a variant of beam search that produces more diverse output candidates to choose from. This strategy measures the dissimilarity of sequences and a penalty is applied if sequences are too similar. To avoid high computation costs, the number of beams is divided into groups.
Enable diverse beam search with the `num_beams`, `num_beam_groups` and `diversity_penalty` parameters (the `num_beams` parameter should be divisible by `num_beam_groups`).
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, infer_device
device = infer_device()
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to(device)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", dtype=torch.float16).to(device)
# explicitly set to 100 because Llama2 generation length is 4096
outputs = model.generate(**inputs, max_new_tokens=50, num_beams=6, num_beam_groups=3, diversity_penalty=1.0, do_sample=False)
tokenizer.batch_decode(outputs, skip_special_tokens=True)
'Hugging Face is an open-source company 🤗\nWe are an open-source company. Our mission is to democratize AI and make it accessible to everyone. We believe that AI should be used for the benefit of humanity, not for the benefit of a'
```
## Custom generation methods

View File

@ -108,9 +108,6 @@ generation.
[[autodoc]] ForcedEOSTokenLogitsProcessor
- __call__
[[autodoc]] HammingDiversityLogitsProcessor
- __call__
[[autodoc]] InfNanRemoveLogitsProcessor
- __call__
@ -219,10 +216,6 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- process
- finalize
[[autodoc]] BeamSearchScorer
- process
- finalize
[[autodoc]] ConstrainedBeamSearchScorer
- process
- finalize

View File

@ -102,7 +102,7 @@ You may want to consider offloading if you have a small GPU and you're getting o
Offloading is available for both [`DynamicCache`] and [`StaticCache`]. You can enable it by configuring `cache_implementation="offloaded"` for the dynamic version, or `cache_implementation="offloaded_static"` for the static version, in either [`GenerationConfig`] or [`~GenerationMixin.generate`].
Additionally, you can also instantiate your own [`DynamicCache`] or [`StaticCache`] with the `offloading=True` option, and pass this cache in `generate` or your model's `forward` (for example, `past_key_values=DynamicCache(config=model.config, offloading=True)` for a dynamic cache).
Note that the 2 [`Cache`] classes mentionned above have an additional option when instantiating them directly, `offload_only_non_sliding`.
Note that the 2 [`Cache`] classes mentioned above have an additional option when instantiating them directly, `offload_only_non_sliding`.
This additional argument decides if the layers using sliding window/chunk attention (if any), will be offloaded as well. Since
these layers are usually short anyway, it may be better to avoid offloading them, as offloading may incur a speed penalty. By default, this option is `False` for [`DynamicCache`], and `True` for [`StaticCache`].
@ -146,7 +146,7 @@ tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt, dtype=torch.float16, device_map="auto")
prompt = ["okay "*1000 + "Fun fact: The most"]
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
beams = { "num_beams": 40, "num_beam_groups": 40, "num_return_sequences": 40, "diversity_penalty": 1.0, "max_new_tokens": 23, "early_stopping": True, }
beams = { "num_beams": 40, "num_return_sequences": 20, "max_new_tokens": 23, "early_stopping": True, }
out = resilient_generate(model, **inputs, **beams)
responses = tokenizer.batch_decode(out[:,-28:], skip_special_tokens=True)
```

View File

@ -188,3 +188,8 @@ error, it means NCCL was probably not loaded.
[[autodoc]] DeepseekV3ForSequenceClassification
- forward
## DeepseekV3ForTokenClassification
[[autodoc]] DeepseekV3ForTokenClassification
- forward

View File

@ -148,6 +148,14 @@ processed_outputs = processor.post_process_keypoint_matching(outputs, image_size
- post_process_keypoint_matching
- visualize_keypoint_matching
## EfficientLoFTRImageProcessorFast
[[autodoc]] EfficientLoFTRImageProcessorFast
- preprocess
- post_process_keypoint_matching
- visualize_keypoint_matching
<frameworkcontent>
<pt>
## EfficientLoFTRModel

View File

@ -273,3 +273,8 @@ visualizer("<img>What is shown in this image?")
[[autodoc]] Gemma3ForSequenceClassification
- forward
## Gemma3TextForSequenceClassification
[[autodoc]] Gemma3TextForSequenceClassification
- forward

View File

@ -50,7 +50,7 @@ The `generate()` method can be used to generate text using GPTSAN-Japanese model
>>> model = AutoModel.from_pretrained("Tanrei/GPTSAN-japanese").to(device)
>>> x_tok = tokenizer("は、", prefix_text="織田信長", return_tensors="pt")
>>> torch.manual_seed(0)
>>> gen_tok = model.generate(x_tok.input_ids.to(model.device), token_type_ids=x_tok.token_type_ids.to(mdoel.device), max_new_tokens=20)
>>> gen_tok = model.generate(x_tok.input_ids.to(model.device), token_type_ids=x_tok.token_type_ids.to(model.device), max_new_tokens=20)
>>> tokenizer.decode(gen_tok[0])
'織田信長は、2004年に『戦国BASARA』のために、豊臣秀吉'
```

View File

@ -104,6 +104,11 @@ If you're interested in submitting a resource to be included here, please feel f
[[autodoc]] ImageGPTImageProcessor
- preprocess
## ImageGPTImageProcessorFast
[[autodoc]] ImageGPTImageProcessorFast
- preprocess
## ImageGPTModel
[[autodoc]] ImageGPTModel

View File

@ -32,7 +32,7 @@ MetaCLIP 2 is a replication of the original CLIP model trained on 300+ languages
This model was contributed by [nielsr](https://huggingface.co/nielsr).
The original code can be found [here](https://github.com/facebookresearch/MetaCLIP).
You can find all the MetaCLIP 2 checkpoints under the [Meta](https://huggingface.co/facebook?search_models=metaclip-2) organization.
You can find all the MetaCLIP 2 checkpoints under the [Meta](https://huggingface.co/facebook/models?search=metaclip-2) organization.
> [!TIP]
> Click on the MetaCLIP 2 models in the right sidebar for more examples of how to apply MetaCLIP 2 to different image and language tasks.

View File

@ -124,7 +124,7 @@ Create a [`Pipeline`] object and select a task. By default, [`Pipeline`] downloa
<hfoptions id="pipeline-tasks">
<hfoption id="text generation">
Use [`~infer_device`] to automatically detect an available accelerator for inference.
Use [`infer_device`] to automatically detect an available accelerator for inference.
```py
from transformers import pipeline, infer_device
@ -144,7 +144,7 @@ pipeline("The secret to baking a good cake is ", max_length=50)
</hfoption>
<hfoption id="image segmentation">
Use [`~infer_device`] to automatically detect an available accelerator for inference.
Use [`infer_device`] to automatically detect an available accelerator for inference.
```py
from transformers import pipeline, infer_device
@ -171,7 +171,7 @@ segments[1]["label"]
</hfoption>
<hfoption id="automatic speech recognition">
Use [`~infer_device`] to automatically detect an available accelerator for inference.
Use [`infer_device`] to automatically detect an available accelerator for inference.
```py
from transformers import pipeline, infer_device

View File

@ -21,7 +21,7 @@ Transformer models can be efficiently deployed using libraries such as vLLM, Tex
> [!TIP]
> Responses API is now supported as an experimental API! Read more about it [here](#responses-api).
Apart from that you can also serve transformer models easily using the `transformers serve` CLI. This is ideal for experimentation purposes, or to run models locally for personal and private use.
You can also serve transformer models with the `transformers serve` CLI. With Continuous Batching, `serve` now delivers solid throughput and latency well suited for evaluation, experimentation, and moderate-load local or self-hosted deployments. While vLLM, SGLang, or other inference engines remain our recommendations for large-scale production, `serve` avoids the extra runtime and operational overhead, and is on track to gain more production-oriented features.
In this document, we dive into the different supported endpoints and modalities; we also cover the setup of several user interfaces that can be used on top of `transformers serve` in the following guides:
- [Jan (text and MCP user interface)](./jan.md)
@ -58,7 +58,7 @@ or by sending an HTTP request, like we'll see below.
## Chat Completions - text-based
See below for examples for text-based requests. Both LLMs and VLMs should handle
See below for examples for text-based requests. Both LLMs and VLMs should handle
<hfoptions id="chat-completion-http">
<hfoption id="curl">
@ -366,6 +366,40 @@ The `transformers serve` server is also an MCP client, so it can interact with M
<!-- TODO: example with a minimal python example, and explain that it is possible to pass a full generation config in the request -->
## Continuous Batching
Continuous Batching (CB) lets the server dynamically group and interleave requests so they can share forward passes on the GPU. Instead of processing each request sequentially, `serve` adds new requests as others progress (prefill) and drops finished ones during decode. The result is significantly higher GPU utilization and better throughput without sacrificing latency for most workloads.
Thanks to this, evaluation, experimentation, and moderate-load local/self-hosted use can now be handled comfortably by `transformers serve` without introducing an extra runtime to operate.
### Enable CB in serve
CB is opt-in and currently applies to chat completions.
```sh
transformers serve \
--continuous-batching
--attn_implementation sdpa_paged
```
### Performance tips
- Use an efficient attention backend when available:
```sh
transformers serve \
--continuous_batching \
--attn_implementation paged_attention
```
> [!TIP]
> If you choose `paged_attention`, you must install `flash-attn` separately: `pip install flash-attn --no-build-isolation`
- `--dtype {bfloat16|float16}` typically improve throughput and memory use vs. `float32`
- `--load_in_4bit`/`--load_in_8bit` can reduce memory footprint for LoRA setups
- `--force-model <repo_id>` avoids per-request model hints and helps produce stable, repeatable runs

View File

@ -0,0 +1,129 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Keypoint matching
Keypoint matching matches different points of interests that belong to same object appearing in two different images. Most modern keypoint matchers take images as input and output the following:
- **Keypoint coordinates (x,y):** one-to-one mapping of pixel coordinates between the first and the second image using two lists. Each keypoint at a given index in the first list is matched to the keypoint at the same index in the second list.
- **Matching scores:** Scores assigned to the keypoint matches.
In this tutorial, you will extract keypoint matches with the [`EfficientLoFTR`] model trained with the [MatchAnything framework](https://huggingface.co/zju-community/matchanything_eloftr), and refine the matches. This model is only 16M parameters and can be run on a CPU. You will use the [`AutoModelForKeypointMatching`] class.
```python
from transformers import AutoImageProcessor, AutoModelForKeypointMatching
import torch
processor = AutoImageProcessor.from_pretrained("zju-community/matchanything_eloftr")
model = AutoModelForKeypointMatching.from_pretrained("zju-community/matchanything_eloftr"))
```
Load two images that have the same object of interest. The second photo is taken a second apart, it's colors are edited, and it is further cropped and rotated.
<div style="display: flex; align-items: center;">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
alt="Bee"
style="height: 200px; object-fit: contain; margin-right: 10px;">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee_edited.jpg"
alt="Bee edited"
style="height: 200px; object-fit: contain;">
</div>
```python
from transformers.image_utils import load_image
image1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg")
image2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee_edited.jpg")
images = [image1, image2]
```
We can pass the images to the processor and infer.
```python
inputs = processor(images, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
```
We can postprocess the outputs. The threshold parameter is used to refine noise (lower confidence thresholds) in the output matches.
```python
image_sizes = [[(image.height, image.width) for image in images]]
outputs = processor.post_process_keypoint_matching(outputs, image_sizes, threshold=0.2)
print(outputs)
```
Here's the outputs.
```
[{'keypoints0': tensor([[4514, 550],
[4813, 683],
[1972, 1547],
...
[3916, 3408]], dtype=torch.int32),
'keypoints1': tensor([[2280, 463],
[2378, 613],
[2231, 887],
...
[1521, 2560]], dtype=torch.int32),
'matching_scores': tensor([0.2189, 0.2073, 0.2414, ...
])}]
```
We have trimmed the output but there's 401 matches!
```python
len(outputs[0]["keypoints0"])
# 401
```
We can visualize them using the processor's [`~EfficientLoFTRImageProcessor.visualize_keypoint_matching`] method.
```python
plot_images = processor.visualize_keypoint_matching(images, outputs)
plot_images
```
![Matched Image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/matched_bees.png)
Optionally, you can use the [`Pipeline`] API and set the task to `keypoint-matching`.
```python
from transformers import pipeline
image_1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image_2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee_edited.jpg"
pipe = pipeline("keypoint-matching", model="zju-community/matchanything_eloftr")
pipe([image_1, image_2])
```
The output looks like following.
```bash
[{'keypoint_image_0': {'x': 2444, 'y': 2869},
'keypoint_image_1': {'x': 837, 'y': 1500},
'score': 0.9756593704223633},
{'keypoint_image_0': {'x': 1248, 'y': 2819},
'keypoint_image_1': {'x': 862, 'y': 866},
'score': 0.9735618829727173},
{'keypoint_image_0': {'x': 1547, 'y': 3317},
'keypoint_image_1': {'x': 1436, 'y': 1500},
...
}
]
```

View File

@ -79,7 +79,7 @@ Index the images offline, and during inference, return the query text embeddings
Store the image and image embeddings by writing them to the dataset with [`~datasets.Dataset.map`] as shown below. Add an `embeddings` column that contains the indexed embeddings. ColPali embeddings take up a lot of storage, so remove them from the accelerator and store them in the CPU as NumPy vectors.
```python
ds_with_embeddings = dataset.map(lambda example: {'embeddings': model(**processor(images=example["image"]).to(devide), return_tensors="pt").embeddings.to(torch.float32).detach().cpu().numpy()})
ds_with_embeddings = dataset.map(lambda example: {'embeddings': model(**processor(images=example["image"]).to(device), return_tensors="pt").embeddings.to(torch.float32).detach().cpu().numpy()})
```
For online inference, create a function to search the image embeddings in batches and retrieve the k-most relevant images. The function below returns the indices in the dataset and their scores for a given indexed dataset, text embeddings, number of top results, and the batch size.

View File

@ -241,43 +241,6 @@ time."\n\nHe added: "I am very proud of the work I have been able to do in the l
'Das Haus ist wunderbar.'
```
### Diverse beam search decoding
多様なビームサーチデコーディング戦略は、ビームサーチ戦略の拡張であり、選択肢からより多様なビームシーケンスを生成できるようにします。この仕組みの詳細については、[Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence Models](https://huggingface.co/papers/1610.02424) をご参照ください。このアプローチには、`num_beams``num_beam_groups`、および `diversity_penalty` という3つの主要なパラメータがあります。多様性ペナルティは、出力がグループごとに異なることを保証し、ビームサーチは各グループ内で使用されます。
```python
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
>>> checkpoint = "google/pegasus-xsum"
>>> prompt = (
... "The Permaculture Design Principles are a set of universal design principles "
... "that can be applied to any location, climate and culture, and they allow us to design "
... "the most efficient and sustainable human habitation and food production systems. "
... "Permaculture is a design system that encompasses a wide variety of disciplines, such "
... "as ecology, landscape design, environmental science and energy conservation, and the "
... "Permaculture design principles are drawn from these various disciplines. Each individual "
... "design principle itself embodies a complete conceptual framework based on sound "
... "scientific principles. When we bring all these separate principles together, we can "
... "create a design system that both looks at whole systems, the parts that these systems "
... "consist of, and how those parts interact with each other to create a complex, dynamic, "
... "living system. Each design principle serves as a tool that allows us to integrate all "
... "the separate parts of a design, referred to as elements, into a functional, synergistic, "
... "whole system, where the elements harmoniously interact and work together in the most "
... "efficient way possible."
... )
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
>>> outputs = model.generate(**inputs, num_beams=5, num_beam_groups=5, max_new_tokens=30, diversity_penalty=1.0)
>>> tokenizer.decode(outputs[0], skip_special_tokens=True)
'The Design Principles are a set of universal design principles that can be applied to any location, climate and
culture, and they allow us to design the'
```
### Assisted Decoding
アシストデコーディングは、上記のデコーディング戦略を変更したもので、同じトークナイザー理想的にははるかに小さなモデルを使用して、いくつかの候補トークンを貪欲に生成するアシスタントモデルを使用します。その後、主要なモデルは候補トークンを1つの前向きパスで検証し、デコーディングプロセスを高速化します。現在、アシストデコーディングでは貪欲検索とサンプリングのみがサポートされており、バッチ入力はサポートされていません。アシストデコーディングの詳細については、[このブログ記事](https://huggingface.co/blog/assisted-generation) をご覧ください。

View File

@ -139,9 +139,6 @@ generation_output[:2]
[[autodoc]] ForcedEOSTokenLogitsProcessor
- __call__
[[autodoc]] HammingDiversityLogitsProcessor
- __call__
[[autodoc]] InfNanRemoveLogitsProcessor
- __call__
@ -303,32 +300,6 @@ generation_output[:2]
[[autodoc]] MaxTimeCriteria
- __call__
## Constraints
[`Constraint`] を使用すると、生成時に出力に特定のトークンまたはシーケンスが含まれるように強制できます。これは PyTorch 実装でのみ利用可能であることに注意してください。
[[autodoc]] Constraint
[[autodoc]] PhrasalConstraint
[[autodoc]] DisjunctiveConstraint
[[autodoc]] ConstraintListState
## BeamSearch
[[autodoc]] BeamScorer
- process
- finalize
[[autodoc]] BeamSearchScorer
- process
- finalize
[[autodoc]] ConstrainedBeamSearchScorer
- process
- finalize
## Streamers
[[autodoc]] TextStreamer

View File

@ -507,7 +507,7 @@
title: DeBERTa
- local: model_doc/deberta-v2
title: DeBERTa-v2
- local: in_translation
- local: model_doc/deepseek_v3
title: DeepSeek-V3
- local: in_translation
title: DialoGPT
@ -997,6 +997,8 @@
title: WavLM
- local: model_doc/whisper
title: Whisper
- local: model_doc/xclip
title: xclip
- local: in_translation
title: XLS-R
- local: in_translation

View File

@ -232,44 +232,6 @@ time."\n\nHe added: "I am very proud of the work I have been able to do in the l
'Das Haus ist wunderbar.'
```
### 다양한 빔 탐색 디코딩(Diverse beam search decoding)[[diverse-beam-search-decoding]]
다양한 빔 탐색(Decoding) 전략은 선택할 수 있는 더 다양한 빔 시퀀스 집합을 생성할 수 있게 해주는 빔 탐색 전략의 확장입니다. 이 방법은 어떻게 작동하는지 알아보려면, [다양한 빔 탐색: 신경 시퀀스 모델에서 다양한 솔루션 디코딩하기](https://huggingface.co/papers/1610.02424)를 참조하세요. 이 접근 방식은 세 가지 주요 매개변수를 가지고 있습니다: `num_beams`, `num_beam_groups`, 그리고 `diversity_penalty`. 다양성 패널티는 그룹 간에 출력이 서로 다르게 하기 위한 것이며, 각 그룹 내에서 빔 탐색이 사용됩니다.
```python
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
>>> checkpoint = "google/pegasus-xsum"
>>> prompt = (
... "The Permaculture Design Principles are a set of universal design principles "
... "that can be applied to any location, climate and culture, and they allow us to design "
... "the most efficient and sustainable human habitation and food production systems. "
... "Permaculture is a design system that encompasses a wide variety of disciplines, such "
... "as ecology, landscape design, environmental science and energy conservation, and the "
... "Permaculture design principles are drawn from these various disciplines. Each individual "
... "design principle itself embodies a complete conceptual framework based on sound "
... "scientific principles. When we bring all these separate principles together, we can "
... "create a design system that both looks at whole systems, the parts that these systems "
... "consist of, and how those parts interact with each other to create a complex, dynamic, "
... "living system. Each design principle serves as a tool that allows us to integrate all "
... "the separate parts of a design, referred to as elements, into a functional, synergistic, "
... "whole system, where the elements harmoniously interact and work together in the most "
... "efficient way possible."
... )
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
>>> outputs = model.generate(**inputs, num_beams=5, num_beam_groups=5, max_new_tokens=30, diversity_penalty=1.0)
>>> tokenizer.decode(outputs[0], skip_special_tokens=True)
'The Design Principles are a set of universal design principles that can be applied to any location, climate and
culture, and they allow us to design the'
```
이 가이드에서는 다양한 디코딩 전략을 가능하게 하는 주요 매개변수를 보여줍니다. [`generate`] 메서드에 대한 고급 매개변수가 존재하므로 [`generate`] 메서드의 동작을 더욱 세부적으로 제어할 수 있습니다. 사용 가능한 매개변수의 전체 목록은 [API 문서](./main_classes/text_generation)를 참조하세요.
### 추론 디코딩(Speculative Decoding)[[speculative-decoding]]
추론 디코딩(보조 디코딩(assisted decoding)으로도 알려짐)은 동일한 토크나이저를 사용하는 훨씬 작은 보조 모델을 활용하여 몇 가지 후보 토큰을 생성하는 상위 모델의 디코딩 전략을 수정한 것입니다. 주 모델은 단일 전방 통과로 후보 토큰을 검증함으로써 디코딩 과정을 가속화합니다. `do_sample=True`일 경우, [추론 디코딩 논문](https://huggingface.co/papers/2211.17192)에 소개된 토큰 검증과 재샘플링 방식이 사용됩니다.

View File

@ -131,9 +131,6 @@ generation_output[:2]
[[autodoc]] ForcedEOSTokenLogitsProcessor
- __call__
[[autodoc]] HammingDiversityLogitsProcessor
- __call__
[[autodoc]] InfNanRemoveLogitsProcessor
- __call__
@ -308,32 +305,6 @@ generation_output[:2]
[[autodoc]] EosTokenCriteria
- __call__
## Constraint [[transformers.Constraint]]
[`Constraint`]는 생성 출력에 특정 토큰이나 시퀀스를 강제로 포함시키는 데 사용됩니다. 이 기능은 PyTorch 구현에만 제공됩니다.
[[autodoc]] Constraint
[[autodoc]] PhrasalConstraint
[[autodoc]] DisjunctiveConstraint
[[autodoc]] ConstraintListState
## 빔 검색 (BeamSearch) [[transformers.BeamScorer]]
[[autodoc]] BeamScorer
- process
- finalize
[[autodoc]] BeamSearchScorer
- process
- finalize
[[autodoc]] ConstrainedBeamSearchScorer
- process
- finalize
## 스트리머 (Streamers) [[transformers.TextStreamer]]
[[autodoc]] TextStreamer

View File

@ -0,0 +1,184 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# DeepSeek-V3[[deepseek-v3]]
## 개요[[overview]]
DeepSeek-V3 모델은 [DeepSeek-V3 기술 보고서](https://huggingface.co/papers/2412.19437)에서 DeepSeek-AI 팀에 의해 제안되었습니다.
논문의 초록은 다음과 같습니다.
총 671B개의 파라미터를 가지며 토큰당 37B개가 활성화되는 강력한 Mixture-of-Experts(MoE) 언어 모델인 DeepSeek-V3를 소개합니다. 효율적인 추론과 비용 효율적인 훈련을 달성하기 위해, DeepSeek-V3는 DeepSeek-V2에서 철저히 검증된 Multi-head Latent Attention(MLA) 및 DeepSeekMoE 아키텍처를 채택했습니다. 나아가 DeepSeek-V3는 로드 밸런싱을 위한 보조 손실 없는 전략을 개척하고, 더 강력한 성능을 위해 다중 토큰 예측 훈련 목표를 설정합니다. 저희는 14.8조 개의 다양하고 고품질의 토큰으로 DeepSeek-V3를 사전 훈련했으며, 그 잠재력을 완전히 활용하기 위해 지도 파인튜닝 및 강화 학습 단계를 거쳤습니다. 종합적인 평가 결과, DeepSeek-V3는 다른 오픈 소스 모델들을 능가하며 선도적인 비공개 소스 모델들과 필적하는 성능을 달성했음을 보여줍니다. 뛰어난 성능에도 불구하고 DeepSeek-V3의 전체 훈련에는 278.8만 H800 GPU 시간만이 소요되었습니다. 또한, 훈련 과정이 매우 안정적입니다. 전체 훈련 과정 동안 복구 불가능한 손실 급증을 경험하거나 롤백을 수행한 적이 없습니다. 모델 체크포인트는 https://github.com/deepseek-ai/DeepSeek-V3 에서 확인할 수 있습니다.
## 한계 및 기여 요청![[limitations-and-call-for-contribution!]]
저희는 이 코드를 커뮤니티 기반으로 만들게 되어 매우 기쁘며, 여러분이 다음 사항들을 어떻게 최적화할 수 있는지 확인하고 싶습니다.
- 현재 구현은 "기본적인" 어텐션 계산을 사용합니다. 따라서 실제 Multi-head Latent Attention (MLA) 가 아닙니다.
- 현재 구현은 전문가들을 순회하는 루프를 사용합니다. 이는 교체되어야 하기에 `integrations/tensor_parallel``get_packed_weights`를 사용하는 것을 제안합니다.
- 현재 구현에서는 ROPE에 EleutherAI의 수식을 사용하지만, 원본 수식을 적용하면 더 효율적일 것입니다! (단, 기존 API는 그대로 준수해야 합니다)
- generation config 또는 config shape의 문제일 것으로 추정되는 문제로 인해 정적 캐시는 지원되지 않습니다.
### 사용 팁[[usage-tips]]
이 모델은 효율적인 추론과 비용 효율적인 훈련을 위해 Multi-head Latent Attention (MLA) 및 DeepSeekMoE 아키텍처를 사용합니다. 로드 밸런싱을 위한 보조 손실이 없는 전략과 다중 토큰 예측 훈련 목표를 채택합니다. 이 모델은 14.8조 개의 토큰으로 사전 훈련되고 지도 파인튜닝 및 강화 학습 단계를 거친 후 다양한 언어 작업에 사용될 수 있습니다.
`FP8`로 모델을 자동으로 실행할 수 있으며, 8개의 H100으로 구성된 2개 노드면 충분할 것입니다!
```python
# `run_deepseek_v1.py`
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
torch.manual_seed(30)
tokenizer = AutoTokenizer.from_pretrained("deepseek-r1")
chat = [
{"role": "user", "content": "안녕하세요, 어떻게 지내세요?"},
{"role": "assistant", "content": "저는 잘 지내요. 오늘 무엇을 도와드릴까요?"},
{"role": "user", "content": "채팅 템플릿이 어떻게 작동하는지 보여주고 싶어요!"},
]
model = AutoModelForCausalLM.from_pretrained("deepseek-r1", device_map="auto", torch_dtype=torch.bfloat16)
inputs = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
import time
start = time.time()
outputs = model.generate(inputs, max_new_tokens=50)
print(tokenizer.batch_decode(outputs))
print(time.time()-start)
```
생성된 결과는 다음과 같습니다.
``````
<Assistant><think>
좋아요, 사용자는 채팅 템플릿이 어떻게 작동하는지 보여주고 싶어 하는군요. 이게 무슨 의미인지 분석해 보겠습니다. 채팅 템플릿은 대화 데이터를 구조화하는 것인데, 특히 특정 입력 형식이 필요한 모델에 중요합니다. 아마도 OpenAI 같은 API에서 메시지가 역할(사용자, 어시스턴트, 시스템)과 함께 형식화되는 방식을 말하는 것일 수 있습니다.
먼저, 채팅 템플릿이 무엇인지 설명해야겠습니다. 이는 모델이 이해할 수 있는 구조화된 형식으로 대화 데이터를 포맷하는 과정입니다. 여기에는 보통 역할과 콘텐츠가 포함됩니다. 예를 들어, 사용자 메시지, 어시스턴트 응답, 시스템 메시지는 각각 고유한 역할 태그를 가집니다.
사용자는 예시를 원할 수 있습니다. 간단한 대화를 생각해 보죠. 사용자가 "안녕하세요, 잘 지내세요?"라고 말하고 어시스턴트가 "네, 잘 지내요. 오늘 무엇을 도와드릴까요?"라고 답합니다. 그런 다음 사용자가 채팅 템플릿을 보여주고 싶다고 이어갑니다. 따라서 예시에는 대화 기록과 새 메시지가 포함되어야 합니다.
Hugging Face의 Transformers와 같은 프레임워크에서는 Jinja2 템플릿을 사용하여 채팅 템플릿이 적용됩니다. 템플릿은 시스템 메시지를 결합한 다음, 적절한 태그와 함께 사용자와 어시스턴트 메시지를 반복하는 형식일 수 있습니다. 예를 들어, {% for message in messages %}를 사용하고 <|user|>, <|assistant|>와 같은 역할을 할당하는 것입니다.
메시지 배열 예시를 구성하고, 각 역할과 내용을 보여주어야겠습니다. 그런 다음 가상의 템플릿을 적용하여 모델이 사용하는 형식화된 문자열로 변환하는 과정을 보여줍니다. 또한, 모델마다 특수 토큰을 사용하거나 역할 레이블이 다른 것처럼 템플릿 요구 사항이 다르다는 점도 언급해야 합니다.
잠깐, 사용자가 "채팅 템플릿"을 보여주는 맥락에서 언급했습니다. 아마도 발표할 수 있는 실용적인 예시를 원하는 것일 수 있습니다. 따라서 코드 스니펫이나 구조화된 데이터 예시를 제공하는 것이 도움이 될 것입니다. 일반적인 메시지 배열과 템플릿이 적용된 결과물을 개략적으로 설명해 보겠습니다.
또한, 적절한 템플릿팅은 모델이 대화 흐름을 이해하도록 보장하며, 이는 일관된 응답을 생성하는 데 중요합니다. 왜 이것이 중요한지에 대한 메모, 예를 들어 컨텍스트 유지 및 역할별 처리의 중요성을 포함할 수도 있겠습니다.
흔한 실수나 피해야 할 점이 있는지 확인해 보겠습니다. 예를 들어, 태그를 제대로 닫지 않거나 역할이 일치하지 않는 경우입니다. 하지만 사용자가 묻지 않는 한 너무 자세할 수 있습니다. 먼저 긍정적인 예시에 집중합시다.
모든 것을 종합하면, 응답에는 예시 메시지 배열, 적용된 템플릿, 그리고 최종적으로 형식화된 문자열이 포함되어야 합니다. 자리 표시자로 꺾쇠괄호나 특수 토큰을 사용할 수 있습니다. 또한, 이것이 구조화된 데이터로 모델을 훈련하거나 파인튜닝하는 데 도움이 된다는 점도 언급해야 합니다.
이것이 확실한 접근 방식인 것 같습니다. 명확하게 만들기 위해 단계별로 구조화해 보겠습니다.
</think>
채팅 템플릿은 대화 데이터(예: 사용자/어시스턴트 상호작용)를 언어 모델이 이해할 수 있는 형식으로 구조화하는 방법입니다. 이는 특히 입력에서 역할(사용자, 어시스턴트, 시스템 등)과 메시지를 명시적으로 구분해야 하는 다중 턴 대화를 처리하도록 훈련된 모델에 중요합니다. 예시와 함께 자세히 살펴보겠습니다!
---
### **1단계: 원본 대화 기록**
다음과 같은 대화가 있다고 가정해 보겠습니다:
- **사용자**: "안녕하세요, 어떻게 지내세요?"
- **어시스턴트**: "저는 잘 지내요. 오늘 무엇을 도와드릴까요?"
- **사용자**: "채팅 템플릿이 어떻게 작동하는지 보여주고 싶어요!"
---
### **2단계: 구조화된 메시지**
Hugging Face Transformers나 OpenAI 같은 프레임워크에서는 대화가 종종 `role`과 `content`를 가진 딕셔너리 리스트로 형식화됩니다.
```python
messages = [
{"role": "user", "content": "안녕하세요, 어떻게 지내세요?"},
{"role": "assistant", "content": "저는 잘 지내요. 오늘 무엇을 도와드릴까요?"},
{"role": "user", "content": "채팅 템플릿이 어떻게 작동하는지 보여주고 싶어요!"},
]
```
---
### **3단계: 채팅 템플릿 적용**
**채팅 템플릿**은 이 구조화된 데이터를 모델에 맞는 단일 문자열로 변환합니다. 예를 들어, (Hugging Face에서 흔히 사용되는) Jinja 스타일 템플릿을 사용하면 다음과 같습니다.
```jinja
{% for message in messages %}
{% if message['role'] == 'user' %}
<|user|>{{ message['content'] }}<|end|>
{% elif message['role'] == 'assistant' %}
<|assistant|>{{ message['content'] }}<|end|>
{% endif %}
{% endfor %}
<|assistant|>
```
---
### **4단계: 최종 템플릿 결과물**
위 템플릿을 저희 `messages` 리스트에 적용하면 다음과 같은 결과가 생성됩니다:
```text
<|user|>안녕하세요, 어떻게 지내세요?<|end|>
<|assistant|>저는 잘 지내요. 오늘 무엇을 도와드릴까요?<|end|>
<|user|>채팅 템플릿이 어떻게 작동하는지 보여주고 싶어요!<|end|>
<|assistant|>
```
이는 모델에게 다음을 알려줍니다:
1. 대화 기록 (사용자/어시스턴트 턴).
2. 모델이 응답을 생성할 차례 (`<|assistant|>`가 끝에 있음).
---
### **주요 참고사항**:
- **역할 분리**: `<|user|>`와 `<|assistant|>` 같은 태그는 모델이 화자를 구별하는 데 도움이 됩니다.
- **특수 토큰**: 모델은 종종 메시지 경계를 표시하기 위해 `<|end|>`와 같은 고유한 토큰을 사용합니다.
- **유연성**: 템플릿은 모델마다 다릅니다 (예: OpenAI는 태그 대신 `{"role": "user", "content": "..."}` 형식을 사용합니다).
---
### **이것이 왜 중요한가**:
- **일관성**: 모델이 대화 구조를 이해하도록 보장합니다.
- **컨텍스트 보존**: 다중 턴 대화의 흐름을 유지합니다.
- **정렬**: 더 나은 성능을 위해 모델이 훈련된 형식과 일치시킵니다.
더 자세히 알아보거나 특정 프레임워크(예: OpenAI, Llama, Mistral)의 구현을 보고 싶으신가요? 알려주세요! 😊<end of sentence>
``````
다음을 사용하여 실행하세요
```bash
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0|1 --rdzv-id an_id --rdzv-backend c10d --rdzv-endpoint master_addr:master_port run_deepseek_r1.py
```
만약 다음과 같은
```bash
[rank0]: ncclInternalError: Internal check failed.
[rank0]: Last error:
[rank0]: Bootstrap : no socket interface found
```
오류가 발생한다면, NCCL이 로드되지 않았을 가능성이 높다는 의미입니다.
## DeepseekV3Config[[deepseekv3config]]
[[autodoc]] DeepseekV3Config
## DeepseekV3Model[[deepseekv3model]]
[[autodoc]] DeepseekV3Model
- forward
## DeepseekV3ForCausalLM[[deepseekv3forcausallm]]
[[autodoc]] DeepseekV3ForCausalLM
- forward

View File

@ -0,0 +1,84 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# X-CLIP[[x-clip]]
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
## 개요[[overview]]
X-CLIP 모델은 Bolin Ni, Houwen Peng, Minghao Chen, Songyang Zhang, Gaofeng Meng, Jianlong Fu, Shiming Xiang, Haibin Ling이 [Expanding Language-Image Pretrained Models for General Video Recognition](https://huggingface.co/papers/2208.02816)에서 제안했습니다.
X-CLIP은 비디오를 위해 [CLIP](clip)을 최소한으로 확장한 모델입니다. 이 모델은 텍스트 인코더, 교차 프레임 비전 인코더, 다중 프레임 통합 Transformer, 그리고 비디오별 프롬프트 생성기로 구성됩니다.
논문의 초록은 아래와 같습니다.
*대조적 언어-이미지 사전 학습은 웹 스케일 데이터로부터 시각-텍스트 공동 표현을 학습하는 데 큰 성공을 거두었으며, 다양한 이미지 작업에 대해 뛰어난 "제로샷(zero-shot)" 일반화 능력을 보여주었습니다. 그러나 이러한 새로운 언어-이미지 사전 학습 방법을 비디오 도메인으로 효과적으로 확장하는 방법은 아직 해결되지 않은 문제입니다. 본 연구에서는 새로운 모델을 처음부터 사전 학습하는 대신, 사전 학습된 언어-이미지 모델을 비디오 인식에 직접 적용하는 간단하면서도 효과적인 접근 방식을 제시합니다. 더 구체적으로, 시간 차원에서 프레임 간의 장기적인 의존성을 포착하기 위해 프레임 간 정보를 명시적으로 교환하는 교차 프레임 어텐션 메커니즘을 제안합니다. 이러한 모듈은 가벼울 뿐만 아니라, 사전 학습된 언어-이미지 모델에 쉽게 통합될 수 있습니다. 또한, 비디오 콘텐츠 정보를 활용하여 식별력 있는 텍스트 프롬프트를 생성하는 비디오별 프롬프팅 기법을 제안합니다. 광범위한 실험을 통해 우리의 접근 방식이 효과적이며 다양한 비디오 인식 시나리오에 일반화될 수 있음을 입증합니다. 특히, 완전 지도 학습 환경에서 우리 접근 방식은 Kinectics-400에서 87.1%의 top-1 정확도를 달성하면서도 Swin-L 및 ViViT-H에 비해 FLOPs를 12배 적게 사용합니다. 제로샷 실험에서는 두 가지 인기 있는 프로토콜 하에서 top-1 정확도 측면에서 현재 최첨단 방법들을 +7.6% 및 +14.9% 능가합니다. 퓨샷(few-shot) 시나리오에서는 레이블이 지정된 데이터가 극히 제한적일 때 이전 최고 방법들을 +32.1% 및 +23.1% 능가합니다.*
팁:
- X-CLIP의 사용법은 [CLIP](clip)과 동일합니다.
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/xclip_architecture.png"
alt="drawing" width="600"/>
<small> X-CLIP 아키텍처. <a href="https://huggingface.co/papers/2208.02816">원본 논문</a>에서 가져왔습니다. </small>
이 모델은 [nielsr](https://huggingface.co/nielsr)님이 기여했습니다.
원본 코드는 [여기](https://github.com/microsoft/VideoX/tree/master/X-CLIP)에서 찾을 수 있습니다.
## 리소스[[resources]]
X-CLIP을 시작하는 데 도움이 되는 공식 Hugging Face 및 커뮤니티(🌎로 표시) 리소스 목록입니다.
- X-CLIP 데모 노트북은 [여기](https://github.com/NielsRogge/Transformers-Tutorials/tree/master/X-CLIP)에서 찾을 수 있습니다.
여기에 포함할 리소스를 제출하는 데 관심이 있다면, 언제든지 Pull Request를 열어주세요. 검토 후 반영하겠습니다! 리소스는 기존 리소스를 복제하는 대신 새로운 것을 보여주는 것이 이상적입니다.
## XCLIPProcessor[[xclipprocessor]]
[[autodoc]] XCLIPProcessor
## XCLIPConfig[[xclipconfig]]
[[autodoc]] XCLIPConfig
- from_text_vision_configs
## XCLIPTextConfig[[xcliptextconfig]]
[[autodoc]] XCLIPTextConfig
## XCLIPVisionConfig[[xclipvisionconfig]]
[[autodoc]] XCLIPVisionConfig
## XCLIPModel[[xclipmodel]]
[[autodoc]] XCLIPModel
- forward
- get_text_features
- get_video_features
## XCLIPTextModel[[xcliptextmodel]]
[[autodoc]] XCLIPTextModel
- forward
## XCLIPVisionModel[[xclipvisionmodel]]
[[autodoc]] XCLIPVisionModel
- forward

View File

@ -133,9 +133,6 @@ generation_output[:2]
[[autodoc]] ForcedEOSTokenLogitsProcessor
- __call__
[[autodoc]] HammingDiversityLogitsProcessor
- __call__
[[autodoc]] InfNanRemoveLogitsProcessor
- __call__
@ -298,32 +295,6 @@ generation_output[:2]
[[autodoc]] MaxTimeCriteria
- __call__
## Constraints
可以使用[`Constraint`]来强制生成结果包含输出中的特定tokens或序列。请注意这仅适用于我们的PyTorch实现。
[[autodoc]] Constraint
[[autodoc]] PhrasalConstraint
[[autodoc]] DisjunctiveConstraint
[[autodoc]] ConstraintListState
## BeamSearch
[[autodoc]] BeamScorer
- process
- finalize
[[autodoc]] BeamSearchScorer
- process
- finalize
[[autodoc]] ConstrainedBeamSearchScorer
- process
- finalize
## Streamers
[[autodoc]] TextStreamer

View File

@ -15,9 +15,7 @@ limitations under the License.
# Examples
We host a wide range of example scripts for multiple learning frameworks. Simply choose your favorite: [TensorFlow](https://github.com/huggingface/transformers/tree/main/examples/tensorflow), [PyTorch](https://github.com/huggingface/transformers/tree/main/examples/pytorch) or [JAX/Flax](https://github.com/huggingface/transformers/tree/main/examples/flax).
We also have some [research projects](https://github.com/huggingface/transformers-research-projects/), as well as some [legacy examples](https://github.com/huggingface/transformers/tree/main/examples/legacy). Note that unlike the main examples these are not actively maintained, and may require specific older versions of dependencies in order to run.
We host a wide range of example scripts, in addition to [research projects](https://github.com/huggingface/transformers-research-projects/), as well as some [legacy examples](https://github.com/huggingface/transformers/tree/main/examples/legacy). Note that unlike the main examples these are not actively maintained, and may require specific older versions of dependencies in order to run.
While we strive to present as many use cases as possible, the example scripts are just that - examples. It is expected that they won't work out-of-the-box on your specific problem and that you will be required to change a few lines of code to adapt them to your needs. To help you with that, most of the examples fully expose the preprocessing of the data, allowing you to tweak and edit them as required.

View File

@ -1,83 +0,0 @@
<!---
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# JAX/Flax Examples
This folder contains actively maintained examples of 🤗 Transformers using the JAX/Flax backend. Porting models and examples to JAX/Flax is an ongoing effort, and more will be added in the coming months. In particular, these examples are all designed to run fast on Cloud TPUs, and we include step-by-step guides to getting started with Cloud TPU.
*NOTE*: Currently, there is no "Trainer" abstraction for JAX/Flax -- all examples contain an explicit training loop.
The following table lists all of our examples on how to use 🤗 Transformers with the JAX/Flax backend:
- with information about the model and dataset used,
- whether or not they leverage the [🤗 Datasets](https://github.com/huggingface/datasets) library,
- links to **Colab notebooks** to walk through the scripts and run them easily.
| Task | Example model | Example dataset | 🤗 Datasets | Colab
|---|---|---|:---:|:---:|
| [**`causal-language-modeling`**](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling) | GPT2 | OSCAR | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/causal_language_modeling_flax.ipynb)
| [**`masked-language-modeling`**](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling) | RoBERTa | OSCAR | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/masked_language_modeling_flax.ipynb)
| [**`text-classification`**](https://github.com/huggingface/transformers/tree/main/examples/flax/text-classification) | BERT | GLUE | ✅ | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification_flax.ipynb)
## Intro: JAX and Flax
[JAX](https://github.com/google/jax) is a numerical computation library that exposes a NumPy-like API with tracing capabilities. With JAX's `jit`, you can
trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU. JAX
supports additional transformations such as `grad` (for arbitrary gradients), `pmap` (for parallelizing computation on multiple devices), `remat` (for gradient checkpointing), `vmap` (automatic
efficient vectorization), and `pjit` (for automatically sharded model parallelism). All JAX transformations compose arbitrarily with each other -- e.g., efficiently
computing per-example gradients is simply `vmap(grad(f))`.
[Flax](https://github.com/google/flax) builds on top of JAX with an ergonomic
module abstraction using Python dataclasses that leads to concise and explicit code. Flax's "lifted" JAX transformations (e.g. `vmap`, `remat`) allow you to nest JAX transformation and modules in any way you wish. Flax is the most widely used JAX library, with [129 dependent projects](https://github.com/google/flax/network/dependents?package_id=UGFja2FnZS01MjEyMjA2MA%3D%3D) as of May 2021. It is also the library underlying all of the official Cloud TPU JAX examples.
## Running on Cloud TPU
All of our JAX/Flax models are designed to run efficiently on Google
Cloud TPUs. Here is [a guide for running JAX on Google Cloud TPU](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm).
Consider applying for the [Google TPU Research Cloud project](https://sites.research.google/trc/) for free TPU compute.
Each example README contains more details on the specific model and training
procedure.
## Running on single or multiple GPUs
All of our JAX/Flax examples also run efficiently on single and multiple GPUs. You can use the same instructions in the README to launch training on GPU.
Distributed training is supported out-of-the box and scripts will use all the GPUs that are detected.
You should follow this [guide for installing JAX on GPUs](https://github.com/google/jax/#pip-installation-gpu-cuda) since the installation depends on
your CUDA and CuDNN version.
## Supported models
Porting models from PyTorch to JAX/Flax is an ongoing effort.
Feel free to reach out if you are interested in contributing a model in JAX/Flax -- we'll
be adding a guide for porting models from PyTorch in the upcoming few weeks.
For a complete overview of models that are supported in JAX/Flax, please have a look at [this](https://huggingface.co/transformers/main/index.html#supported-frameworks) table.
Over 3000 pretrained checkpoints are supported in JAX/Flax as of May 2021.
Click [here](https://huggingface.co/models?filter=jax) to see the full list on the 🤗 hub.
## Upload the trained/fine-tuned model to the Hub
All the example scripts support automatic upload of your final model to the [Model Hub](https://huggingface.co/models) by adding a `--push_to_hub` argument. It will then create a repository with your username slash the name of the folder you are using as `output_dir`. For instance, `"sgugger/test-mrpc"` if your username is `sgugger` and you are working in the folder `~/tmp/test-mrpc`.
To specify a given repository name, use the `--hub_model_id` argument. You will need to specify the whole repository name (including your username), for instance `--hub_model_id sgugger/finetuned-bert-mrpc`. To upload to an organization you are a member of, just use the name of that organization instead of your username: `--hub_model_id huggingface/finetuned-bert-mrpc`.
A few notes on this integration:
- you will need to be logged in to the Hugging Face website locally for it to work, the easiest way to achieve this is to run `hf auth login` and then type your username and password when prompted. You can also pass along your authentication token with the `--hub_token` argument.
- the `output_dir` you pick will either need to be a new folder or a local clone of the distant repository you are using.

View File

@ -1,10 +0,0 @@
datasets >= 1.13.3
pytest<8.0.1
conllu
nltk
rouge-score
seqeval
tensorboard
evaluate >= 0.2.0
torch
accelerate

View File

@ -1,45 +0,0 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# tests directory-specific settings - this file is run automatically
# by pytest before any tests are run
import sys
import warnings
from os.path import abspath, dirname, join
# allow having multiple repository checkouts and not needing to remember to rerun
# `pip install -e '.[dev]'` when switching between checkouts and running tests.
git_repo_path = abspath(join(dirname(dirname(dirname(__file__))), "src"))
sys.path.insert(1, git_repo_path)
# silence FutureWarning warnings in tests since often we can't act on them until
# they become normal warnings - i.e. the tests still need to test the current functionality
warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_addoption(parser):
from transformers.testing_utils import pytest_addoption_shared
pytest_addoption_shared(parser)
def pytest_terminal_summary(terminalreporter):
from transformers.testing_utils import pytest_terminal_summary_main
make_reports = terminalreporter.config.getoption("--make-reports")
if make_reports:
pytest_terminal_summary_main(terminalreporter, id=make_reports)

View File

@ -1,68 +0,0 @@
# Image Captioning (vision-encoder-text-decoder model) training example
The following example showcases how to finetune a vision-encoder-text-decoder model for image captioning
using the JAX/Flax backend, leveraging 🤗 Transformers library's [FlaxVisionEncoderDecoderModel](https://huggingface.co/docs/transformers/model_doc/vision-encoder-decoder#transformers.FlaxVisionEncoderDecoderModel).
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
Models written in JAX/Flax are **immutable** and updated in a purely functional
way which enables simple and efficient model parallelism.
`run_image_captioning_flax.py` is a lightweight example of how to download and preprocess a dataset from the 🤗 Datasets
library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.
For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets#json-files and you also will find examples of these below.
### Download COCO dataset (2017)
This example uses COCO dataset (2017) through a custom dataset script, which requires users to manually download the
COCO dataset before training.
```bash
mkdir data
cd data
wget http://images.cocodataset.org/zips/train2017.zip
wget http://images.cocodataset.org/zips/val2017.zip
wget http://images.cocodataset.org/zips/test2017.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
wget http://images.cocodataset.org/annotations/image_info_test2017.zip
cd ..
```
### Create a model from a vision encoder model and a text decoder model
Next, we create a [FlaxVisionEncoderDecoderModel](https://huggingface.co/docs/transformers/model_doc/visionencoderdecoder#transformers.FlaxVisionEncoderDecoderModel) instance from a pre-trained vision encoder ([ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.FlaxViTModel)) and a pre-trained text decoder ([GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.FlaxGPT2Model)):
```bash
python3 create_model_from_encoder_decoder_models.py \
--output_dir model \
--encoder_model_name_or_path google/vit-base-patch16-224-in21k \
--decoder_model_name_or_path openai-community/gpt2
```
### Train the model
Finally, we can run the example script to train the model:
```bash
python3 run_image_captioning_flax.py \
--output_dir ./image-captioning-training-results \
--model_name_or_path model \
--dataset_name ydshieh/coco_dataset_script \
--dataset_config_name=2017 \
--data_dir $PWD/data \
--image_column image_path \
--caption_column caption \
--do_train --do_eval --predict_with_generate \
--num_train_epochs 1 \
--eval_steps 500 \
--learning_rate 3e-5 --warmup_steps 0 \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--overwrite_output_dir \
--max_target_length 32 \
--num_beams 8 \
--preprocessing_num_workers 16 \
--logging_steps 10 \
--block_size 16384 \
--push_to_hub
```
This should finish in about 1h30 on Cloud TPU, with validation loss and ROUGE2 score of 2.0153 and 14.64 respectively
after 1 epoch. Training statistics can be accessed on [Models](https://huggingface.co/ydshieh/image-captioning-training-results/tensorboard).

View File

@ -1,115 +0,0 @@
#!/usr/bin/env python
# Copyright 2022 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Create a VisionEncoderDecoderModel instance from pretrained encoder/decoder models.
The cross-attention will be randomly initialized.
"""
from dataclasses import dataclass, field
from typing import Optional
from transformers import AutoConfig, AutoImageProcessor, AutoTokenizer, FlaxVisionEncoderDecoderModel, HfArgumentParser
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
output_dir: str = field(
metadata={"help": "The output directory where the model will be written."},
)
encoder_model_name_or_path: str = field(
metadata={
"help": (
"The encoder model checkpoint for weights initialization. "
"Don't set if you want to train an encoder model from scratch."
)
},
)
decoder_model_name_or_path: str = field(
metadata={
"help": (
"The decoder model checkpoint for weights initialization. "
"Don't set if you want to train a decoder model from scratch."
)
},
)
encoder_config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained encoder config name or path if not the same as encoder_model_name"}
)
decoder_config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained decoder config name or path if not the same as decoder_model_name"}
)
def main():
parser = HfArgumentParser((ModelArguments,))
(model_args,) = parser.parse_args_into_dataclasses()
# Load pretrained model and tokenizer
# Use explicit specified encoder config
if model_args.encoder_config_name:
encoder_config = AutoConfig.from_pretrained(model_args.encoder_config_name)
# Use pretrained encoder model's config
else:
encoder_config = AutoConfig.from_pretrained(model_args.encoder_model_name_or_path)
# Use explicit specified decoder config
if model_args.decoder_config_name:
decoder_config = AutoConfig.from_pretrained(model_args.decoder_config_name)
# Use pretrained decoder model's config
else:
decoder_config = AutoConfig.from_pretrained(model_args.decoder_model_name_or_path)
# necessary for `from_encoder_decoder_pretrained` when `decoder_config` is passed
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_pretrained_model_name_or_path=model_args.encoder_model_name_or_path,
decoder_pretrained_model_name_or_path=model_args.decoder_model_name_or_path,
encoder_config=encoder_config,
decoder_config=decoder_config,
)
# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
decoder_start_token_id = decoder_config.decoder_start_token_id
pad_token_id = decoder_config.pad_token_id
if decoder_start_token_id is None:
decoder_start_token_id = decoder_config.bos_token_id
if pad_token_id is None:
pad_token_id = decoder_config.eos_token_id
# This is necessary to make Flax's generate() work
model.config.eos_token_id = decoder_config.eos_token_id
model.config.decoder_start_token_id = decoder_start_token_id
model.config.pad_token_id = pad_token_id
image_processor = AutoImageProcessor.from_pretrained(model_args.encoder_model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_args.decoder_model_name_or_path)
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(model.config.pad_token_id)
model.save_pretrained(model_args.output_dir)
image_processor.save_pretrained(model_args.output_dir)
tokenizer.save_pretrained(model_args.output_dir)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -1,568 +0,0 @@
<!---
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Language model training and inference examples
The following example showcases how to train a language model from scratch
using the JAX/Flax backend.
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
Models written in JAX/Flax are **immutable** and updated in a purely functional
way which enables simple and efficient model parallelism.
## Masked language modeling
In the following, we demonstrate how to train a bi-directional transformer model
using masked language modeling objective as introduced in [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://huggingface.co/papers/1810.04805).
More specifically, we demonstrate how JAX/Flax can be leveraged
to pre-train [**`FacebookAI/roberta-base`**](https://huggingface.co/FacebookAI/roberta-base)
in Norwegian on a single TPUv3-8 pod.
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
To setup all relevant files for training, let's create a directory.
```bash
mkdir ./norwegian-roberta-base
```
### Train tokenizer
In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train), we use a **`ByteLevelBPETokenizer`**.
The tokenizer is trained on the complete Norwegian dataset of OSCAR
and consequently saved in the cloned model directory.
This can take up to 10 minutes depending on your hardware ☕.
```python
from datasets import load_dataset
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
# load dataset
dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")
# Instantiate tokenizer
tokenizer = ByteLevelBPETokenizer()
def batch_iterator(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i: i + batch_size]["text"]
# Customized training
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
"<s>",
"<pad>",
"</s>",
"<unk>",
"<mask>",
])
# Save files to disk
tokenizer.save("./norwegian-roberta-base/tokenizer.json")
```
### Create configuration
Next, we create the model's configuration file. This is as simple
as loading and storing [`**FacebookAI/roberta-base**`](https://huggingface.co/FacebookAI/roberta-base)
in the local model folder:
```python
from transformers import RobertaConfig
config = RobertaConfig.from_pretrained("FacebookAI/roberta-base", vocab_size=50265)
config.save_pretrained("./norwegian-roberta-base")
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
### Train model
Next we can run the example script to pretrain the model:
```bash
python run_mlm_flax.py \
--output_dir="./norwegian-roberta-base" \
--model_type="roberta" \
--config_name="./norwegian-roberta-base" \
--tokenizer_name="./norwegian-roberta-base" \
--dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \
--max_seq_length="128" \
--weight_decay="0.01" \
--per_device_train_batch_size="128" \
--per_device_eval_batch_size="128" \
--learning_rate="3e-4" \
--warmup_steps="1000" \
--overwrite_output_dir \
--num_train_epochs="18" \
--adam_beta1="0.9" \
--adam_beta2="0.98" \
--logging_steps="500" \
--save_steps="2500" \
--eval_steps="2500" \
--push_to_hub
```
Training should converge at a loss and accuracy
of 1.78 and 0.64 respectively after 18 epochs on a single TPUv3-8.
This should take less than 18 hours.
Training statistics can be accessed on [tfhub.dev](https://tensorboard.dev/experiment/GdYmdak2TWeVz0DDRYOrrg).
For a step-by-step walkthrough of how to do masked language modeling in Flax, please have a
look at [this](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/masked_language_modeling_flax.ipynb) google colab.
## Causal language modeling
In the following, we demonstrate how to train an auto-regressive causal transformer model
in JAX/Flax.
More specifically, we pretrain a randomly initialized [**`openai-community/gpt2`**](https://huggingface.co/openai-community/gpt2) model in Norwegian on a single TPUv3-8.
to pre-train 124M [**`openai-community/gpt2`**](https://huggingface.co/openai-community/gpt2)
in Norwegian on a single TPUv3-8 pod.
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
To setup all relevant files for training, let's create a directory.
```bash
mkdir ./norwegian-gpt2
```
### Train tokenizer
In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train), we use a **`ByteLevelBPETokenizer`**.
The tokenizer is trained on the complete Norwegian dataset of OSCAR
and consequently saved in the cloned model directory.
This can take up to 10 minutes depending on your hardware ☕.
```python
from datasets import load_dataset
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
# load dataset
dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")
# Instantiate tokenizer
tokenizer = ByteLevelBPETokenizer()
def batch_iterator(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i: i + batch_size]["text"]
# Customized training
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50257, min_frequency=2, special_tokens=[
"<s>",
"<pad>",
"</s>",
"<unk>",
"<mask>",
])
# Save files to disk
tokenizer.save("./norwegian-gpt2/tokenizer.json")
```
### Create configuration
Next, we create the model's configuration file. This is as simple
as loading and storing [`**openai-community/gpt2**`](https://huggingface.co/openai-community/gpt2)
in the local model folder:
```python
from transformers import GPT2Config
config = GPT2Config.from_pretrained("openai-community/gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=50257)
config.save_pretrained("./norwegian-gpt2")
```
Great, we have set up our model repository. During training, we will now automatically
push the training logs and model weights to the repo.
### Train model
Finally, we can run the example script to pretrain the model:
```bash
python run_clm_flax.py \
--output_dir="./norwegian-gpt2" \
--model_type="gpt2" \
--config_name="./norwegian-gpt2" \
--tokenizer_name="./norwegian-gpt2" \
--dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \
--do_train --do_eval \
--block_size="512" \
--per_device_train_batch_size="64" \
--per_device_eval_batch_size="64" \
--learning_rate="5e-3" --warmup_steps="1000" \
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
--overwrite_output_dir \
--num_train_epochs="20" \
--logging_steps="500" \
--save_steps="2500" \
--eval_steps="2500" \
--push_to_hub
```
Training should converge at a loss and perplexity
of 3.24 and 25.72 respectively after 20 epochs on a single TPUv3-8.
This should take less than ~21 hours.
Training statistics can be accessed on [tfhub.dev](https://tensorboard.dev/experiment/2zEhLwJ0Qp2FAkI3WVH9qA).
For a step-by-step walkthrough of how to do causal language modeling in Flax, please have a
look at [this](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/causal_language_modeling_flax.ipynb) google colab.
## T5-like span-masked language modeling
In the following, we demonstrate how to train a T5 model using the span-masked language model
objective as proposed in the [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://huggingface.co/papers/1910.10683).
More specifically, we demonstrate how JAX/Flax can be leveraged
to pre-train [**`google/t5-v1_1-base`**](https://huggingface.co/google/t5-v1_1-base)
in Norwegian on a single TPUv3-8 pod.
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"norwegian-t5-base"`, but you can change the model name as you like.
To setup all relevant files for training, let's create a directory.
```bash
cd ./norwegian-t5-base
```
### Train tokenizer
In the first step, we train a tokenizer to efficiently process the text input for the model.
We make use of the [tokenizers](https://github.com/huggingface/tokenizers) library to train
a sentencepiece unigram tokenizer as shown in [t5_tokenizer_model.py](https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling/t5_tokenizer_model.py)
which is heavily inspired from [yandex-research/DeDLOC's tokenizer model](https://github.com/yandex-research/DeDLOC/blob/5c994bc64e573702a9a79add3ecd68b38f14b548/sahajbert/tokenizer/tokenizer_model.py) .
The tokenizer is trained on the complete Norwegian dataset of OSCAR
and consequently saved in the cloned model directory.
This can take up to 120 minutes depending on your hardware ☕☕☕ .
```python
import datasets
from t5_tokenizer_model import SentencePieceUnigramTokenizer
vocab_size = 32_000
input_sentence_size = None
# Initialize a dataset
dataset = datasets.load_dataset("oscar", name="unshuffled_deduplicated_no", split="train")
tokenizer = SentencePieceUnigramTokenizer(unk_token="<unk>", eos_token="</s>", pad_token="<pad>")
# Build an iterator over this dataset
def batch_iterator(input_sentence_size=None):
if input_sentence_size is None:
input_sentence_size = len(dataset)
batch_length = 100
for i in range(0, input_sentence_size, batch_length):
yield dataset[i: i + batch_length]["text"]
# Train tokenizer
tokenizer.train_from_iterator(
iterator=batch_iterator(input_sentence_size=input_sentence_size),
vocab_size=vocab_size,
show_progress=True,
)
# Save files to disk
tokenizer.save("./norwegian-t5-base/tokenizer.json")
```
### Create configuration
Next, we create the model's configuration file. This is as simple
as loading and storing [`**google/t5-v1_1-base**`](https://huggingface.co/google/t5-v1_1-base)
in the local model folder:
```python
from transformers import T5Config
config = T5Config.from_pretrained("google/t5-v1_1-base", vocab_size=tokenizer.get_vocab_size())
config.save_pretrained("./norwegian-t5-base")
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
### Train model
Next we can run the example script to pretrain the model:
```bash
python run_t5_mlm_flax.py \
--output_dir="./norwegian-t5-base" \
--model_type="t5" \
--config_name="./norwegian-t5-base" \
--tokenizer_name="./norwegian-t5-base" \
--dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \
--max_seq_length="512" \
--per_device_train_batch_size="32" \
--per_device_eval_batch_size="32" \
--adafactor \
--learning_rate="0.005" \
--weight_decay="0.001" \
--warmup_steps="2000" \
--overwrite_output_dir \
--logging_steps="500" \
--save_steps="10000" \
--eval_steps="2500" \
--push_to_hub
```
Training should converge at a loss and accuracy
of 2.36 and 57.0 respectively after 3 epochs on a single TPUv3-8.
This should take around 4.5 hours.
Training statistics can be accessed on directly on the 🤗 [hub](https://huggingface.co/patrickvonplaten/t5-base-norwegian/tensorboard)
## BART: Denoising language modeling
In the following, we demonstrate how to train a BART model
using denoising language modeling objective as introduced in [BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://huggingface.co/papers/1910.13461).
More specifically, we demonstrate how JAX/Flax can be leveraged
to pre-train [**`bart-base`**](https://huggingface.co/facebook/bart-base)
in Norwegian on a single TPUv3-8 pod.
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
To setup all relevant files for training, let's create a directory.
```bash
mkdir ./norwegian-bart-base
```
### Train tokenizer
In the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train), we use a **`ByteLevelBPETokenizer`**.
The tokenizer is trained on the complete Norwegian dataset of OSCAR
and consequently saved in the cloned model directory.
This can take up to 10 minutes depending on your hardware ☕.
```python
from datasets import load_dataset
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
# load dataset
dataset = load_dataset("oscar", "unshuffled_deduplicated_no", split="train")
# Instantiate tokenizer
tokenizer = ByteLevelBPETokenizer()
def batch_iterator(batch_size=1000):
for i in range(0, len(dataset), batch_size):
yield dataset[i: i + batch_size]["text"]
# Customized training
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
"<s>",
"<pad>",
"</s>",
"<unk>",
"<mask>",
])
# Save files to disk
tokenizer.save("./norwegian-bart-base/tokenizer.json")
```
### Create configuration
Next, we create the model's configuration file. This is as simple
as loading and storing [`**facebook/bart-base**`](https://huggingface.co/facebook/bart-base)
in the local model folder:
```python
from transformers import BartConfig
config = BartConfig.from_pretrained("facebook/bart-base", vocab_size=50265)
config.save_pretrained("./norwegian-bart-base")
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
### Train model
Next we can run the example script to pretrain the model:
```bash
python run_bart_dlm_flax.py \
--output_dir="./norwegian-bart-base" \
--config_name="./norwegian-bart-base" \
--tokenizer_name="./norwegian-bart-base" \
--dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \
--max_seq_length="1024" \
--per_device_train_batch_size="32" \
--per_device_eval_batch_size="32" \
--learning_rate="1e-4" \
--warmup_steps="2000" \
--overwrite_output_dir \
--logging_steps="500" \
--save_steps="2000" \
--eval_steps="2000" \
--push_to_hub
```
Training should converge at a loss and accuracy
of 1.36 and 0.77 respectively after 3 epochs on a single TPUv3-8.
This should take less than 6 hours.
Training statistics can be accessed on [tfhub.dev](https://tensorboard.dev/experiment/Maw62QlaSXWS0MOf2V2lbg/).
## Runtime evaluation
We also ran masked language modeling using PyTorch/XLA on a TPUv3-8, and PyTorch on 8 V100 GPUs. We report the
overall training time below.
For reproducibility, we state the training commands used for PyTorch/XLA and PyTorch further below.
| Task | [TPU v3-8 (Flax)](https://tensorboard.dev/experiment/GdYmdak2TWeVz0DDRYOrrg/) | [TPU v3-8 (Pytorch/XLA)](https://tensorboard.dev/experiment/7Jq1kcQQRAmy12KOdXek7A/)| [8 GPU (PyTorch)](https://tensorboard.dev/experiment/PJneV8FQRxa2unPw1QnVHA) |
|-------|-----------|------------|------------|
| MLM | 15h32m | 23h46m | 44h14m |
*All experiments are ran on Google Cloud Platform.
GPU experiments are ran without further optimizations besides JAX
transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8"
are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips.
### Script to run MLM with PyTorch/XLA on TPUv3-8
For comparison one can run the same pre-training with PyTorch/XLA on TPU. To set up PyTorch/XLA on Cloud TPU VMs, please
refer to [this](https://cloud.google.com/tpu/docs/pytorch-xla-ug-tpu-vm) guide.
Having created the tokenizer and configuration in `norwegian-roberta-base`, we create the following symbolic links:
```bash
ln -s ~/transformers/examples/pytorch/language-modeling/run_mlm.py ./
ln -s ~/transformers/examples/pytorch/xla_spawn.py ./
```
, set the following environment variables:
```bash
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
unset LD_PRELOAD
export NUM_TPUS=8
export TOKENIZERS_PARALLELISM=0
export MODEL_DIR="./norwegian-roberta-base"
mkdir -p ${MODEL_DIR}
```
, and start training as follows:
```bash
python3 xla_spawn.py --num_cores ${NUM_TPUS} run_mlm.py --output_dir="./runs" \
--model_type="roberta" \
--config_name="${MODEL_DIR}" \
--tokenizer_name="${MODEL_DIR}" \
--dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \
--max_seq_length="128" \
--weight_decay="0.01" \
--per_device_train_batch_size="128" \
--per_device_eval_batch_size="128" \
--learning_rate="3e-4" \
--warmup_steps="1000" \
--overwrite_output_dir \
--num_train_epochs="18" \
--adam_beta1="0.9" \
--adam_beta2="0.98" \
--do_train \
--do_eval \
--logging_steps="500" \
--eval_strategy="epoch" \
--report_to="tensorboard" \
--save_strategy="no"
```
### Script to compare pre-training with PyTorch on 8 GPU V100's
For comparison you can run the same pre-training with PyTorch on GPU. Note that we have to make use of `gradient_accumulation`
because the maximum batch size that fits on a single V100 GPU is 32 instead of 128.
Having created the tokenizer and configuration in `norwegian-roberta-base`, we create the following symbolic links:
```bash
ln -s ~/transformers/examples/pytorch/language-modeling/run_mlm.py ./
```
, set some environment variables:
```bash
export NUM_GPUS=8
export TOKENIZERS_PARALLELISM=0
export MODEL_DIR="./norwegian-roberta-base"
mkdir -p ${MODEL_DIR}
```
, and can start training as follows:
```bash
python3 -m torch.distributed.launch --nproc_per_node ${NUM_GPUS} run_mlm.py \
--output_dir="${MODEL_DIR}" \
--model_type="roberta" \
--config_name="${MODEL_DIR}" \
--tokenizer_name="${MODEL_DIR}" \
--dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \
--max_seq_length="128" \
--weight_decay="0.01" \
--per_device_train_batch_size="32" \
--per_device_eval_batch_size="32" \
--gradient_accumulation="4" \
--learning_rate="3e-4" \
--warmup_steps="1000" \
--overwrite_output_dir \
--num_train_epochs="18" \
--adam_beta1="0.9" \
--adam_beta2="0.98" \
--do_train \
--do_eval \
--logging_steps="500" \
--eval_strategy="steps" \
--report_to="tensorboard" \
--save_strategy="no"
```
## Language model inference with bfloat16
The following example demonstrates performing inference with a language model using the JAX/Flax backend.
The example script run_bert_flax.py uses bert-base-uncased, and the model is loaded into `FlaxBertModel`.
The input data are randomly generated tokens, and the model is also jitted with JAX.
By default, it uses float32 precision for inference. To enable bfloat16, add the flag shown in the command below.
```bash
python3 run_bert_flax.py --precision bfloat16
> NOTE: For JAX Versions after v0.4.33 or later, users will need to set the below environment variables as a \
> temporary workaround to use Bfloat16 datatype. \
> This restriction is expected to be removed in future version
```bash
export XLA_FLAGS=--xla_cpu_use_thunk_runtime=false
```
bfloat16 gives better performance on GPUs and also Intel CPUs (Sapphire Rapids or later) with Advanced Matrix Extension (Intel AMX).
By changing the dtype for `FlaxBertModel `to `jax.numpy.bfloat16`, you get the performance benefits of the underlying hardware.
```python
import jax
model = FlaxBertModel.from_pretrained("bert-base-uncased", config=config, dtype=jax.numpy.bfloat16)
```
Switching from float32 to bfloat16 can increase the speed of an AWS c7i.4xlarge with Intel Sapphire Rapids by more than 2x.

View File

@ -1,5 +0,0 @@
datasets >= 1.1.3
jax>=0.2.8
jaxlib>=0.1.59
flax>=0.3.5
optax>=0.0.9

View File

@ -1,993 +0,0 @@
#!/usr/bin/env python
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Pretraining the library models for denoising language modeling on a text file or a dataset.
Here is the full list of checkpoints on the hub that can be pretrained by this script:
https://huggingface.co/models?filter=bart
"""
# You can also adapt this script on your own denoising language modeling task. Pointers for this are left as comments.
import json
import logging
import math
import os
import sys
import time
from dataclasses import asdict, dataclass, field
from enum import Enum
from itertools import chain
from pathlib import Path
from typing import Optional
import flax
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
from datasets import load_dataset
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import HfApi
from tqdm import tqdm
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
AutoTokenizer,
BartConfig,
BatchEncoding,
FlaxBartForConditionalGeneration,
HfArgumentParser,
PreTrainedTokenizerBase,
is_tensorboard_available,
set_seed,
)
from transformers.models.bart.modeling_flax_bart import shift_tokens_right
from transformers.utils import send_example_telemetry
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@dataclass
class TrainingArguments:
output_dir: str = field(
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)
overwrite_output_dir: bool = field(
default=False,
metadata={
"help": (
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
per_device_train_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
)
per_device_eval_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
)
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
push_to_hub: bool = field(
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
)
hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
def __post_init__(self):
if self.output_dir is not None:
self.output_dir = os.path.expanduser(self.output_dir)
def to_dict(self):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
the token values by removing their value.
"""
d = asdict(self)
for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
d[k] = [x.value for x in v]
if k.endswith("_token"):
d[k] = f"<{k.upper()}>"
return d
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
)
},
)
model_type: Optional[str] = field(
default=None,
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
dtype: Optional[str] = field(
default="float32",
metadata={
"help": (
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
" `[float32, float16, bfloat16]`."
)
},
)
token: str = field(
default=None,
metadata={
"help": (
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `hf auth login` (stored in `~/.huggingface`)."
)
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether to trust the execution of code from datasets/models defined on the Hub."
" This option should only be set to `True` for repositories you trust and in which you have read the"
" code, as it will execute code present on the Hub on your local machine."
)
},
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
train_ref_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
)
validation_ref_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
max_seq_length: Optional[int] = field(
default=None,
metadata={
"help": (
"The maximum total input sequence length after tokenization and masking. Sequences longer than this"
" will be truncated. Default to the max input length of the model."
)
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
mlm_probability: float = field(
default=0.3, metadata={"help": "Ratio of tokens to mask for span masked language modeling loss"}
)
permute_sentence_ratio: float = field(
default=1.0, metadata={"help": "Ratio of sentences to be permuted in each document"}
)
poisson_lambda: float = field(
default=3.0, metadata={"help": "Mean of Poisson distribution used to generate span-lengths to be masked"}
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
if extension not in ["csv", "json", "txt"]:
raise ValueError("train_file` should be a csv, json or text file.")
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
if extension not in ["csv", "json", "txt"]:
raise ValueError("`validation_file` should be a csv, json or text file.")
@flax.struct.dataclass
class FlaxDataCollatorForBartDenoisingLM:
"""
Data collator used for BART denoising language modeling. The code is largely copied from
`<https://github.com/morganmcg1/rotobart/blob/main/data_collator.py#L223>`__.
For more information on how BART denoising language modeling works, one can take a look
at the `official paper <https://huggingface.co/papers/1910.13461>`__
or the `official code for preprocessing <https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/denoising_dataset.py>`__ .
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data
mask_ratio (:obj:`float`):
The probability with which to (randomly) mask tokens in the input
poisson_lambda (:obj:`float`):
Mean parameter of Poisson distribution used to generate span-lengths to be masked
permute_sentence_ratio (:obj:`float`):
Ratio of sentences to be permuted in each document
decoder_start_token_id: (:obj:`int):
The decoder start token id of the model
"""
tokenizer: PreTrainedTokenizerBase
decoder_start_token_id: int
mask_ratio: float = 0.3
poisson_lambda: float = 3.0
permute_sentence_ratio: float = 1.0
def __post_init__(self):
if self.tokenizer.mask_token is None or self.tokenizer.eos_token is None:
raise ValueError(
"This tokenizer does not have a mask token or eos token which is necessary for denoising"
" language modeling. "
)
def __call__(self, examples: list[dict[str, list[int]]]) -> BatchEncoding:
# convert list to dict and tensorize input
batch = BatchEncoding(
{k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
)
batch["labels"] = batch["input_ids"].copy()
batch["decoder_input_ids"] = shift_tokens_right(
batch["labels"], self.tokenizer.pad_token_id, self.decoder_start_token_id
)
# permuting sentences
do_permute = False
if self.permute_sentence_ratio > 0.0:
batch["input_ids"] = self.permute_sentences(batch["input_ids"])
do_permute = True
# masking span of tokens (text infilling in the paper)
if self.mask_ratio:
batch["input_ids"], batch["labels"] = self.span_mask_tokens(
batch["input_ids"], batch["labels"], do_permute
)
# ignore pad tokens
batch["attention_mask"] = (batch["input_ids"] != self.tokenizer.pad_token_id).astype(int)
batch["decoder_attention_mask"] = (batch["decoder_input_ids"] != self.tokenizer.pad_token_id).astype(int)
return batch
def permute_sentences(self, input_ids):
"""
Shuffle sentences in each document.
"""
results = input_ids.copy()
# find end locations of sentences
end_sentence_mask = input_ids == self.tokenizer.pad_token_id
sentence_ends = np.argwhere(end_sentence_mask)
sentence_ends[:, 1] += 1
example_has_multiple_sentences, num_sentences = np.unique(sentence_ends[:, 0], return_counts=True)
num_sentences_map = dict(zip(example_has_multiple_sentences, num_sentences))
num_to_permute = np.ceil(num_sentences * self.permute_sentence_ratio).astype(int)
num_to_permute_map = dict(zip(example_has_multiple_sentences, num_to_permute))
sentence_ends = np.split(sentence_ends[:, 1], np.unique(sentence_ends[:, 0], return_index=True)[1][1:])
sentence_ends_map = dict(zip(example_has_multiple_sentences, sentence_ends))
for i in range(input_ids.shape[0]):
if i not in example_has_multiple_sentences:
continue
substitutions = np.random.permutation(num_sentences_map[i])[: num_to_permute_map[i]]
ordering = np.arange(0, num_sentences_map[i])
ordering[substitutions] = substitutions[np.random.permutation(num_to_permute_map[i])]
# write shuffled sentences into results
index = 0
for j in ordering:
sentence = input_ids[i, (sentence_ends_map[i][j - 1] if j > 0 else 0) : sentence_ends_map[i][j]]
results[i, index : index + sentence.shape[0]] = sentence
index += sentence.shape[0]
return results
def span_mask_tokens(self, input_ids, labels, do_permute):
"""
Sampling text spans with span lengths drawn from a Poisson distribution and masking them.
"""
special_tokens_mask_labels = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
special_tokens_mask_inputs = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in input_ids.tolist()
]
special_tokens_mask_labels = np.array(special_tokens_mask_labels, dtype=bool)
special_tokens_mask_inputs = np.array(special_tokens_mask_inputs, dtype=bool)
# determine how many tokens we need to mask in total
is_token_mask = ~(input_ids == self.tokenizer.pad_token_id) & ~special_tokens_mask_inputs
num_tokens_to_mask = int(math.ceil(is_token_mask.astype(float).sum() * self.mask_ratio))
if num_tokens_to_mask == 0:
return input_ids, labels
# generate a sufficient number of span lengths
span_lengths = np.random.poisson(lam=self.poisson_lambda, size=(num_tokens_to_mask,))
while np.cumsum(span_lengths, 0)[-1] < num_tokens_to_mask:
span_lengths = np.concatenate(
[span_lengths, np.random.poisson(lam=self.poisson_lambda, size=(num_tokens_to_mask,))]
)
# remove all spans of length 0
# note that BART inserts additional mask tokens where length == 0,
# which we do not implement for now as it adds additional complexity
span_lengths = span_lengths[span_lengths > 0]
# trim to about num_tokens_to_mask tokens
cutoff_idx = np.argmin(np.abs(np.cumsum(span_lengths, 0) - num_tokens_to_mask)) + 1
span_lengths = span_lengths[:cutoff_idx]
# randomly choose starting positions for masking
token_indices = np.argwhere(is_token_mask == 1)
span_starts = np.random.permutation(token_indices.shape[0])[: span_lengths.shape[0]]
# prepare mask
masked_indices = np.array(token_indices[span_starts])
mask = np.full_like(input_ids, fill_value=False)
# mask starting positions
for mi in masked_indices:
mask[tuple(mi)] = True
span_lengths -= 1
# fill up spans
max_index = input_ids.shape[1] - 1
remaining = (span_lengths > 0) & (masked_indices[:, 1] < max_index)
while np.any(remaining):
masked_indices[remaining, 1] += 1
for mi in masked_indices:
mask[tuple(mi)] = True
span_lengths -= 1
remaining = (span_lengths > 0) & (masked_indices[:, 1] < max_index)
# place the mask tokens
mask[np.where(special_tokens_mask_inputs)] = False
input_ids[np.where(mask)] = self.tokenizer.mask_token_id
if not do_permute:
labels[np.where(mask == 0)] = -100
else:
labels[np.where(special_tokens_mask_labels)] = -100
# remove mask tokens that are not starts of spans
to_remove = (mask == 1) & np.roll((mask == 1), 1, 1)
new_input_ids = np.full_like(input_ids, fill_value=self.tokenizer.pad_token_id)
for i, example in enumerate(input_ids):
new_example = example[~to_remove[i]]
new_input_ids[i, : new_example.shape[0]] = new_example
return new_input_ids, labels
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
"""Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
num_samples = len(samples_idx)
if drop_last:
samples_to_remove = num_samples % batch_size
if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove]
sections_split = num_samples // batch_size
samples_idx = samples_idx.reshape((sections_split, batch_size))
else:
sections_split = math.ceil(num_samples / batch_size)
samples_idx = np.array_split(samples_idx, sections_split)
return samples_idx
def write_train_metric(summary_writer, train_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
for key, vals in train_metrics.items():
tag = f"train_{key}"
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
def write_eval_metric(summary_writer, eval_metrics, step):
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_bart_dlm", model_args, data_args, framework="flax")
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
level=logging.INFO,
datefmt="[%X]",
)
# Log on each process the small summary:
logger = logging.getLogger(__name__)
# Set the verbosity to info of the Transformers logger (on main process only):
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
set_seed(training_args.seed)
# Handle the repository creation
if training_args.push_to_hub:
# Retrieve of infer repo_name
repo_name = training_args.hub_model_id
if repo_name is None:
repo_name = Path(training_args.output_dir).absolute().name
# Create repo and retrieve repo_id
api = HfApi()
repo_id = api.create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
# 'text' is found. You can easily tweak this behavior (see below).
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
trust_remote_code=data_args.trust_remote_code,
)
if "validation" not in datasets:
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
trust_remote_code=data_args.trust_remote_code,
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
trust_remote_code=data_args.trust_remote_code,
)
else:
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1]
if extension == "txt":
extension = "text"
datasets = load_dataset(
extension,
data_files=data_files,
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
if "validation" not in datasets:
datasets["validation"] = load_dataset(
extension,
data_files=data_files,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.
# Load pretrained model and tokenizer
if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
token=model_args.token,
)
elif model_args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
token=model_args.token,
)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script. "
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
if model_args.config_name:
config = BartConfig.from_pretrained(
model_args.config_name,
cache_dir=model_args.cache_dir,
vocab_size=len(tokenizer),
token=model_args.token,
)
elif model_args.model_name_or_path:
config = BartConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
token=model_args.token,
)
else:
config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
# Preprocessing the datasets.
# First we tokenize all the texts.
if training_args.do_train:
column_names = datasets["train"].column_names
else:
column_names = datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
# Use Punkt Sentence Tokenizer to divide a document into a list of sentences
nltk.download("punkt")
sentence_tokenizer = nltk.data.load("tokenizers/punkt/english.pickle")
def sentence_split_function(example):
sents = sentence_tokenizer.tokenize(example["text"])
# use pad token as end of sentence indicator
new_text = tokenizer.bos_token + f"{tokenizer.pad_token}".join(sents) + tokenizer.eos_token
return {"text": new_text}
split_datasets = datasets.map(
sentence_split_function,
batched=False,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
# Tokenize every text, then concatenate them together before splitting them in smaller parts.
# Since we make sure that all sequences are of the same length, no attention_mask is needed.
def tokenize_function(examples):
return tokenizer(examples[text_column_name], add_special_tokens=False, return_attention_mask=False)
tokenized_datasets = split_datasets.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=text_column_name,
load_from_cache_file=not data_args.overwrite_cache,
)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= max_seq_length:
total_length = (total_length // max_seq_length) * max_seq_length
# Split by chunks of max_len.
result = {
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
for k, t in concatenated_examples.items()
}
return result
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
# might be slower to preprocess.
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/process#map
tokenized_datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)
# Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if has_tensorboard and jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
except ImportError as ie:
has_tensorboard = False
logger.warning(
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
# Initialize our training
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
if model_args.model_name_or_path:
model = FlaxBartForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
config=config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
token=model_args.token,
)
else:
config.vocab_size = len(tokenizer)
model = FlaxBartForConditionalGeneration(
config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
)
# Data collator
# This one will take care of randomly masking the tokens and permuting the sentences.
data_collator = FlaxDataCollatorForBartDenoisingLM(
tokenizer=tokenizer,
decoder_start_token_id=model.config.decoder_start_token_id,
mask_ratio=data_args.mlm_probability,
poisson_lambda=data_args.poisson_lambda,
permute_sentence_ratio=data_args.permute_sentence_ratio,
)
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
# Create learning rate schedule
warmup_fn = optax.linear_schedule(
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
)
decay_fn = optax.linear_schedule(
init_value=training_args.learning_rate,
end_value=0,
transition_steps=num_train_steps - training_args.warmup_steps,
)
linear_decay_lr_schedule_fn = optax.join_schedules(
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
if training_args.adafactor:
# We use the default parameters here to initialize adafactor,
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
optimizer = optax.adafactor(
learning_rate=linear_decay_lr_schedule_fn,
)
else:
optimizer = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
# Define gradient update step fn
def train_step(state, batch, dropout_rng):
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
def loss_fn(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
# compute loss, ignore padded input tokens and special tokens
label_mask = jnp.where(labels > 0, 1.0, 0.0)
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
# take average
loss = loss.sum()
num_labels = label_mask.sum()
return loss, num_labels
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, metrics, new_dropout_rng
# Create parallel version of the train step
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
# Define eval fn
def eval_step(params, batch):
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
# compute loss, ignore padded input tokens and special tokens
label_mask = jnp.where(labels > 0, 1.0, 0.0)
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
# compute accuracy
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
# summarize metrics
metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
metrics = jax.lax.psum(metrics, axis_name="batch")
return metrics
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
# Replicate the train state on each device
state = jax_utils.replicate(state)
train_time = 0
epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
train_metrics = []
# Create sampling rng
rng, input_rng = jax.random.split(rng)
# Generate an epoch by shuffling sampling indices from the train dataset
num_train_samples = len(tokenized_datasets["train"])
# Avoid using jax.numpy here in case of TPU training
train_samples_idx = np.random.permutation(np.arange(num_train_samples))
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
# Gather the indexes for creating the batch and do a training step
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples)
# Model forward
model_inputs = shard(model_inputs.data)
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
train_metrics.append(train_metric)
cur_step = epoch * (num_train_samples // train_batch_size) + step
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
# Save metrics
train_metric = jax_utils.unreplicate(train_metric)
train_time += time.time() - train_start
if has_tensorboard and jax.process_index() == 0:
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate:"
f" {train_metric['learning_rate']})"
)
train_metrics = []
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
# ======================== Evaluating ==============================
num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples)
# Model forward
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
# Update progress bar
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
# Save metrics
if has_tensorboard and jax.process_index() == 0:
write_eval_metric(summary_writer, eval_metrics, cur_step)
if cur_step % training_args.save_steps == 0 and cur_step > 0:
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
api.upload_folder(
commit_message=f"Saving weights and logs of step {cur_step}",
folder_path=training_args.output_dir,
repo_id=repo_id,
repo_type="model",
token=training_args.hub_token,
)
# Eval after training
if training_args.do_eval:
num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
eval_metrics = []
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples)
# Model forward
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
try:
perplexity = math.exp(eval_metrics["loss"])
except OverflowError:
perplexity = float("inf")
eval_metrics["perplexity"] = perplexity
if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)
if __name__ == "__main__":
main()

View File

@ -1,56 +0,0 @@
#!/usr/bin/env python3
import time
from argparse import ArgumentParser
import jax
import numpy as np
from transformers import BertConfig, FlaxBertModel
parser = ArgumentParser()
parser.add_argument("--precision", type=str, choices=["float32", "bfloat16"], default="float32")
args = parser.parse_args()
dtype = jax.numpy.float32
if args.precision == "bfloat16":
dtype = jax.numpy.bfloat16
VOCAB_SIZE = 30522
BS = 32
SEQ_LEN = 128
def get_input_data(batch_size=1, seq_length=384):
shape = (batch_size, seq_length)
input_ids = np.random.randint(1, VOCAB_SIZE, size=shape).astype(np.int32)
token_type_ids = np.ones(shape).astype(np.int32)
attention_mask = np.ones(shape).astype(np.int32)
return {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
inputs = get_input_data(BS, SEQ_LEN)
config = BertConfig.from_pretrained("bert-base-uncased", hidden_act="gelu_new")
model = FlaxBertModel.from_pretrained("bert-base-uncased", config=config, dtype=dtype)
@jax.jit
def func():
outputs = model(**inputs)
return outputs
(nwarmup, nbenchmark) = (5, 100)
# warmpup
for _ in range(nwarmup):
func()
# benchmark
start = time.time()
for _ in range(nbenchmark):
func()
end = time.time()
print(end - start)
print(f"Throughput: {((nbenchmark * BS) / (end - start)):.3f} examples/sec")

View File

@ -1,869 +0,0 @@
#!/usr/bin/env python
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Pre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=text-generation
"""
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
import json
import logging
import math
import os
import sys
import time
from dataclasses import asdict, dataclass, field
from enum import Enum
from itertools import chain
from pathlib import Path
from typing import Callable, Optional
import datasets
import jax
import jax.numpy as jnp
import numpy as np
import optax
from datasets import Dataset, load_dataset
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import HfApi
from tqdm import tqdm
import transformers
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
AutoConfig,
AutoTokenizer,
FlaxAutoModelForCausalLM,
HfArgumentParser,
is_tensorboard_available,
set_seed,
)
from transformers.testing_utils import CaptureLogger
from transformers.utils import send_example_telemetry
logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@dataclass
class TrainingArguments:
output_dir: str = field(
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)
overwrite_output_dir: bool = field(
default=False,
metadata={
"help": (
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
per_device_train_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
)
per_device_eval_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
)
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
push_to_hub: bool = field(
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
)
hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
def __post_init__(self):
if self.output_dir is not None:
self.output_dir = os.path.expanduser(self.output_dir)
def to_dict(self):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
the token values by removing their value.
"""
d = asdict(self)
for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
d[k] = [x.value for x in v]
if k.endswith("_token"):
d[k] = f"<{k.upper()}>"
return d
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
)
},
)
model_type: Optional[str] = field(
default=None,
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
dtype: Optional[str] = field(
default="float32",
metadata={
"help": (
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
" `[float32, float16, bfloat16]`."
)
},
)
token: str = field(
default=None,
metadata={
"help": (
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `hf auth login` (stored in `~/.huggingface`)."
)
},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether to trust the execution of code from datasets/models defined on the Hub."
" This option should only be set to `True` for repositories you trust and in which you have read the"
" code, as it will execute code present on the Hub on your local machine."
)
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
block_size: Optional[int] = field(
default=None,
metadata={
"help": (
"Optional input sequence length after tokenization. "
"The training dataset will be truncated in block of this size for training. "
"Default to the model max input length for single sentence inputs (take into account special tokens)."
)
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
keep_linebreaks: bool = field(
default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
if extension not in ["csv", "json", "txt"]:
raise ValueError("train_file` should be a csv, json or text file.")
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
if extension not in ["csv", "json", "txt"]:
raise ValueError("`validation_file` should be a csv, json or text file.")
class TrainState(train_state.TrainState):
dropout_rng: jnp.ndarray
def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True):
"""
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
"""
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
batch_idx = np.asarray(batch_idx)
else:
batch_idx = np.arange(len(dataset))
if drop_last:
steps_per_epoch = len(dataset) // batch_size
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
else:
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()}
yield batch
def write_train_metric(summary_writer, train_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
for key, vals in train_metrics.items():
tag = f"train_{key}"
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
def write_eval_metric(summary_writer, eval_metrics, step):
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_clm", model_args, data_args, framework="flax")
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Setup logging, we only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# Set the verbosity to info of the Transformers logger (on main process only):
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
set_seed(training_args.seed)
# Handle the repository creation
if training_args.push_to_hub:
# Retrieve of infer repo_name
repo_name = training_args.hub_model_id
if repo_name is None:
repo_name = Path(training_args.output_dir).absolute().name
# Create repo and retrieve repo_id
api = HfApi()
repo_id = api.create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
# 'text' is found. You can easily tweak this behavior (see below).
#
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
keep_in_memory=False,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
trust_remote_code=model_args.trust_remote_code,
)
if "validation" not in dataset:
dataset["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
trust_remote_code=model_args.trust_remote_code,
)
dataset["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
trust_remote_code=model_args.trust_remote_code,
)
else:
data_files = {}
dataset_args = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1]
if extension == "txt":
extension = "text"
dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
dataset = load_dataset(
extension,
data_files=data_files,
cache_dir=model_args.cache_dir,
**dataset_args,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
if "validation" not in dataset:
dataset["validation"] = load_dataset(
extension,
data_files=data_files,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
**dataset_args,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
dataset["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
**dataset_args,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.
# Load pretrained model and tokenizer
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
if model_args.config_name:
config = AutoConfig.from_pretrained(
model_args.config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
else:
config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
elif model_args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script. "
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
if model_args.model_name_or_path:
model = FlaxAutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
else:
model = FlaxAutoModelForCausalLM.from_config(
config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
trust_remote_code=model_args.trust_remote_code,
)
# Preprocessing the datasets.
# First we tokenize all the texts.
if training_args.do_train:
column_names = dataset["train"].column_names
else:
column_names = dataset["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
def tokenize_function(examples):
with CaptureLogger(tok_logger) as cl:
output = tokenizer(examples[text_column_name])
# clm input could be much much longer than block_size
if "Token indices sequence length is longer than the" in cl.out:
tok_logger.warning(
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
" before being passed to the model."
)
return output
tokenized_datasets = dataset.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
if data_args.block_size is None:
block_size = tokenizer.model_max_length
if block_size > config.max_position_embeddings:
logger.warning(
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
f"Using block_size={min(1024, config.max_position_embeddings)} instead. You can change that default value by passing --block_size xxx."
)
block_size = min(1024, config.max_position_embeddings)
else:
if data_args.block_size > tokenizer.model_max_length:
logger.warning(
f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model "
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
)
block_size = min(data_args.block_size, tokenizer.model_max_length)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
# to preprocess.
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/process#map
lm_datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)
if training_args.do_train:
if "train" not in tokenized_datasets:
raise ValueError("--do_train requires a train dataset")
train_dataset = lm_datasets["train"]
if data_args.max_train_samples is not None:
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
train_dataset = train_dataset.select(range(max_train_samples))
if training_args.do_eval:
if "validation" not in tokenized_datasets:
raise ValueError("--do_eval requires a validation dataset")
eval_dataset = lm_datasets["validation"]
if data_args.max_eval_samples is not None:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
# Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if has_tensorboard and jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
except ImportError as ie:
has_tensorboard = False
logger.warning(
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
# Initialize our training
rng = jax.random.PRNGKey(training_args.seed)
rng, dropout_rng = jax.random.split(rng)
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
len(train_dataset),
train_batch_size,
training_args.num_train_epochs,
training_args.warmup_steps,
training_args.learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
if training_args.adafactor:
# We use the default parameters here to initialize adafactor,
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
optimizer = optax.adafactor(
learning_rate=linear_decay_lr_schedule_fn,
)
else:
optimizer = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
def loss_fn(logits, labels):
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
return loss.mean()
# Define gradient update step fn
def train_step(state, batch):
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = loss_fn(logits, labels)
return loss
grad_fn = jax.value_and_grad(compute_loss)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return new_state, metrics
# Define eval fn
def eval_step(params, batch):
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
loss = loss_fn(logits, labels)
# summarize metrics
metrics = {"loss": loss}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return metrics
# Create parallel version of the train and eval step
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
p_eval_step = jax.pmap(eval_step, "batch")
# Replicate the train state on each device
state = state.replicate()
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_epochs}")
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
logger.info(f" Total optimization steps = {total_train_steps}")
train_time = 0
train_metrics = []
epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
# Create sampling rng
rng, input_rng = jax.random.split(rng)
# Generate an epoch by shuffling sampling indices from the train dataset
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
steps_per_epoch = len(train_dataset) // train_batch_size
# train
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
batch = next(train_loader)
batch = shard(batch)
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
cur_step = epoch * (len(train_dataset) // train_batch_size) + step
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
# Save metrics
train_metric = unreplicate(train_metric)
train_time += time.time() - train_start
if has_tensorboard and jax.process_index() == 0:
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
f" {train_metric['learning_rate'].mean()})"
)
train_metrics = []
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
# ======================== Evaluating ==============================
eval_metrics = []
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward
batch = next(eval_loader)
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, batch, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
try:
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
except OverflowError:
eval_metrics["perplexity"] = float("inf")
# Print metrics and update progress bar
desc = (
f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity:"
f" {eval_metrics['perplexity']})"
)
epochs.write(desc)
epochs.desc = desc
# Save metrics
if has_tensorboard and jax.process_index() == 0:
write_eval_metric(summary_writer, eval_metrics, cur_step)
if cur_step % training_args.save_steps == 0 and cur_step > 0:
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(unreplicate(state.params))
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
api.upload_folder(
commit_message=f"Saving weights and logs of step {cur_step}",
folder_path=training_args.output_dir,
repo_id=repo_id,
repo_type="model",
token=training_args.hub_token,
)
# Eval after training
if training_args.do_eval:
eval_metrics = []
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False)
eval_steps = math.ceil(len(eval_dataset) / eval_batch_size)
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
# Model forward
batch = next(eval_loader)
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, batch, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_util.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)
try:
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
except OverflowError:
eval_metrics["perplexity"] = float("inf")
if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)
if __name__ == "__main__":
main()

View File

@ -1,924 +0,0 @@
#!/usr/bin/env python
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
text file or a dataset.
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=fill-mask
"""
import json
import logging
import math
import os
import sys
import time
from dataclasses import asdict, dataclass, field
from enum import Enum
from itertools import chain
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
from pathlib import Path
from typing import Optional
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from datasets import load_dataset
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import HfApi
from tqdm import tqdm
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
AutoConfig,
AutoTokenizer,
FlaxAutoModelForMaskedLM,
HfArgumentParser,
PreTrainedTokenizerBase,
TensorType,
is_tensorboard_available,
set_seed,
)
from transformers.utils import send_example_telemetry
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@dataclass
class TrainingArguments:
output_dir: str = field(
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)
overwrite_output_dir: bool = field(
default=False,
metadata={
"help": (
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
per_device_train_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
)
per_device_eval_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
)
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
push_to_hub: bool = field(
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
)
hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
gradient_checkpointing: bool = field(
default=False,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
def __post_init__(self):
if self.output_dir is not None:
self.output_dir = os.path.expanduser(self.output_dir)
def to_dict(self):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
the token values by removing their value.
"""
d = asdict(self)
for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
d[k] = [x.value for x in v]
if k.endswith("_token"):
d[k] = f"<{k.upper()}>"
return d
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
)
},
)
model_type: Optional[str] = field(
default=None,
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
dtype: Optional[str] = field(
default="float32",
metadata={
"help": (
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
" `[float32, float16, bfloat16]`."
)
},
)
token: str = field(
default=None,
metadata={
"help": (
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `hf auth login` (stored in `~/.huggingface`)."
)
},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether to trust the execution of code from datasets/models defined on the Hub."
" This option should only be set to `True` for repositories you trust and in which you have read the"
" code, as it will execute code present on the Hub on your local machine."
)
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
train_ref_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
)
validation_ref_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
validation_split_percentage: Optional[int] = field(
default=5,
metadata={
"help": "The percentage of the train set used as validation set in case there's no validation split"
},
)
max_seq_length: Optional[int] = field(
default=None,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated. Default to the max input length of the model."
)
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
mlm_probability: float = field(
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
)
pad_to_max_length: bool = field(
default=False,
metadata={
"help": (
"Whether to pad all samples to `max_seq_length`. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
)
},
)
line_by_line: bool = field(
default=False,
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
@flax.struct.dataclass
class FlaxDataCollatorForLanguageModeling:
"""
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
are not all of the same length.
Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
The probability with which to (randomly) mask tokens in the input.
.. note::
For best performance, this data collator should be used with a dataset having items that are dictionaries or
BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
:class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
argument :obj:`return_special_tokens_mask=True`.
"""
tokenizer: PreTrainedTokenizerBase
mlm_probability: float = 0.15
def __post_init__(self):
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
"You should pass `mlm=False` to train on causal language modeling instead."
)
def __call__(self, examples: list[dict[str, np.ndarray]], pad_to_multiple_of: int) -> dict[str, np.ndarray]:
# Handle dict or lists with proper padding and conversion to tensor.
batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
# If special token mask has been preprocessed, pop it from the dict.
special_tokens_mask = batch.pop("special_tokens_mask", None)
batch["input_ids"], batch["labels"] = self.mask_tokens(
batch["input_ids"], special_tokens_mask=special_tokens_mask
)
return batch
def mask_tokens(
self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
) -> tuple[np.ndarray, np.ndarray]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
labels = inputs.copy()
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
probability_matrix = np.full(labels.shape, self.mlm_probability)
special_tokens_mask = special_tokens_mask.astype("bool")
probability_matrix[special_tokens_mask] = 0.0
masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
indices_random &= masked_indices & ~indices_replaced
random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray:
"""Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by
the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned."""
num_samples = len(samples_idx)
if drop_last:
samples_to_remove = num_samples % batch_size
if samples_to_remove != 0:
samples_idx = samples_idx[:-samples_to_remove]
sections_split = num_samples // batch_size
samples_idx = samples_idx.reshape((sections_split, batch_size))
else:
sections_split = math.ceil(num_samples / batch_size)
samples_idx = np.array_split(samples_idx, sections_split)
return samples_idx
def write_train_metric(summary_writer, train_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
for key, vals in train_metrics.items():
tag = f"train_{key}"
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
def write_eval_metric(summary_writer, eval_metrics, step):
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_mlm", model_args, data_args, framework="flax")
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
level=logging.INFO,
datefmt="[%X]",
)
# Log on each process the small summary:
logger = logging.getLogger(__name__)
# Set the verbosity to info of the Transformers logger (on main process only):
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
set_seed(training_args.seed)
# Handle the repository creation
if training_args.push_to_hub:
# Retrieve of infer repo_name
repo_name = training_args.hub_model_id
if repo_name is None:
repo_name = Path(training_args.output_dir).absolute().name
# Create repo and retrieve repo_id
api = HfApi()
repo_id = api.create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
# 'text' is found. You can easily tweak this behavior (see below).
#
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
datasets = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
trust_remote_code=model_args.trust_remote_code,
)
if "validation" not in datasets:
datasets["validation"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
trust_remote_code=model_args.trust_remote_code,
)
datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
trust_remote_code=model_args.trust_remote_code,
)
else:
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.validation_file.split(".")[-1]
if extension == "txt":
extension = "text"
datasets = load_dataset(
extension,
data_files=data_files,
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
if "validation" not in datasets:
datasets["validation"] = load_dataset(
extension,
data_files=data_files,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
datasets["train"] = load_dataset(
extension,
data_files=data_files,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
# https://huggingface.co/docs/datasets/loading_datasets.
# Load pretrained model and tokenizer
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.
if model_args.config_name:
config = AutoConfig.from_pretrained(
model_args.config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
else:
config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
elif model_args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script. "
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
)
# Preprocessing the datasets.
# First we tokenize all the texts.
if training_args.do_train:
column_names = datasets["train"].column_names
else:
column_names = datasets["validation"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
if data_args.line_by_line:
# When using line_by_line, we just tokenize each nonempty line.
padding = "max_length" if data_args.pad_to_max_length else False
def tokenize_function(examples):
# Remove empty lines
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
return tokenizer(
examples,
return_special_tokens_mask=True,
padding=padding,
truncation=True,
max_length=max_seq_length,
)
tokenized_datasets = datasets.map(
tokenize_function,
input_columns=[text_column_name],
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
else:
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
# efficient when it receives the `special_tokens_mask`.
def tokenize_function(examples):
return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=not data_args.overwrite_cache,
)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
# customize this part to your needs.
if total_length >= max_seq_length:
total_length = (total_length // max_seq_length) * max_seq_length
# Split by chunks of max_len.
result = {
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
for k, t in concatenated_examples.items()
}
return result
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
# might be slower to preprocess.
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/process#map
tokenized_datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
)
# Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if has_tensorboard and jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
except ImportError as ie:
has_tensorboard = False
logger.warning(
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
# Data collator
# This one will take care of randomly masking the tokens.
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
# Initialize our training
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
if model_args.model_name_or_path:
model = FlaxAutoModelForMaskedLM.from_pretrained(
model_args.model_name_or_path,
config=config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
else:
model = FlaxAutoModelForMaskedLM.from_config(
config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
trust_remote_code=model_args.trust_remote_code,
)
if training_args.gradient_checkpointing:
model.enable_gradient_checkpointing()
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
# Create learning rate schedule
warmup_fn = optax.linear_schedule(
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
)
decay_fn = optax.linear_schedule(
init_value=training_args.learning_rate,
end_value=0,
transition_steps=num_train_steps - training_args.warmup_steps,
)
linear_decay_lr_schedule_fn = optax.join_schedules(
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
if training_args.adafactor:
# We use the default parameters here to initialize adafactor,
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
optimizer = optax.adafactor(
learning_rate=linear_decay_lr_schedule_fn,
)
else:
optimizer = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
# Define gradient update step fn
def train_step(state, batch, dropout_rng):
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
def loss_fn(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
# compute loss, ignore padded input tokens
label_mask = jnp.where(labels > 0, 1.0, 0.0)
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
# take average
loss = loss.sum()
num_labels = label_mask.sum()
return loss, num_labels
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, metrics, new_dropout_rng
# Create parallel version of the train step
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
# Define eval fn
def eval_step(params, batch):
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
# compute loss, ignore padded input tokens
label_mask = jnp.where(labels > 0, 1.0, 0.0)
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
# compute accuracy
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
# summarize metrics
metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
metrics = jax.lax.psum(metrics, axis_name="batch")
return metrics
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
# Replicate the train state on each device
state = jax_utils.replicate(state)
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
train_metrics = []
# Create sampling rng
rng, input_rng = jax.random.split(rng)
# Generate an epoch by shuffling sampling indices from the train dataset
num_train_samples = len(tokenized_datasets["train"])
# Avoid using jax.numpy here in case of TPU training
train_samples_idx = np.random.permutation(np.arange(num_train_samples))
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
# Gather the indexes for creating the batch and do a training step
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples, pad_to_multiple_of=16)
# Model forward
model_inputs = shard(model_inputs.data)
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
train_metrics.append(train_metric)
cur_step = epoch * (num_train_samples // train_batch_size) + step
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
# Save metrics
train_metric = jax_utils.unreplicate(train_metric)
train_time += time.time() - train_start
if has_tensorboard and jax.process_index() == 0:
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate:"
f" {train_metric['learning_rate']})"
)
train_metrics = []
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
# ======================== Evaluating ==============================
num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
eval_metrics = []
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples, pad_to_multiple_of=16)
# Model forward
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
# Update progress bar
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
# Save metrics
if has_tensorboard and jax.process_index() == 0:
write_eval_metric(summary_writer, eval_metrics, cur_step)
if cur_step % training_args.save_steps == 0 and cur_step > 0:
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
api.upload_folder(
commit_message=f"Saving weights and logs of step {cur_step}",
folder_path=training_args.output_dir,
repo_id=repo_id,
repo_type="model",
token=training_args.hub_token,
)
# Eval after training
if training_args.do_eval:
num_eval_samples = len(tokenized_datasets["validation"])
# Avoid using jax.numpy here in case of TPU training
eval_samples_idx = np.arange(num_eval_samples)
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False)
eval_metrics = []
for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
model_inputs = data_collator(samples, pad_to_multiple_of=16)
# Model forward
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
try:
perplexity = math.exp(eval_metrics["loss"])
except OverflowError:
perplexity = float("inf")
eval_metrics["perplexity"] = perplexity
if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -1,117 +0,0 @@
#!/usr/bin/env python3
import json
from collections.abc import Iterator
from typing import Union
from tokenizers import AddedToken, Regex, Tokenizer, decoders, normalizers, pre_tokenizers, trainers
from tokenizers.implementations.base_tokenizer import BaseTokenizer
from tokenizers.models import Unigram
from tokenizers.processors import TemplateProcessing
class SentencePieceUnigramTokenizer(BaseTokenizer):
"""
This class is a copy of `DeDLOC's tokenizer implementation <https://github.com/yandex-research/DeDLOC/blob/main/sahajbert/tokenizer/tokenizer_model.py>`__ .
Custom SentencePiece Unigram Tokenizer with NMT, NKFC, spaces and lower-casing characters normalization
Represents the Unigram algorithm, with the pretokenization used by SentencePiece
"""
def __init__(
self,
replacement: str = "",
add_prefix_space: bool = True,
unk_token: Union[str, AddedToken] = "<unk>",
eos_token: Union[str, AddedToken] = "</s>",
pad_token: Union[str, AddedToken] = "<pad>",
):
self.special_tokens = {
"pad": {"id": 0, "token": pad_token},
"eos": {"id": 1, "token": eos_token},
"unk": {"id": 2, "token": unk_token},
}
self.special_tokens_list = [None] * len(self.special_tokens)
for token_dict in self.special_tokens.values():
self.special_tokens_list[token_dict["id"]] = token_dict["token"]
tokenizer = Tokenizer(Unigram())
tokenizer.normalizer = normalizers.Sequence(
[
normalizers.Nmt(),
normalizers.NFKC(),
normalizers.Replace(Regex(" {2,}"), " "),
normalizers.Lowercase(),
]
)
tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
[
pre_tokenizers.Metaspace(
replacement=replacement, prepend_scheme="always" if add_prefix_space else "never"
),
pre_tokenizers.Digits(individual_digits=True),
pre_tokenizers.Punctuation(),
]
)
tokenizer.decoder = decoders.Metaspace(
replacement=replacement, prepend_scheme="always" if add_prefix_space else "never"
)
tokenizer.post_processor = TemplateProcessing(
single=f"$A {self.special_tokens['eos']['token']}",
special_tokens=[(self.special_tokens["eos"]["token"], self.special_tokens["eos"]["id"])],
)
parameters = {
"model": "SentencePieceUnigram",
"replacement": replacement,
"add_prefix_space": add_prefix_space,
}
super().__init__(tokenizer, parameters)
def train(
self,
files: Union[str, list[str]],
vocab_size: int = 8000,
show_progress: bool = True,
):
"""Train the model using the given files"""
trainer = trainers.UnigramTrainer(
vocab_size=vocab_size,
special_tokens=self.special_tokens_list,
show_progress=show_progress,
)
if isinstance(files, str):
files = [files]
self._tokenizer.train(files, trainer=trainer)
self.add_unk_id()
def train_from_iterator(
self,
iterator: Union[Iterator[str], Iterator[Iterator[str]]],
vocab_size: int = 8000,
show_progress: bool = True,
):
"""Train the model using the given iterator"""
trainer = trainers.UnigramTrainer(
vocab_size=vocab_size,
special_tokens=self.special_tokens_list,
show_progress=show_progress,
)
self._tokenizer.train_from_iterator(iterator, trainer=trainer)
self.add_unk_id()
def add_unk_id(self):
tokenizer_json = json.loads(self._tokenizer.to_str())
tokenizer_json["model"]["unk_id"] = self.special_tokens["unk"]["id"]
self._tokenizer = Tokenizer.from_str(json.dumps(tokenizer_json))

View File

@ -1,104 +0,0 @@
<!---
Copyright 2021 The Google Flax Team Authors and HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Question Answering examples
Based on the script [`run_qa.py`](https://github.com/huggingface/transformers/blob/main/examples/flax/question-answering/run_qa.py).
**Note:** This script only works with models that have a fast tokenizer (backed by the 🤗 Tokenizers library) as it
uses special features of those tokenizers. You can check if your favorite model has a fast tokenizer in
[this table](https://huggingface.co/transformers/index.html#supported-frameworks), if it doesn't you can still use the old version
of the script.
The following example fine-tunes BERT on SQuAD:
```bash
python run_qa.py \
--model_name_or_path google-bert/bert-base-uncased \
--dataset_name squad \
--do_train \
--do_eval \
--max_seq_length 384 \
--doc_stride 128 \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--per_device_train_batch_size 12 \
--output_dir ./bert-qa-squad \
--eval_steps 1000 \
--push_to_hub
```
Using the command above, the script will train for 2 epochs and run eval after each epoch.
Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`.
You can see the results by running `tensorboard` in that directory:
```bash
$ tensorboard --logdir .
```
or directly on the hub under *Training metrics*.
Training with the previously defined hyper-parameters yields the following results:
```bash
f1 = 88.62
exact_match = 81.34
```
sample Metrics - [tfhub.dev](https://tensorboard.dev/experiment/6gU75Hx8TGCnc6tr4ZgI9Q)
Here is an example training on 4 TITAN RTX GPUs and Bert Whole Word Masking uncased model to reach a F1 > 93 on SQuAD1.1:
```bash
export CUDA_VISIBLE_DEVICES=0,1,2,3
python run_qa.py \
--model_name_or_path google-bert/bert-large-uncased-whole-word-masking \
--dataset_name squad \
--do_train \
--do_eval \
--per_device_train_batch_size 6 \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir ./wwm_uncased_finetuned_squad/ \
--eval_steps 1000 \
--push_to_hub
```
Training with the previously defined hyper-parameters yields the following results:
```bash
f1 = 93.31
exact_match = 87.04
```
### Usage notes
Note that when contexts are long they may be split into multiple training cases, not all of which may contain
the answer span.
As-is, the example script will train on SQuAD or any other question-answering dataset formatted the same way, and can handle user
inputs as well.
### Memory usage and data loading
One thing to note is that all data is loaded into memory in this script. Most question answering datasets are small
enough that this is not an issue, but if you have a very large dataset you will need to modify the script to handle
data streaming.

View File

@ -1,5 +0,0 @@
datasets >= 1.8.0
jax>=0.2.17
jaxlib>=0.1.68
flax>=0.3.5
optax>=0.0.8

File diff suppressed because it is too large Load Diff

View File

@ -1,443 +0,0 @@
# Copyright 2020 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Post-processing utilities for question answering.
"""
import collections
import json
import logging
import os
from typing import Optional
import numpy as np
from tqdm.auto import tqdm
logger = logging.getLogger(__name__)
def postprocess_qa_predictions(
examples,
features,
predictions: tuple[np.ndarray, np.ndarray],
version_2_with_negative: bool = False,
n_best_size: int = 20,
max_answer_length: int = 30,
null_score_diff_threshold: float = 0.0,
output_dir: Optional[str] = None,
prefix: Optional[str] = None,
log_level: Optional[int] = logging.WARNING,
):
"""
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
original contexts. This is the base postprocessing functions for models that only return start and end logits.
Args:
examples: The non-preprocessed dataset (see the main script for more information).
features: The processed dataset (see the main script for more information).
predictions (:obj:`tuple[np.ndarray, np.ndarray]`):
The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
first dimension must match the number of elements of :obj:`features`.
version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the underlying dataset contains examples with no answers.
n_best_size (:obj:`int`, `optional`, defaults to 20):
The total number of n-best predictions to generate when looking for an answer.
max_answer_length (:obj:`int`, `optional`, defaults to 30):
The maximum length of an answer that can be generated. This is needed because the start and end predictions
are not conditioned on one another.
null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0):
The threshold used to select the null answer: if the best answer has a score that is less than the score of
the null answer minus this threshold, the null answer is selected for this example (note that the score of
the null answer for an example giving several features is the minimum of the scores for the null answer on
each feature: all features must be aligned on the fact they `want` to predict a null answer).
Only useful when :obj:`version_2_with_negative` is :obj:`True`.
output_dir (:obj:`str`, `optional`):
If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if
:obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null
answers, are saved in `output_dir`.
prefix (:obj:`str`, `optional`):
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
``logging`` log level (e.g., ``logging.WARNING``)
"""
if len(predictions) != 2:
raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).")
all_start_logits, all_end_logits = predictions
if len(predictions[0]) != len(features):
raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
# Build a map example to its corresponding features.
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
features_per_example = collections.defaultdict(list)
for i, feature in enumerate(features):
features_per_example[example_id_to_index[feature["example_id"]]].append(i)
# The dictionaries we have to fill.
all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict()
if version_2_with_negative:
scores_diff_json = collections.OrderedDict()
# Logging.
logger.setLevel(log_level)
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
# Let's loop over all the examples!
for example_index, example in enumerate(tqdm(examples)):
# Those are the indices of the features associated to the current example.
feature_indices = features_per_example[example_index]
min_null_prediction = None
prelim_predictions = []
# Looping through all the features associated to the current example.
for feature_index in feature_indices:
# We grab the predictions of the model for this feature.
start_logits = all_start_logits[feature_index]
end_logits = all_end_logits[feature_index]
# This is what will allow us to map some the positions in our logits to span of texts in the original
# context.
offset_mapping = features[feature_index]["offset_mapping"]
# Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
# available in the current feature.
token_is_max_context = features[feature_index].get("token_is_max_context", None)
# Update minimum null prediction.
feature_null_score = start_logits[0] + end_logits[0]
if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
min_null_prediction = {
"offsets": (0, 0),
"score": feature_null_score,
"start_logit": start_logits[0],
"end_logit": end_logits[0],
}
# Go through all possibilities for the `n_best_size` greater start and end logits.
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
for start_index in start_indexes:
for end_index in end_indexes:
# Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
# to part of the input_ids that are not in the context.
if (
start_index >= len(offset_mapping)
or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None
or len(offset_mapping[start_index]) < 2
or offset_mapping[end_index] is None
or len(offset_mapping[end_index]) < 2
):
continue
# Don't consider answers with a length that is either < 0 or > max_answer_length.
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
continue
# Don't consider answer that don't have the maximum context available (if such information is
# provided).
if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
continue
prelim_predictions.append(
{
"offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
"score": start_logits[start_index] + end_logits[end_index],
"start_logit": start_logits[start_index],
"end_logit": end_logits[end_index],
}
)
if version_2_with_negative and min_null_prediction is not None:
# Add the minimum null prediction
prelim_predictions.append(min_null_prediction)
null_score = min_null_prediction["score"]
# Only keep the best `n_best_size` predictions.
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
# Add back the minimum null prediction if it was removed because of its low score.
if (
version_2_with_negative
and min_null_prediction is not None
and not any(p["offsets"] == (0, 0) for p in predictions)
):
predictions.append(min_null_prediction)
# Use the offsets to gather the answer text in the original context.
context = example["context"]
for pred in predictions:
offsets = pred.pop("offsets")
pred["text"] = context[offsets[0] : offsets[1]]
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
# failure.
if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""):
predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0})
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
# the LogSumExp trick).
scores = np.array([pred.pop("score") for pred in predictions])
exp_scores = np.exp(scores - np.max(scores))
probs = exp_scores / exp_scores.sum()
# Include the probabilities in our predictions.
for prob, pred in zip(probs, predictions):
pred["probability"] = prob
# Pick the best prediction. If the null answer is not possible, this is easy.
if not version_2_with_negative:
all_predictions[example["id"]] = predictions[0]["text"]
else:
# Otherwise we first need to find the best non-empty prediction.
i = 0
while predictions[i]["text"] == "":
i += 1
best_non_null_pred = predictions[i]
# Then we compare to the null prediction using the threshold.
score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable.
if score_diff > null_score_diff_threshold:
all_predictions[example["id"]] = ""
else:
all_predictions[example["id"]] = best_non_null_pred["text"]
# Make `predictions` JSON-serializable by casting np.float back to float.
all_nbest_json[example["id"]] = [
{k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
for pred in predictions
]
# If we have an output_dir, let's save all those dicts.
if output_dir is not None:
if not os.path.isdir(output_dir):
raise OSError(f"{output_dir} is not a directory.")
prediction_file = os.path.join(
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
)
nbest_file = os.path.join(
output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
)
if version_2_with_negative:
null_odds_file = os.path.join(
output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
)
logger.info(f"Saving predictions to {prediction_file}.")
with open(prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n")
logger.info(f"Saving nbest_preds to {nbest_file}.")
with open(nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
if version_2_with_negative:
logger.info(f"Saving null_odds to {null_odds_file}.")
with open(null_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
return all_predictions
def postprocess_qa_predictions_with_beam_search(
examples,
features,
predictions: tuple[np.ndarray, np.ndarray],
version_2_with_negative: bool = False,
n_best_size: int = 20,
max_answer_length: int = 30,
start_n_top: int = 5,
end_n_top: int = 5,
output_dir: Optional[str] = None,
prefix: Optional[str] = None,
log_level: Optional[int] = logging.WARNING,
):
"""
Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the
original contexts. This is the postprocessing functions for models that return start and end logits, indices, as well as
cls token predictions.
Args:
examples: The non-preprocessed dataset (see the main script for more information).
features: The processed dataset (see the main script for more information).
predictions (:obj:`tuple[np.ndarray, np.ndarray]`):
The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
first dimension must match the number of elements of :obj:`features`.
version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the underlying dataset contains examples with no answers.
n_best_size (:obj:`int`, `optional`, defaults to 20):
The total number of n-best predictions to generate when looking for an answer.
max_answer_length (:obj:`int`, `optional`, defaults to 30):
The maximum length of an answer that can be generated. This is needed because the start and end predictions
are not conditioned on one another.
start_n_top (:obj:`int`, `optional`, defaults to 5):
The number of top start logits too keep when searching for the :obj:`n_best_size` predictions.
end_n_top (:obj:`int`, `optional`, defaults to 5):
The number of top end logits too keep when searching for the :obj:`n_best_size` predictions.
output_dir (:obj:`str`, `optional`):
If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if
:obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null
answers, are saved in `output_dir`.
prefix (:obj:`str`, `optional`):
If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
``logging`` log level (e.g., ``logging.WARNING``)
"""
if len(predictions) != 5:
raise ValueError("`predictions` should be a tuple with five elements.")
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions
if len(predictions[0]) != len(features):
raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
# Build a map example to its corresponding features.
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
features_per_example = collections.defaultdict(list)
for i, feature in enumerate(features):
features_per_example[example_id_to_index[feature["example_id"]]].append(i)
# The dictionaries we have to fill.
all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict()
scores_diff_json = collections.OrderedDict() if version_2_with_negative else None
# Logging.
logger.setLevel(log_level)
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
# Let's loop over all the examples!
for example_index, example in enumerate(tqdm(examples)):
# Those are the indices of the features associated to the current example.
feature_indices = features_per_example[example_index]
min_null_score = None
prelim_predictions = []
# Looping through all the features associated to the current example.
for feature_index in feature_indices:
# We grab the predictions of the model for this feature.
start_log_prob = start_top_log_probs[feature_index]
start_indexes = start_top_index[feature_index]
end_log_prob = end_top_log_probs[feature_index]
end_indexes = end_top_index[feature_index]
feature_null_score = cls_logits[feature_index]
# This is what will allow us to map some the positions in our logits to span of texts in the original
# context.
offset_mapping = features[feature_index]["offset_mapping"]
# Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
# available in the current feature.
token_is_max_context = features[feature_index].get("token_is_max_context", None)
# Update minimum null prediction
if min_null_score is None or feature_null_score < min_null_score:
min_null_score = feature_null_score
# Go through all possibilities for the `n_start_top`/`n_end_top` greater start and end logits.
for i in range(start_n_top):
for j in range(end_n_top):
start_index = int(start_indexes[i])
j_index = i * end_n_top + j
end_index = int(end_indexes[j_index])
# Don't consider out-of-scope answers (last part of the test should be unnecessary because of the
# p_mask but let's not take any risk)
if (
start_index >= len(offset_mapping)
or end_index >= len(offset_mapping)
or offset_mapping[start_index] is None
or len(offset_mapping[start_index]) < 2
or offset_mapping[end_index] is None
or len(offset_mapping[end_index]) < 2
):
continue
# Don't consider answers with a length negative or > max_answer_length.
if end_index < start_index or end_index - start_index + 1 > max_answer_length:
continue
# Don't consider answer that don't have the maximum context available (if such information is
# provided).
if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
continue
prelim_predictions.append(
{
"offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
"score": start_log_prob[i] + end_log_prob[j_index],
"start_log_prob": start_log_prob[i],
"end_log_prob": end_log_prob[j_index],
}
)
# Only keep the best `n_best_size` predictions.
predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
# Use the offsets to gather the answer text in the original context.
context = example["context"]
for pred in predictions:
offsets = pred.pop("offsets")
pred["text"] = context[offsets[0] : offsets[1]]
# In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
# failure.
if len(predictions) == 0:
# Without predictions min_null_score is going to be None and None will cause an exception later
min_null_score = -2e-6
predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": min_null_score})
# Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
# the LogSumExp trick).
scores = np.array([pred.pop("score") for pred in predictions])
exp_scores = np.exp(scores - np.max(scores))
probs = exp_scores / exp_scores.sum()
# Include the probabilities in our predictions.
for prob, pred in zip(probs, predictions):
pred["probability"] = prob
# Pick the best prediction and set the probability for the null answer.
all_predictions[example["id"]] = predictions[0]["text"]
if version_2_with_negative:
scores_diff_json[example["id"]] = float(min_null_score)
# Make `predictions` JSON-serializable by casting np.float back to float.
all_nbest_json[example["id"]] = [
{k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
for pred in predictions
]
# If we have an output_dir, let's save all those dicts.
if output_dir is not None:
if not os.path.isdir(output_dir):
raise OSError(f"{output_dir} is not a directory.")
prediction_file = os.path.join(
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
)
nbest_file = os.path.join(
output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
)
if version_2_with_negative:
null_odds_file = os.path.join(
output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
)
logger.info(f"Saving predictions to {prediction_file}.")
with open(prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n")
logger.info(f"Saving nbest_preds to {nbest_file}.")
with open(nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
if version_2_with_negative:
logger.info(f"Saving null_odds to {null_odds_file}.")
with open(null_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
return all_predictions, scores_diff_json

View File

@ -1,68 +0,0 @@
<!---
Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Automatic Speech Recognition - Flax Examples
## Sequence to Sequence
The script [`run_flax_speech_recognition_seq2seq.py`](https://github.com/huggingface/transformers/blob/main/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py)
can be used to fine-tune any [Flax Speech Sequence-to-Sequence Model](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.FlaxAutoModelForSpeechSeq2Seq)
for automatic speech recognition on one of the [official speech recognition datasets](https://huggingface.co/datasets?task_ids=task_ids:automatic-speech-recognition)
or a custom dataset. This includes the Whisper model from OpenAI, or a warm-started Speech-Encoder-Decoder Model,
an example for which is included below.
### Whisper Model
We can load all components of the Whisper model directly from the pretrained checkpoint, including the pretrained model
weights, feature extractor and tokenizer. We simply have to specify the id of fine-tuning dataset and the necessary
training hyperparameters.
The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint
on the Hindi subset of the [Common Voice 13](https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0) dataset.
Note that before running this script you must accept the dataset's [terms of use](https://huggingface.co/datasets/mozilla-foundation/common_voice_13_0)
and register your Hugging Face Hub token on your device by running `huggingface-hub login`.
```bash
python run_flax_speech_recognition_seq2seq.py \
--model_name_or_path="openai/whisper-small" \
--dataset_name="mozilla-foundation/common_voice_13_0" \
--dataset_config_name="hi" \
--language="hindi" \
--train_split_name="train+validation" \
--eval_split_name="test" \
--output_dir="./whisper-small-hi-flax" \
--per_device_train_batch_size="16" \
--per_device_eval_batch_size="16" \
--num_train_epochs="10" \
--learning_rate="1e-4" \
--warmup_steps="500" \
--logging_steps="25" \
--generation_max_length="40" \
--preprocessing_num_workers="32" \
--dataloader_num_workers="32" \
--max_duration_in_seconds="30" \
--text_column_name="sentence" \
--overwrite_output_dir \
--do_train \
--do_eval \
--predict_with_generate \
--push_to_hub \
--use_auth_token
```
On a TPU v4-8, training should take approximately 25 minutes, with a final cross-entropy loss of 0.02 and word error
rate of **34%**. See the checkpoint [sanchit-gandhi/whisper-small-hi-flax](https://huggingface.co/sanchit-gandhi/whisper-small-hi-flax)
for an example training run.

View File

@ -1,8 +0,0 @@
datasets[audio]>=2.14.0
jax>=0.3.6
jaxlib>=0.3.6
flax>=0.4.1
optax>=0.0.8
torch>=1.9.0
jiwer
evaluate

View File

@ -1,877 +0,0 @@
#!/usr/bin/env python
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the Flax library models for sequence to sequence speech recognition.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
import logging
import os
import sys
import time
from dataclasses import field
from functools import partial
from pathlib import Path
from typing import Any, Callable, Optional, Union
import datasets
import evaluate
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from datasets import DatasetDict, load_dataset
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import HfApi
from torch.utils.data import DataLoader
from tqdm import tqdm
import transformers
from transformers import (
AutoConfig,
AutoFeatureExtractor,
AutoProcessor,
AutoTokenizer,
FlaxAutoModelForSpeechSeq2Seq,
HfArgumentParser,
Seq2SeqTrainingArguments,
is_tensorboard_available,
)
from transformers.file_utils import get_full_repo_name
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risk.
check_min_version("4.56.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/flax/speech-recognition/requirements.txt")
logger = logging.getLogger(__name__)
@flax.struct.dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
feature_extractor_name: Optional[str] = field(
default=None, metadata={"help": "feature extractor name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers login` (necessary to use this script "
"with private models)."
},
)
dtype: Optional[str] = field(
default="float32",
metadata={
"help": (
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
" `[float32, float16, bfloat16]`."
)
},
)
num_beams: Optional[int] = field(
default=None,
metadata={
"help": (
"Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
"which is used during evaluation."
)
},
)
@flax.struct.dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: str = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether to trust the execution of code from datasets/models defined on the Hub."
" This option should only be set to `True` for repositories you trust and in which you have read the"
" code, as it will execute code present on the Hub on your local machine."
)
},
)
text_column: Optional[str] = field(
default=None,
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
)
dataset_cache_dir: Optional[str] = field(
default=None, metadata={"help": "Path to cache directory for saving and loading datasets"}
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
},
)
audio_column_name: str = field(
default="audio",
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
)
text_column_name: str = field(
default="text",
metadata={"help": "The name of the dataset column containing the text data. Defaults to 'text'"},
)
max_duration_in_seconds: float = field(
default=20.0,
metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"},
)
min_duration_in_seconds: float = field(
default=0.0,
metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"},
)
max_label_length: float = field(
default=128,
metadata={"help": "Truncate transcriptions that are longer `max_eval_length` tokens."},
)
pad_input_to_multiple_of: Optional[int] = field(
default=None,
metadata={
"help": "If set will pad the input sequence to a multiple of the provided value. "
"This is important to avoid triggering recompilations on TPU. If unspecified, will default to padding the inputs to max length."
},
)
pad_target_to_multiple_of: Optional[int] = field(
default=None,
metadata={
"help": "If set will pad the target sequence to a multiple of the provided value. "
"This is important to avoid triggering recompilations on TPU. If unspecified, will default to padding the targets to max length."
},
)
preprocessing_only: bool = field(
default=False,
metadata={
"help": "Whether to only do data preprocessing and skip training. "
"This is especially useful when data preprocessing errors out in distributed training due to timeout. "
"In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
"so that the cached datasets can consequently be loaded in distributed training"
},
)
train_split_name: str = field(
default="train",
metadata={
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'"
},
)
eval_split_name: str = field(
default="validation",
metadata={
"help": "The name of the evaluation data set split to use (via the datasets library). Defaults to 'validation'"
},
)
do_lower_case: bool = field(
default=True,
metadata={"help": "Whether the target text should be lower cased."},
)
language: str = field(
default=None,
metadata={
"help": (
"Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
"only. For English speech recognition, it should be set to `None`."
)
},
)
task: str = field(
default="transcribe",
metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
)
def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
"""
Shift label ids one token to the right.
"""
shifted_label_ids = np.zeros_like(label_ids)
shifted_label_ids[:, 1:] = label_ids[:, :-1]
shifted_label_ids[:, 0] = decoder_start_token_id
return shifted_label_ids
@flax.struct.dataclass
class FlaxDataCollatorSpeechSeq2SeqWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
processor ([`Wav2Vec2Processor`])
The processor used for processing the data.
decoder_start_token_id (:obj: `int`)
The begin-of-sentence of the decoder.
input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
See above for details.
max_input_length (:obj:`float`, `optional`):
Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
max_target_length (:obj:`int`, `optional`):
Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
pad_input_to_multiple_of (:obj:`int`, `optional`):
If set will pad the input sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
pad_target_to_multiple_of (:obj:`int`, `optional`):
If set will pad the target sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""
processor: Any
decoder_start_token_id: int
input_padding: Union[bool, str] = "longest"
target_padding: Union[bool, str] = "max_length"
max_input_length: Optional[float] = None
max_target_length: Optional[int] = None
pad_input_to_multiple_of: Optional[int] = None
pad_target_to_multiple_of: Optional[int] = None
def __call__(self, features: list[dict[str, Union[list[int], np.ndarray]]]) -> dict[str, np.ndarray]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
model_input_name = self.processor.model_input_names[0]
# dataloader returns a list of features which we convert to a dict
input_features = {model_input_name: [feature[model_input_name] for feature in features]}
label_features = {"input_ids": [feature["labels"] for feature in features]}
# reformat list to dict and set to pytorch format
batch = self.processor.feature_extractor.pad(
input_features,
max_length=self.max_input_length,
padding=self.input_padding,
pad_to_multiple_of=self.pad_input_to_multiple_of,
return_tensors="np",
)
labels_batch = self.processor.tokenizer.pad(
label_features,
max_length=self.max_target_length,
padding=self.target_padding,
pad_to_multiple_of=self.pad_target_to_multiple_of,
return_tensors="np",
)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
labels = labels_batch["input_ids"]
if (labels[:, 0] == self.decoder_start_token_id).all().item():
labels = labels[:, 1:]
labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
decoder_input_ids = shift_tokens_right(labels, self.decoder_start_token_id)
# replace padding with -100 to ignore correctly when computing the loss
labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
labels = labels.filled(fill_value=-100)
batch["labels"] = labels
batch["decoder_input_ids"] = decoder_input_ids
return batch
class TrainState(train_state.TrainState):
dropout_rng: jnp.ndarray
def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
for key, vals in train_metrics.items():
tag = f"train_{key}"
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
def create_learning_rate_fn(
num_train_steps: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
def main():
# 1. Parse input arguments
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your JAX/Flax versions.
send_example_telemetry("run_speech_recognition_seq2seq", model_args, data_args, framework="flax")
# 2. Setup logging
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
# Set the verbosity to info of the Transformers logger.
# We only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
logger.info("Training/evaluation parameters %s", training_args)
# Check the output dir is valid
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use `--overwrite_output_dir` to overcome."
)
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
# Create repo and retrieve repo_id
api = HfApi()
repo_id = api.create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
# 3. Load dataset
raw_datasets = DatasetDict()
if training_args.do_train:
raw_datasets["train"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=data_args.train_split_name,
cache_dir=data_args.dataset_cache_dir,
num_proc=data_args.preprocessing_num_workers,
token=True if model_args.use_auth_token else None,
trust_remote_code=data_args.trust_remote_code,
)
if training_args.do_eval:
raw_datasets["eval"] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=data_args.eval_split_name,
cache_dir=data_args.dataset_cache_dir,
num_proc=data_args.preprocessing_num_workers,
token=True if model_args.use_auth_token else None,
trust_remote_code=data_args.trust_remote_code,
)
if not training_args.do_train and not training_args.do_eval:
raise ValueError(
"Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
)
if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
raise ValueError(
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--audio_column_name` to the correct audio column - one of "
f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
)
if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
raise ValueError(
f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. "
"Make sure to set `--text_column_name` to the correct text column - one of "
f"{', '.join(next(iter(raw_datasets.values())).column_names)}."
)
# 5. Load pretrained model, tokenizer, and feature extractor
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=True if model_args.use_auth_token else None,
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=True if model_args.use_auth_token else None,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
token=True if model_args.use_auth_token else None,
)
model = FlaxAutoModelForSpeechSeq2Seq.from_pretrained(
model_args.model_name_or_path,
config=config,
dtype=getattr(jnp, model_args.dtype),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=True if model_args.use_auth_token else None,
)
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
# 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
# so we just need to set the correct target sampling rate.
raw_datasets = raw_datasets.cast_column(
data_args.audio_column_name, datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate)
)
# 7. Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the targets.
max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
max_label_length = (
data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
)
pad_input_to_multiple_of = data_args.pad_input_to_multiple_of
pad_target_to_multiple_of = data_args.pad_target_to_multiple_of
audio_column_name = data_args.audio_column_name
num_workers = data_args.preprocessing_num_workers
text_column_name = data_args.text_column_name
model_input_name = feature_extractor.model_input_names[0]
do_lower_case = data_args.do_lower_case
if training_args.do_train and data_args.max_train_samples is not None:
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
if training_args.do_eval and data_args.max_eval_samples is not None:
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
if data_args.language is not None:
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
def prepare_dataset(batch):
# process audio
sample = batch[audio_column_name]
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
# process audio length
batch[model_input_name] = inputs.get(model_input_name)[0]
batch["input_length"] = len(sample["array"])
# process targets
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
batch["labels"] = tokenizer(input_str).input_ids
return batch
vectorized_datasets = raw_datasets.map(
prepare_dataset,
remove_columns=next(iter(raw_datasets.values())).column_names,
num_proc=num_workers,
desc="preprocess train and eval dataset",
)
# filter training data with inputs longer than max_input_length
def is_audio_in_length_range(length):
return min_input_length < length < max_input_length
vectorized_datasets = vectorized_datasets.filter(
is_audio_in_length_range,
num_proc=num_workers,
input_columns=["input_length"],
)
# for large datasets it is advised to run the preprocessing on a
# single machine first with `args.preprocessing_only` since there will mostly likely
# be a timeout when running the script in distributed mode.
# In a second step `args.preprocessing_only` can then be set to `False` to load the
# cached dataset
if data_args.preprocessing_only:
cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
logger.info(f"Data preprocessing finished. Files cached at {cache}.")
return
# 8. Load Metric
metric = evaluate.load("wer", cache_dir=model_args.cache_dir)
def compute_metrics(preds, labels):
# replace padded labels by the padding token
for idx in range(len(labels)):
labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True)
# we do not want to group tokens when computing the metrics
label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
wer = metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
# 9. Save feature extractor, tokenizer and config
feature_extractor.save_pretrained(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)
config.save_pretrained(training_args.output_dir)
processor = AutoProcessor.from_pretrained(training_args.output_dir)
data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id,
input_padding="longest",
target_padding="longest",
max_target_length=max_label_length,
pad_input_to_multiple_of=pad_input_to_multiple_of,
pad_target_to_multiple_of=pad_target_to_multiple_of if pad_target_to_multiple_of else max_label_length,
)
# Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if has_tensorboard and jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
except ImportError as ie:
has_tensorboard = False
logger.warning(
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
# Initialize our training
rng = jax.random.PRNGKey(training_args.seed)
rng, dropout_rng = jax.random.split(rng)
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(vectorized_datasets["train"]) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
total_train_steps,
training_args.warmup_steps,
training_args.learning_rate,
)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layer_norm", "self_attn_layer_norm", "final_layer_norm", "encoder_attn_layer_norm"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
# label smoothed cross entropy
def loss_fn(logits, labels, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss, i.e. where labels are not set to -100
padding_mask = labels >= 0
loss = loss * padding_mask
loss = loss.sum()
num_labels = padding_mask.sum()
return loss, num_labels
# Define gradient update step fn
def train_step(state, batch, label_smoothing_factor=0.0):
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
return loss, num_labels
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
(loss, num_labels), grad = grad_fn(state.params)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
# true grad = total grad / total samples
grad = jax.lax.psum(grad, "batch")
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
return new_state, metrics
# Define eval fn
def eval_step(params, batch, label_smoothing_factor=0.0):
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
metrics = {"loss": loss}
return metrics
# Define generation function
num_beams = model_args.num_beams if model_args.num_beams is not None else model.config.num_beams
gen_kwargs = {"max_length": max_label_length, "num_beams": num_beams}
def generate_step(params, batch):
model.params = params
output_ids = model.generate(batch[model_input_name], attention_mask=batch.get("attention_mask"), **gen_kwargs)
return output_ids.sequences
# Create parallel version of the train and eval step
p_train_step = jax.pmap(
partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0,)
)
p_eval_step = jax.pmap(partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch")
p_generate_step = jax.pmap(generate_step, "batch")
# Replicate the train state on each device
state = state.replicate()
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(vectorized_datasets['train'])}")
logger.info(f" Num Epochs = {num_epochs}")
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
logger.info(f" Total optimization steps = {total_train_steps}")
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
train_metrics = []
# Generate an epoch by shuffling sampling indices from the train dataset and create a data loader
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
train_loader = DataLoader(
vectorized_datasets["train"],
batch_size=train_batch_size,
drop_last=True,
collate_fn=data_collator,
num_workers=training_args.dataloader_num_workers,
)
# train
for batch in tqdm(train_loader, desc="Training...", position=1, leave=False):
batch = shard(batch.data)
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
train_time += time.time() - train_start
train_metric = unreplicate(train_metric)
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
f" {train_metric['learning_rate']})"
)
# ======================== Evaluating ==============================
eval_metrics = []
eval_preds = []
eval_labels = []
eval_loader = DataLoader(
vectorized_datasets["eval"],
batch_size=eval_batch_size,
drop_last=False,
collate_fn=data_collator,
num_workers=training_args.dataloader_num_workers,
)
for batch in tqdm(eval_loader, desc="Evaluating...", position=2, leave=False):
# Model forward
labels = batch["labels"]
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, batch.data, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)
# generation
if training_args.predict_with_generate:
generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch.data)
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
eval_labels.extend(labels)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# compute WER metric
wer_desc = ""
if training_args.predict_with_generate:
wer_metric = compute_metrics(eval_preds, eval_labels)
eval_metrics.update(wer_metric)
wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
# Print metrics and update progress bar
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {wer_desc})"
epochs.write(desc)
epochs.desc = desc
# Save metrics
if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size)
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
api.upload_folder(
commit_message=f"Saving weights and logs of epoch {epoch}",
folder_path=training_args.output_dir,
repo_id=repo_id,
repo_type="model",
token=training_args.hub_token,
)
if __name__ == "__main__":
main()

View File

@ -1,35 +0,0 @@
# Summarization (Seq2Seq model) training examples
The following example showcases how to finetune a sequence-to-sequence model for summarization
using the JAX/Flax backend.
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
Models written in JAX/Flax are **immutable** and updated in a purely functional
way which enables simple and efficient model parallelism.
`run_summarization_flax.py` is a lightweight example of how to download and preprocess a dataset from the 🤗 Datasets library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.
For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets#json-files and you also will find examples of these below.
### Train the model
Next we can run the example script to train the model:
```bash
python run_summarization_flax.py \
--output_dir ./bart-base-xsum \
--model_name_or_path facebook/bart-base \
--tokenizer_name facebook/bart-base \
--dataset_name="xsum" \
--do_train --do_eval --do_predict --predict_with_generate \
--num_train_epochs 6 \
--learning_rate 5e-5 --warmup_steps 0 \
--per_device_train_batch_size 64 \
--per_device_eval_batch_size 64 \
--overwrite_output_dir \
--max_source_length 512 --max_target_length 64 \
--push_to_hub
```
This should finish in 37min, with validation loss and ROUGE2 score of 1.7785 and 17.01 respectively after 6 epochs. training statistics can be accessed on [tfhub.dev](https://tensorboard.dev/experiment/OcPfOIgXRMSJqYB4RdK2tA/#scalars).
> Note that here we used default `generate` arguments, using arguments specific for `xsum` dataset should give better ROUGE scores.

View File

@ -1,6 +0,0 @@
datasets >= 1.1.3
jax>=0.2.8
jaxlib>=0.1.59
flax>=0.3.5
optax>=0.0.8
evaluate>=0.2.0

File diff suppressed because it is too large Load Diff

View File

@ -1,284 +0,0 @@
# Copyright 2021 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import logging
import os
import sys
from unittest.mock import patch
from transformers.testing_utils import TestCasePlus, get_gpu_count, slow
SRC_DIRS = [
os.path.join(os.path.dirname(__file__), dirname)
for dirname in [
"text-classification",
"language-modeling",
"summarization",
"token-classification",
"question-answering",
"speech-recognition",
]
]
sys.path.extend(SRC_DIRS)
if SRC_DIRS is not None:
import run_clm_flax
import run_flax_glue
import run_flax_ner
import run_flax_speech_recognition_seq2seq
import run_mlm_flax
import run_qa
import run_summarization_flax
import run_t5_mlm_flax
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
def get_setup_file():
parser = argparse.ArgumentParser()
parser.add_argument("-f")
args = parser.parse_args()
return args.f
def get_results(output_dir, split="eval"):
path = os.path.join(output_dir, f"{split}_results.json")
if os.path.exists(path):
with open(path) as f:
return json.load(f)
raise ValueError(f"can't find {path}")
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class ExamplesTests(TestCasePlus):
def test_run_glue(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_glue.py
--model_name_or_path distilbert/distilbert-base-uncased
--output_dir {tmp_dir}
--train_file ./tests/fixtures/tests_samples/MRPC/train.csv
--validation_file ./tests/fixtures/tests_samples/MRPC/dev.csv
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--learning_rate=1e-4
--eval_steps=2
--warmup_steps=2
--seed=42
--max_seq_length=128
""".split()
with patch.object(sys, "argv", testargs):
run_flax_glue.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
@slow
def test_run_clm(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_clm_flax.py
--model_name_or_path distilbert/distilgpt2
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
--do_train
--do_eval
--block_size 128
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--num_train_epochs 2
--logging_steps 2 --eval_steps 2
--output_dir {tmp_dir}
--overwrite_output_dir
""".split()
with patch.object(sys, "argv", testargs):
run_clm_flax.main()
result = get_results(tmp_dir)
self.assertLess(result["eval_perplexity"], 100)
@slow
def test_run_summarization(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_summarization.py
--model_name_or_path google-t5/t5-small
--train_file tests/fixtures/tests_samples/xsum/sample.json
--validation_file tests/fixtures/tests_samples/xsum/sample.json
--test_file tests/fixtures/tests_samples/xsum/sample.json
--output_dir {tmp_dir}
--overwrite_output_dir
--num_train_epochs=3
--warmup_steps=8
--do_train
--do_eval
--do_predict
--learning_rate=2e-4
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--predict_with_generate
""".split()
with patch.object(sys, "argv", testargs):
run_summarization_flax.main()
result = get_results(tmp_dir, split="test")
self.assertGreaterEqual(result["test_rouge1"], 10)
self.assertGreaterEqual(result["test_rouge2"], 2)
self.assertGreaterEqual(result["test_rougeL"], 7)
self.assertGreaterEqual(result["test_rougeLsum"], 7)
@slow
def test_run_mlm(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_mlm.py
--model_name_or_path distilbert/distilroberta-base
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
--output_dir {tmp_dir}
--overwrite_output_dir
--max_seq_length 128
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--logging_steps 2 --eval_steps 2
--do_train
--do_eval
--num_train_epochs=1
""".split()
with patch.object(sys, "argv", testargs):
run_mlm_flax.main()
result = get_results(tmp_dir)
self.assertLess(result["eval_perplexity"], 42)
@slow
def test_run_t5_mlm(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_t5_mlm_flax.py
--model_name_or_path google-t5/t5-small
--train_file ./tests/fixtures/sample_text.txt
--validation_file ./tests/fixtures/sample_text.txt
--do_train
--do_eval
--max_seq_length 128
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--num_train_epochs 2
--logging_steps 2 --eval_steps 2
--output_dir {tmp_dir}
--overwrite_output_dir
""".split()
with patch.object(sys, "argv", testargs):
run_t5_mlm_flax.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.42)
@slow
def test_run_ner(self):
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
epochs = 7 if get_gpu_count() > 1 else 2
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_flax_ner.py
--model_name_or_path google-bert/bert-base-uncased
--train_file tests/fixtures/tests_samples/conll/sample.json
--validation_file tests/fixtures/tests_samples/conll/sample.json
--output_dir {tmp_dir}
--overwrite_output_dir
--do_train
--do_eval
--warmup_steps=2
--learning_rate=2e-4
--logging_steps 2 --eval_steps 2
--per_device_train_batch_size=2
--per_device_eval_batch_size=2
--num_train_epochs={epochs}
--seed 7
""".split()
with patch.object(sys, "argv", testargs):
run_flax_ner.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_accuracy"], 0.75)
self.assertGreaterEqual(result["eval_f1"], 0.3)
@slow
def test_run_qa(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_qa.py
--model_name_or_path google-bert/bert-base-uncased
--version_2_with_negative
--train_file tests/fixtures/tests_samples/SQUAD/sample.json
--validation_file tests/fixtures/tests_samples/SQUAD/sample.json
--output_dir {tmp_dir}
--overwrite_output_dir
--num_train_epochs=3
--warmup_steps=2
--do_train
--do_eval
--logging_steps 2 --eval_steps 2
--learning_rate=2e-4
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
""".split()
with patch.object(sys, "argv", testargs):
run_qa.main()
result = get_results(tmp_dir)
self.assertGreaterEqual(result["eval_f1"], 30)
self.assertGreaterEqual(result["eval_exact"], 30)
@slow
def test_run_flax_speech_recognition_seq2seq(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_flax_speech_recognition_seq2seq.py
--model_name_or_path openai/whisper-tiny.en
--dataset_name hf-internal-testing/librispeech_asr_dummy
--dataset_config clean
--train_split_name validation
--eval_split_name validation
--output_dir {tmp_dir}
--overwrite_output_dir
--num_train_epochs=2
--max_train_samples 10
--max_eval_samples 10
--warmup_steps=8
--do_train
--do_eval
--learning_rate=2e-4
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--predict_with_generate
""".split()
with patch.object(sys, "argv", testargs):
run_flax_speech_recognition_seq2seq.main()
result = get_results(tmp_dir, split="eval")
self.assertLessEqual(result["eval_wer"], 0.05)

View File

@ -1,108 +0,0 @@
<!---
Copyright 2021 The Google Flax Team Authors and HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Text classification examples
## GLUE tasks
Based on the script [`run_flax_glue.py`](https://github.com/huggingface/transformers/blob/main/examples/flax/text-classification/run_flax_glue.py).
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models) and can also be used for a
dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file (the script might need some tweaks in that case,
refer to the comments inside for help).
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
```bash
export TASK_NAME=mrpc
python run_flax_glue.py \
--model_name_or_path google-bert/bert-base-cased \
--task_name ${TASK_NAME} \
--max_seq_length 128 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--eval_steps 100 \
--output_dir ./$TASK_NAME/ \
--push_to_hub
```
where task name can be one of cola, mnli, mnli_mismatched, mnli_matched, mrpc, qnli, qqp, rte, sst2, stsb, wnli.
Using the command above, the script will train for 3 epochs and run eval after each epoch.
Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`.
You can see the results by running `tensorboard` in that directory:
```bash
$ tensorboard --logdir .
```
or directly on the hub under *Training metrics*.
### Accuracy Evaluation
We train five replicas and report mean accuracy and stdev on the dev set below.
We use the settings as in the command above (with an exception for MRPC and
WNLI which are tiny and where we used 5 epochs instead of 3), and we use a total
train batch size of 32 (we train on 8 Cloud v3 TPUs, so a per-device batch size of 4),
On the task other than MRPC and WNLI we train for 3 these epochs because this is the standard,
but looking at the training curves of some of them (e.g., SST-2, STS-b), it appears the models
are undertrained and we could get better results when training longer.
In the Tensorboard results linked below, the random seed of each model is equal to the ID of the run. So in order to reproduce run 1, run the command above with `--seed=1`. The best run used random seed 3, which is the default in the script. The results of all runs are in [this Google Sheet](https://docs.google.com/spreadsheets/d/1p3XzReMO75m_XdEJvPue-PIq_PN-96J2IJpJW1yS-10/edit?usp=sharing).
| Task | Metric | Acc (best run) | Acc (avg/5runs) | Stdev | Metrics |
|-------|------------------------------|----------------|-----------------|-----------|--------------------------------------------------------------------------|
| CoLA | Matthews corr | 60.57 | 59.04 | 1.06 | [tfhub.dev](https://tensorboard.dev/experiment/lfr2adVpRtmLDALKrElkzg/) |
| SST-2 | Accuracy | 92.66 | 92.23 | 0.57 | [tfhub.dev](https://tensorboard.dev/experiment/jYvfv2trRHKMjoWnXVwrZA/) |
| MRPC | F1/Accuracy | 89.90/85.78 | 88.97/84.36 | 0.72/1.09 | [tfhub.dev](https://tensorboard.dev/experiment/bo3W3DEoRw2Q7YXjWrJkfg/) |
| STS-B | Pearson/Spearman corr. | 89.04/88.70 | 88.94/88.63 | 0.07/0.07 | [tfhub.dev](https://tensorboard.dev/experiment/fxVwbLD7QpKhbot0r9rn2w/) |
| QQP | Accuracy/F1 | 90.81/87.58 | 90.76/87.51 | 0.05/0.06 | [tfhub.dev](https://tensorboard.dev/experiment/di089Rc9TZmsnKRMrYNLsA/) |
| MNLI | Matched acc. | 84.10 | 83.80 | 0.16 | [tfhub.dev](https://tensorboard.dev/experiment/JgNCGHDJSRaW6HBx6YQFYQ/) |
| QNLI | Accuracy | 91.01 | 90.82 | 0.17 | [tfhub.dev](https://tensorboard.dev/experiment/Bq7cMGJnQMSggYgL8qNGeQ/) |
| RTE | Accuracy | 66.06 | 64.76 | 1.04 | [tfhub.dev](https://tensorboard.dev/experiment/66Eq24bhRjqN6CEhgDSGqQ/) |
| WNLI | Accuracy | 46.48 | 37.01 | 6.83 | [tfhub.dev](https://tensorboard.dev/experiment/TAqcnddqTkWvVEeGaWwIdQ/) |
Some of these results are significantly different from the ones reported on the test set of GLUE benchmark on the
website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the website.
### Runtime evaluation
We also ran each task once on a single V100 GPU, 8 V100 GPUs, and 8 Cloud v3 TPUs and report the
overall training time below. For comparison we ran Pytorch's [run_glue.py](https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue.py) on a single GPU (last column).
| Task | TPU v3-8 | 8 GPU | [1 GPU](https://tensorboard.dev/experiment/mkPS4Zh8TnGe1HB6Yzwj4Q) | 1 GPU (Pytorch) |
|-------|-----------|------------|------------|-----------------|
| CoLA | 1m 42s | 1m 26s | 3m 9s | 4m 6s |
| SST-2 | 5m 12s | 6m 28s | 22m 33s | 34m 37s |
| MRPC | 1m 29s | 1m 14s | 2m 20s | 2m 56s |
| STS-B | 1m 30s | 1m 12s | 2m 16s | 2m 48s |
| QQP | 22m 50s | 31m 48s | 1h 59m 41s | 2h 54m |
| MNLI | 25m 03s | 33m 55s | 2h 9m 37s | 3h 7m 6s |
| QNLI | 7m30s | 9m 40s | 34m 40s | 49m 8s |
| RTE | 1m 20s | 55s | 1m 10s | 1m 16s |
| WNLI | 1m 11s | 48s | 39s | 36s |
|-------|
| **TOTAL** | 1h 03m | 1h 28m | 5h 16m | 6h 37m |
*All experiments are ran on Google Cloud Platform.
GPU experiments are ran without further optimizations besides JAX
transformations. GPU experiments are ran with full precision (fp32). "TPU v3-8"
are 8 TPU cores on 4 chips (each chips has 2 cores), while "8 GPU" are 8 GPU chips.

View File

@ -1,5 +0,0 @@
datasets >= 1.1.3
jax>=0.2.8
jaxlib>=0.1.59
flax>=0.3.5
optax>=0.0.8

View File

@ -1,697 +0,0 @@
#!/usr/bin/env python
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE."""
import json
import logging
import math
import os
import random
import sys
import time
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Optional
import datasets
import evaluate
import jax
import jax.numpy as jnp
import numpy as np
import optax
from datasets import load_dataset
from flax import struct, traverse_util
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import HfApi
from tqdm import tqdm
import transformers
from transformers import (
AutoConfig,
AutoTokenizer,
FlaxAutoModelForSequenceClassification,
HfArgumentParser,
PretrainedConfig,
TrainingArguments,
is_tensorboard_available,
)
from transformers.utils import check_min_version, send_example_telemetry
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.56.0.dev0")
Array = Any
Dataset = datasets.arrow_dataset.Dataset
PRNGKey = Any
task_to_keys = {
"cola": ("sentence", None),
"mnli": ("premise", "hypothesis"),
"mrpc": ("sentence1", "sentence2"),
"qnli": ("question", "sentence"),
"qqp": ("question1", "question2"),
"rte": ("sentence1", "sentence2"),
"sst2": ("sentence", None),
"stsb": ("sentence1", "sentence2"),
"wnli": ("sentence1", "sentence2"),
}
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
use_slow_tokenizer: Optional[bool] = field(
default=False,
metadata={"help": "If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library)."},
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
token: str = field(
default=None,
metadata={
"help": (
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `hf auth login` (stored in `~/.huggingface`)."
)
},
)
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
"execute code present on the Hub on your local machine."
)
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
task_name: Optional[str] = field(
default=None, metadata={"help": f"The name of the glue task to train on. choices {list(task_to_keys.keys())}"}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a csv or JSON file)."}
)
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."},
)
test_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."},
)
text_column_name: Optional[str] = field(
default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."}
)
label_column_name: Optional[str] = field(
default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."}
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_seq_length: int = field(
default=None,
metadata={
"help": (
"The maximum total input sequence length after tokenization. If set, sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
)
},
)
def __post_init__(self):
if self.task_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
self.task_name = self.task_name.lower() if isinstance(self.task_name, str) else self.task_name
def create_train_state(
model: FlaxAutoModelForSequenceClassification,
learning_rate_fn: Callable[[int], float],
is_regression: bool,
num_labels: int,
weight_decay: float,
) -> train_state.TrainState:
"""Create initial training state."""
class TrainState(train_state.TrainState):
"""Train state with an Optax optimizer.
The two functions below differ depending on whether the task is classification
or regression.
Args:
logits_fn: Applied to last layer to obtain the logits.
loss_fn: Function to compute the loss.
"""
logits_fn: Callable = struct.field(pytree_node=False)
loss_fn: Callable = struct.field(pytree_node=False)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
tx = optax.adamw(
learning_rate=learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay, mask=decay_mask_fn
)
if is_regression:
def mse_loss(logits, labels):
return jnp.mean((logits[..., 0] - labels) ** 2)
return TrainState.create(
apply_fn=model.__call__,
params=model.params,
tx=tx,
logits_fn=lambda logits: logits[..., 0],
loss_fn=mse_loss,
)
else: # Classification.
def cross_entropy_loss(logits, labels):
xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))
return jnp.mean(xentropy)
return TrainState.create(
apply_fn=model.__call__,
params=model.params,
tx=tx,
logits_fn=lambda logits: logits.argmax(-1),
loss_fn=cross_entropy_loss,
)
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
"""Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices."""
steps_per_epoch = len(dataset) // batch_size
perms = jax.random.permutation(rng, len(dataset))
perms = perms[: steps_per_epoch * batch_size] # Skip incomplete batch.
perms = perms.reshape((steps_per_epoch, batch_size))
for perm in perms:
batch = dataset[perm]
batch = {k: np.array(v) for k, v in batch.items()}
batch = shard(batch)
yield batch
def glue_eval_data_collator(dataset: Dataset, batch_size: int):
"""Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop."""
batch_idx = np.arange(len(dataset))
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()}
yield batch
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if model_args.use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
FutureWarning,
)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_glue", model_args, data_args, framework="flax")
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Setup logging, we only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# Handle the repository creation
if training_args.push_to_hub:
# Retrieve of infer repo_name
repo_name = training_args.hub_model_id
if repo_name is None:
repo_name = Path(training_args.output_dir).absolute().name
# Create repo and retrieve repo_id
api = HfApi()
repo_id = api.create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
# For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
# sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
# label if at least two columns are provided.
# If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
# single column. You can easily tweak this behavior (see below)
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
if data_args.task_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(
"glue",
data_args.task_name,
token=model_args.token,
)
else:
# Loading the dataset from local csv or json file.
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1]
raw_datasets = load_dataset(
extension,
data_files=data_files,
token=model_args.token,
)
# See more about loading any type of standard or custom dataset at
# https://huggingface.co/docs/datasets/loading_datasets.
# Labels
if data_args.task_name is not None:
is_regression = data_args.task_name == "stsb"
if not is_regression:
label_list = raw_datasets["train"].features["label"].names
num_labels = len(label_list)
else:
num_labels = 1
else:
# Trying to have good defaults here, don't hesitate to tweak to your needs.
is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
if is_regression:
num_labels = 1
else:
# A useful fast method:
# https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.unique
label_list = raw_datasets["train"].unique("label")
label_list.sort() # Let's sort it for determinism
num_labels = len(label_list)
# Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
num_labels=num_labels,
finetuning_task=data_args.task_name,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=not model_args.use_slow_tokenizer,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
model = FlaxAutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path,
config=config,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
# Preprocessing the datasets
if data_args.task_name is not None:
sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
else:
# Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
sentence1_key, sentence2_key = "sentence1", "sentence2"
else:
if len(non_label_column_names) >= 2:
sentence1_key, sentence2_key = non_label_column_names[:2]
else:
sentence1_key, sentence2_key = non_label_column_names[0], None
# Some models have set the order of the labels to use, so let's make sure we do use it.
label_to_id = None
if (
model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
and data_args.task_name is not None
and not is_regression
):
# Some have all caps in their config, some don't.
label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
if sorted(label_name_to_id.keys()) == sorted(label_list):
logger.info(
f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
"Using it!"
)
label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
else:
logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: "
f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}."
"\nIgnoring the model labels as a result.",
)
elif data_args.task_name is None:
label_to_id = {v: i for i, v in enumerate(label_list)}
def preprocess_function(examples):
# Tokenize the texts
texts = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*texts, padding="max_length", max_length=data_args.max_seq_length, truncation=True)
if "label" in examples:
if label_to_id is not None:
# Map labels to IDs (not necessary for GLUE tasks)
result["labels"] = [label_to_id[l] for l in examples["label"]]
else:
# In all cases, rename the column to labels because the model will expect that.
result["labels"] = examples["label"]
return result
processed_datasets = raw_datasets.map(
preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names
)
train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"]
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
# Define a summary writer
has_tensorboard = is_tensorboard_available()
if has_tensorboard and jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(training_args.output_dir)
summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)})
except ImportError as ie:
has_tensorboard = False
logger.warning(
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
def write_train_metric(summary_writer, train_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
for key, vals in train_metrics.items():
tag = f"train_{key}"
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
def write_eval_metric(summary_writer, eval_metrics, step):
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
num_epochs = int(training_args.num_train_epochs)
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
train_batch_size = int(training_args.per_device_train_batch_size) * jax.local_device_count()
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
learning_rate_fn = create_learning_rate_fn(
len(train_dataset),
train_batch_size,
training_args.num_train_epochs,
training_args.warmup_steps,
training_args.learning_rate,
)
state = create_train_state(
model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=training_args.weight_decay
)
# define step functions
def train_step(
state: train_state.TrainState, batch: dict[str, Array], dropout_rng: PRNGKey
) -> tuple[train_state.TrainState, float]:
"""Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
targets = batch.pop("labels")
def loss_fn(params):
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = state.loss_fn(logits, targets)
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad)
metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
return new_state, metrics, new_dropout_rng
p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
def eval_step(state, batch):
logits = state.apply_fn(**batch, params=state.params, train=False)[0]
return state.logits_fn(logits)
p_eval_step = jax.pmap(eval_step, axis_name="batch")
if data_args.task_name is not None:
metric = evaluate.load("glue", data_args.task_name, cache_dir=model_args.cache_dir)
else:
metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
logger.info(f"===== Starting training ({num_epochs} epochs) =====")
train_time = 0
# make sure weights are replicated on each device
state = replicate(state)
steps_per_epoch = len(train_dataset) // train_batch_size
total_steps = steps_per_epoch * num_epochs
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (0/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
train_metrics = []
# Create sampling rng
rng, input_rng = jax.random.split(rng)
# train
train_loader = glue_train_data_collator(input_rng, train_dataset, train_batch_size)
for step, batch in enumerate(
tqdm(
train_loader,
total=steps_per_epoch,
desc="Training...",
position=1,
),
):
state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs)
train_metrics.append(train_metric)
cur_step = (epoch * steps_per_epoch) + (step + 1)
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
# Save metrics
train_metric = unreplicate(train_metric)
train_time += time.time() - train_start
if has_tensorboard and jax.process_index() == 0:
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:"
f" {train_metric['learning_rate']})"
)
train_metrics = []
if (cur_step % training_args.eval_steps == 0 or cur_step % steps_per_epoch == 0) and cur_step > 0:
# evaluate
eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm(
eval_loader,
total=math.ceil(len(eval_dataset) / eval_batch_size),
desc="Evaluating ...",
position=2,
):
labels = batch.pop("labels")
predictions = pad_shard_unpad(p_eval_step)(
state, batch, min_device_batch=per_device_eval_batch_size
)
metric.add_batch(predictions=np.array(predictions), references=labels)
eval_metric = metric.compute()
logger.info(f"Step... ({cur_step}/{total_steps} | Eval metrics: {eval_metric})")
if has_tensorboard and jax.process_index() == 0:
write_eval_metric(summary_writer, eval_metric, cur_step)
if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps):
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(unreplicate(state.params))
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
api.upload_folder(
commit_message=f"Saving weights and logs of epoch {epoch}",
folder_path=training_args.output_dir,
repo_id=repo_id,
repo_type="model",
token=training_args.hub_token,
)
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
# save the eval metrics in json
if jax.process_index() == 0:
eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metric, f, indent=4, sort_keys=True)
if __name__ == "__main__":
main()

View File

@ -1,49 +0,0 @@
<!---
Copyright 2021 The Google Flax Team Authors and HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Token classification examples
Fine-tuning the library models for token classification task such as Named Entity Recognition (NER), Parts-of-speech tagging (POS) or phrase extraction (CHUNKS). The main script run_flax_ner.py leverages the 🤗 Datasets library. You can easily customize it to your needs if you need extra processing on your datasets.
It will either run on a datasets hosted on our hub or with your own text files for training and validation, you might just need to add some tweaks in the data preprocessing.
The following example fine-tunes BERT on CoNLL-2003:
```bash
python run_flax_ner.py \
--model_name_or_path google-bert/bert-base-cased \
--dataset_name conll2003 \
--max_seq_length 128 \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--output_dir ./bert-ner-conll2003 \
--eval_steps 300 \
--push_to_hub
```
Using the command above, the script will train for 3 epochs and run eval after each epoch.
Metrics and hyperparameters are stored in Tensorflow event files in `--output_dir`.
You can see the results by running `tensorboard` in that directory:
```bash
$ tensorboard --logdir .
```
or directly on the hub under *Training metrics*.
sample Metrics - [tfhub.dev](https://tensorboard.dev/experiment/u52qsBIpQSKEEXEJd2LVYA)

View File

@ -1,6 +0,0 @@
datasets >= 1.8.0
jax>=0.2.8
jaxlib>=0.1.59
flax>=0.3.5
optax>=0.0.8
seqeval

View File

@ -1,832 +0,0 @@
#!/usr/bin/env python
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fine-tuning a 🤗 Flax Transformers model on token classification tasks (NER, POS, CHUNKS)"""
import json
import logging
import math
import os
import random
import sys
import time
import warnings
from dataclasses import asdict, dataclass, field
from enum import Enum
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Optional
import datasets
import evaluate
import jax
import jax.numpy as jnp
import numpy as np
import optax
from datasets import ClassLabel, load_dataset
from flax import struct, traverse_util
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import HfApi
from tqdm import tqdm
import transformers
from transformers import (
AutoConfig,
AutoTokenizer,
FlaxAutoModelForTokenClassification,
HfArgumentParser,
is_tensorboard_available,
)
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.56.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/token-classification/requirements.txt")
Array = Any
Dataset = datasets.arrow_dataset.Dataset
PRNGKey = Any
@dataclass
class TrainingArguments:
output_dir: str = field(
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)
overwrite_output_dir: bool = field(
default=False,
metadata={
"help": (
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
per_device_train_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
)
per_device_eval_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
)
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
push_to_hub: bool = field(
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
)
hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
def __post_init__(self):
if self.output_dir is not None:
self.output_dir = os.path.expanduser(self.output_dir)
def to_dict(self):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
the token values by removing their value.
"""
d = asdict(self)
for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
d[k] = [x.value for x in v]
if k.endswith("_token"):
d[k] = f"<{k.upper()}>"
return d
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
token: str = field(
default=None,
metadata={
"help": (
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `hf auth login` (stored in `~/.huggingface`)."
)
},
)
use_auth_token: bool = field(
default=None,
metadata={
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether to trust the execution of code from datasets/models defined on the Hub."
" This option should only be set to `True` for repositories you trust and in which you have read the"
" code, as it will execute code present on the Hub on your local machine."
)
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."})
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a csv or JSON file)."}
)
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."},
)
test_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."},
)
text_column_name: Optional[str] = field(
default=None, metadata={"help": "The column name of text to input in the file (a csv or JSON file)."}
)
label_column_name: Optional[str] = field(
default=None, metadata={"help": "The column name of label to input in the file (a csv or JSON file)."}
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
max_seq_length: int = field(
default=None,
metadata={
"help": (
"The maximum total input sequence length after tokenization. If set, sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
)
},
)
label_all_tokens: bool = field(
default=False,
metadata={
"help": (
"Whether to put the label for one word on all tokens of generated by that word or just on the "
"one (in which case the other tokens will have a padding index)."
)
},
)
return_entity_level_metrics: bool = field(
default=False,
metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."},
)
def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
self.task_name = self.task_name.lower()
def create_train_state(
model: FlaxAutoModelForTokenClassification,
learning_rate_fn: Callable[[int], float],
num_labels: int,
training_args: TrainingArguments,
) -> train_state.TrainState:
"""Create initial training state."""
class TrainState(train_state.TrainState):
"""Train state with an Optax optimizer.
The two functions below differ depending on whether the task is classification
or regression.
Args:
logits_fn: Applied to last layer to obtain the logits.
loss_fn: Function to compute the loss.
"""
logits_fn: Callable = struct.field(pytree_node=False)
loss_fn: Callable = struct.field(pytree_node=False)
# We use Optax's "masking" functionality to not apply weight decay
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
# mask boolean with the same structure as the parameters.
# The mask is True for parameters that should be decayed.
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
# find out all LayerNorm parameters
layer_norm_candidates = ["layernorm", "layer_norm", "ln"]
layer_norm_named_params = {
layer[-2:]
for layer_norm_name in layer_norm_candidates
for layer in flat_params
if layer_norm_name in "".join(layer).lower()
}
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
tx = optax.adamw(
learning_rate=learning_rate_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
mask=decay_mask_fn,
)
def cross_entropy_loss(logits, labels):
xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))
return jnp.mean(xentropy)
return TrainState.create(
apply_fn=model.__call__,
params=model.params,
tx=tx,
logits_fn=lambda logits: logits.argmax(-1),
loss_fn=cross_entropy_loss,
)
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
"""Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices."""
steps_per_epoch = len(dataset) // batch_size
perms = jax.random.permutation(rng, len(dataset))
perms = perms[: steps_per_epoch * batch_size] # Skip incomplete batch.
perms = perms.reshape((steps_per_epoch, batch_size))
for perm in perms:
batch = dataset[perm]
batch = {k: np.array(v) for k, v in batch.items()}
batch = shard(batch)
yield batch
def eval_data_collator(dataset: Dataset, batch_size: int):
"""Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop."""
batch_idx = np.arange(len(dataset))
steps_per_epoch = math.ceil(len(dataset) / batch_size)
batch_idx = np.array_split(batch_idx, steps_per_epoch)
for idx in batch_idx:
batch = dataset[idx]
batch = {k: np.array(v) for k, v in batch.items()}
yield batch
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if model_args.use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
FutureWarning,
)
if model_args.token is not None:
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
model_args.token = model_args.use_auth_token
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_ner", model_args, data_args, framework="flax")
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Setup logging, we only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# Handle the repository creation
if training_args.push_to_hub:
# Retrieve of infer repo_name
repo_name = training_args.hub_model_id
if repo_name is None:
repo_name = Path(training_args.output_dir).absolute().name
# Create repo and retrieve repo_id
api = HfApi()
repo_id = api.create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
#
# For CSV/JSON files, this script will use the column called 'tokens' or the first column if no column called
# 'tokens' is found. You can easily tweak this behavior (see below).
#
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
if data_args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
else:
# Loading the dataset from local csv or json file.
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1]
raw_datasets = load_dataset(
extension,
data_files=data_files,
cache_dir=model_args.cache_dir,
token=model_args.token,
)
# See more about loading any type of standard or custom dataset at
# https://huggingface.co/docs/datasets/loading_datasets.
if raw_datasets["train"] is not None:
column_names = raw_datasets["train"].column_names
features = raw_datasets["train"].features
else:
column_names = raw_datasets["validation"].column_names
features = raw_datasets["validation"].features
if data_args.text_column_name is not None:
text_column_name = data_args.text_column_name
elif "tokens" in column_names:
text_column_name = "tokens"
else:
text_column_name = column_names[0]
if data_args.label_column_name is not None:
label_column_name = data_args.label_column_name
elif f"{data_args.task_name}_tags" in column_names:
label_column_name = f"{data_args.task_name}_tags"
else:
label_column_name = column_names[1]
# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
# unique labels.
def get_label_list(labels):
unique_labels = set()
for label in labels:
unique_labels = unique_labels | set(label)
label_list = list(unique_labels)
label_list.sort()
return label_list
if isinstance(features[label_column_name].feature, ClassLabel):
label_list = features[label_column_name].feature.names
# No need to convert the labels since they are already ints.
label_to_id = {i: i for i in range(len(label_list))}
else:
label_list = get_label_list(raw_datasets["train"][label_column_name])
label_to_id = {l: i for i, l in enumerate(label_list)}
num_labels = len(label_list)
# Load pretrained model and tokenizer
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
label2id=label_to_id,
id2label={i: l for l, i in label_to_id.items()},
finetuning_task=data_args.task_name,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
tokenizer_name_or_path = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
if config.model_type in {"gpt2", "roberta"}:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
add_prefix_space=True,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
model = FlaxAutoModelForTokenClassification.from_pretrained(
model_args.model_name_or_path,
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
# Preprocessing the datasets
# Tokenize all texts and align the labels with them.
def tokenize_and_align_labels(examples):
tokenized_inputs = tokenizer(
examples[text_column_name],
max_length=data_args.max_seq_length,
padding="max_length",
truncation=True,
# We use this argument because the texts in our dataset are lists of words (with a label for each word).
is_split_into_words=True,
)
labels = []
for i, label in enumerate(examples[label_column_name]):
word_ids = tokenized_inputs.word_ids(batch_index=i)
previous_word_idx = None
label_ids = []
for word_idx in word_ids:
# Special tokens have a word id that is None. We set the label to -100 so they are automatically
# ignored in the loss function.
if word_idx is None:
label_ids.append(-100)
# We set the label for the first token of each word.
elif word_idx != previous_word_idx:
label_ids.append(label_to_id[label[word_idx]])
# For the other tokens in a word, we set the label to either the current label or -100, depending on
# the label_all_tokens flag.
else:
label_ids.append(label_to_id[label[word_idx]] if data_args.label_all_tokens else -100)
previous_word_idx = word_idx
labels.append(label_ids)
tokenized_inputs["labels"] = labels
return tokenized_inputs
processed_raw_datasets = raw_datasets.map(
tokenize_and_align_labels,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
remove_columns=raw_datasets["train"].column_names,
desc="Running tokenizer on dataset",
)
train_dataset = processed_raw_datasets["train"]
eval_dataset = processed_raw_datasets["validation"]
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
# Define a summary writer
has_tensorboard = is_tensorboard_available()
if has_tensorboard and jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(training_args.output_dir)
summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)})
except ImportError as ie:
has_tensorboard = False
logger.warning(
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
def write_train_metric(summary_writer, train_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
for key, vals in train_metrics.items():
tag = f"train_{key}"
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
def write_eval_metric(summary_writer, eval_metrics, step):
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
num_epochs = int(training_args.num_train_epochs)
rng = jax.random.PRNGKey(training_args.seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count()
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count()
learning_rate_fn = create_learning_rate_fn(
len(train_dataset),
train_batch_size,
training_args.num_train_epochs,
training_args.warmup_steps,
training_args.learning_rate,
)
state = create_train_state(model, learning_rate_fn, num_labels=num_labels, training_args=training_args)
# define step functions
def train_step(
state: train_state.TrainState, batch: dict[str, Array], dropout_rng: PRNGKey
) -> tuple[train_state.TrainState, float]:
"""Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`."""
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
targets = batch.pop("labels")
def loss_fn(params):
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = state.loss_fn(logits, targets)
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad)
metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch")
return new_state, metrics, new_dropout_rng
p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
def eval_step(state, batch):
logits = state.apply_fn(**batch, params=state.params, train=False)[0]
return state.logits_fn(logits)
p_eval_step = jax.pmap(eval_step, axis_name="batch")
metric = evaluate.load("seqeval", cache_dir=model_args.cache_dir)
def get_labels(y_pred, y_true):
# Transform predictions and references tensos to numpy arrays
# Remove ignored index (special tokens)
true_predictions = [
[label_list[p] for (p, l) in zip(pred, gold_label) if l != -100]
for pred, gold_label in zip(y_pred, y_true)
]
true_labels = [
[label_list[l] for (p, l) in zip(pred, gold_label) if l != -100]
for pred, gold_label in zip(y_pred, y_true)
]
return true_predictions, true_labels
def compute_metrics():
results = metric.compute()
if data_args.return_entity_level_metrics:
# Unpack nested dictionaries
final_results = {}
for key, value in results.items():
if isinstance(value, dict):
for n, v in value.items():
final_results[f"{key}_{n}"] = v
else:
final_results[key] = value
return final_results
else:
return {
"precision": results["overall_precision"],
"recall": results["overall_recall"],
"f1": results["overall_f1"],
"accuracy": results["overall_accuracy"],
}
logger.info(f"===== Starting training ({num_epochs} epochs) =====")
train_time = 0
# make sure weights are replicated on each device
state = replicate(state)
train_time = 0
step_per_epoch = len(train_dataset) // train_batch_size
total_steps = step_per_epoch * num_epochs
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
train_metrics = []
# Create sampling rng
rng, input_rng = jax.random.split(rng)
# train
for step, batch in enumerate(
tqdm(
train_data_collator(input_rng, train_dataset, train_batch_size),
total=step_per_epoch,
desc="Training...",
position=1,
)
):
state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs)
train_metrics.append(train_metric)
cur_step = (epoch * step_per_epoch) + (step + 1)
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
# Save metrics
train_metric = unreplicate(train_metric)
train_time += time.time() - train_start
if has_tensorboard and jax.process_index() == 0:
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:"
f" {train_metric['learning_rate']})"
)
train_metrics = []
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
eval_metrics = {}
# evaluate
for batch in tqdm(
eval_data_collator(eval_dataset, eval_batch_size),
total=math.ceil(len(eval_dataset) / eval_batch_size),
desc="Evaluating ...",
position=2,
):
labels = batch.pop("labels")
predictions = pad_shard_unpad(p_eval_step)(
state, batch, min_device_batch=per_device_eval_batch_size
)
predictions = np.array(predictions)
labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
preds, refs = get_labels(predictions, labels)
metric.add_batch(
predictions=preds,
references=refs,
)
eval_metrics = compute_metrics()
if data_args.return_entity_level_metrics:
logger.info(f"Step... ({cur_step}/{total_steps} | Validation metrics: {eval_metrics}")
else:
logger.info(
f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc:"
f" {eval_metrics['accuracy']})"
)
if has_tensorboard and jax.process_index() == 0:
write_eval_metric(summary_writer, eval_metrics, cur_step)
if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps):
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(unreplicate(state.params))
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
api.upload_folder(
commit_message=f"Saving weights and logs of step {cur_step}",
folder_path=training_args.output_dir,
repo_id=repo_id,
repo_type="model",
token=training_args.hub_token,
)
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
# Eval after training
if training_args.do_eval:
eval_metrics = {}
eval_loader = eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm(eval_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2):
labels = batch.pop("labels")
predictions = pad_shard_unpad(p_eval_step)(state, batch, min_device_batch=per_device_eval_batch_size)
predictions = np.array(predictions)
labels[np.array(chain(*batch["attention_mask"])) == 0] = -100
preds, refs = get_labels(predictions, labels)
metric.add_batch(predictions=preds, references=refs)
eval_metrics = compute_metrics()
if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
path = os.path.join(training_args.output_dir, "eval_results.json")
with open(path, "w") as f:
json.dump(eval_metrics, f, indent=4, sort_keys=True)
if __name__ == "__main__":
main()

View File

@ -1,70 +0,0 @@
<!---
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Image Classification training examples
The following example showcases how to train/fine-tune `ViT` for image-classification using the JAX/Flax backend.
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
Models written in JAX/Flax are **immutable** and updated in a purely functional
way which enables simple and efficient model parallelism.
In this example we will train/fine-tune the model on the [imagenette](https://github.com/fastai/imagenette) dataset.
## Prepare the dataset
We will use the [imagenette](https://github.com/fastai/imagenette) dataset to train/fine-tune our model. Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).
### Download and extract the data.
```bash
wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
tar -xvzf imagenette2.tgz
```
This will create a `imagenette2` dir with two subdirectories `train` and `val` each with multiple subdirectories per class. The training script expects the following directory structure
```bash
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
```
## Train the model
Next we can run the example script to fine-tune the model:
```bash
python run_image_classification.py \
--output_dir ./vit-base-patch16-imagenette \
--model_name_or_path google/vit-base-patch16-224-in21k \
--train_dir="imagenette2/train" \
--validation_dir="imagenette2/val" \
--num_train_epochs 5 \
--learning_rate 1e-3 \
--per_device_train_batch_size 128 --per_device_eval_batch_size 128 \
--overwrite_output_dir \
--preprocessing_num_workers 32 \
--push_to_hub
```
This should finish in ~7mins with 99% validation accuracy.

View File

@ -1,8 +0,0 @@
jax>=0.2.8
jaxlib>=0.1.59
flax>=0.3.5
optax>=0.0.8
-f https://download.pytorch.org/whl/torch_stable.html
torch==2.7.1
-f https://download.pytorch.org/whl/torch_stable.html
torchvision==0.12.0+cpu

View File

@ -1,590 +0,0 @@
#!/usr/bin/env python
# Copyright 2021 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Pre-training/Fine-tuning ViT for image classification .
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
https://huggingface.co/models?filter=vit
"""
import logging
import os
import sys
import time
from dataclasses import asdict, dataclass, field
from enum import Enum
from pathlib import Path
from typing import Callable, Optional
import jax
import jax.numpy as jnp
import optax
# for dataset and preprocessing
import torch
import torchvision
from flax import jax_utils
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import HfApi
from torchvision import transforms
from tqdm import tqdm
import transformers
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
AutoConfig,
FlaxAutoModelForImageClassification,
HfArgumentParser,
is_tensorboard_available,
set_seed,
)
from transformers.utils import send_example_telemetry
logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@dataclass
class TrainingArguments:
output_dir: str = field(
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
)
overwrite_output_dir: bool = field(
default=False,
metadata={
"help": (
"Overwrite the content of the output directory. "
"Use this to continue training if output_dir points to a checkpoint directory."
)
},
)
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
per_device_train_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
)
per_device_eval_batch_size: int = field(
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
)
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
push_to_hub: bool = field(
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
)
hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
def __post_init__(self):
if self.output_dir is not None:
self.output_dir = os.path.expanduser(self.output_dir)
def to_dict(self):
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
the token values by removing their value.
"""
d = asdict(self)
for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
d[k] = [x.value for x in v]
if k.endswith("_token"):
d[k] = f"<{k.upper()}>"
return d
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
)
},
)
model_type: Optional[str] = field(
default=None,
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
dtype: Optional[str] = field(
default="float32",
metadata={
"help": (
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
" `[float32, float16, bfloat16]`."
)
},
)
token: str = field(
default=None,
metadata={
"help": (
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
"generated when running `hf auth login` (stored in `~/.huggingface`)."
)
},
)
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
"execute code present on the Hub on your local machine."
)
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
train_dir: str = field(
metadata={"help": "Path to the root training directory which contains one subdirectory per class."}
)
validation_dir: str = field(
metadata={"help": "Path to the root validation directory which contains one subdirectory per class."},
)
image_size: Optional[int] = field(default=224, metadata={"help": " The size (resolution) of each image."})
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
class TrainState(train_state.TrainState):
dropout_rng: jnp.ndarray
def replicate(self):
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
summary_writer.scalar("train_time", train_time, step)
train_metrics = get_metrics(train_metrics)
for key, vals in train_metrics.items():
tag = f"train_{key}"
for i, val in enumerate(vals):
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"eval_{metric_name}", value, step)
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
decay_fn = optax.linear_schedule(
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
)
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
return schedule_fn
def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry("run_image_classification", model_args, data_args, framework="flax")
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Setup logging, we only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0:
transformers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
# Set the verbosity to info of the Transformers logger (on main process only):
logger.info(f"Training/evaluation parameters {training_args}")
# set seed for random transforms and torch dataloaders
set_seed(training_args.seed)
# Handle the repository creation
if training_args.push_to_hub:
# Retrieve of infer repo_name
repo_name = training_args.hub_model_id
if repo_name is None:
repo_name = Path(training_args.output_dir).absolute().name
# Create repo and retrieve repo_id
api = HfApi()
repo_id = api.create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
# Initialize datasets and pre-processing transforms
# We use torchvision here for faster pre-processing
# Note that here we are using some default pre-processing, for maximum accuracy
# one should tune this part and carefully select what transformations to use.
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
train_dataset = torchvision.datasets.ImageFolder(
data_args.train_dir,
transforms.Compose(
[
transforms.RandomResizedCrop(data_args.image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
),
)
eval_dataset = torchvision.datasets.ImageFolder(
data_args.validation_dir,
transforms.Compose(
[
transforms.Resize(data_args.image_size),
transforms.CenterCrop(data_args.image_size),
transforms.ToTensor(),
normalize,
]
),
)
# Load pretrained model and tokenizer
if model_args.config_name:
config = AutoConfig.from_pretrained(
model_args.config_name,
num_labels=len(train_dataset.classes),
image_size=data_args.image_size,
cache_dir=model_args.cache_dir,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
num_labels=len(train_dataset.classes),
image_size=data_args.image_size,
cache_dir=model_args.cache_dir,
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
else:
config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
if model_args.model_name_or_path:
model = FlaxAutoModelForImageClassification.from_pretrained(
model_args.model_name_or_path,
config=config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
)
else:
model = FlaxAutoModelForImageClassification.from_config(
config,
seed=training_args.seed,
dtype=getattr(jnp, model_args.dtype),
trust_remote_code=model_args.trust_remote_code,
)
# Store some constant
num_epochs = int(training_args.num_train_epochs)
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
steps_per_epoch = len(train_dataset) // train_batch_size
total_train_steps = steps_per_epoch * num_epochs
def collate_fn(examples):
pixel_values = torch.stack([example[0] for example in examples])
labels = torch.tensor([example[1] for example in examples])
batch = {"pixel_values": pixel_values, "labels": labels}
batch = {k: v.numpy() for k, v in batch.items()}
return batch
# Create data loaders
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True,
num_workers=data_args.preprocessing_num_workers,
persistent_workers=True,
drop_last=True,
collate_fn=collate_fn,
)
eval_loader = torch.utils.data.DataLoader(
eval_dataset,
batch_size=eval_batch_size,
shuffle=False,
num_workers=data_args.preprocessing_num_workers,
persistent_workers=True,
drop_last=False,
collate_fn=collate_fn,
)
# Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if has_tensorboard and jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
except ImportError as ie:
has_tensorboard = False
logger.warning(
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is not installed: "
"Please run pip install tensorboard to enable."
)
# Initialize our training
rng = jax.random.PRNGKey(training_args.seed)
rng, dropout_rng = jax.random.split(rng)
# Create learning rate schedule
linear_decay_lr_schedule_fn = create_learning_rate_fn(
len(train_dataset),
train_batch_size,
training_args.num_train_epochs,
training_args.warmup_steps,
training_args.learning_rate,
)
# create adam optimizer
adamw = optax.adamw(
learning_rate=linear_decay_lr_schedule_fn,
b1=training_args.adam_beta1,
b2=training_args.adam_beta2,
eps=training_args.adam_epsilon,
weight_decay=training_args.weight_decay,
)
# Setup train state
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
def loss_fn(logits, labels):
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
return loss.mean()
# Define gradient update step fn
def train_step(state, batch):
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
def compute_loss(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = loss_fn(logits, labels)
return loss
grad_fn = jax.value_and_grad(compute_loss)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return new_state, metrics
# Define eval fn
def eval_step(params, batch):
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
loss = loss_fn(logits, labels)
# summarize metrics
accuracy = (jnp.argmax(logits, axis=-1) == labels).mean()
metrics = {"loss": loss, "accuracy": accuracy}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return metrics
# Create parallel version of the train and eval step
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
p_eval_step = jax.pmap(eval_step, "batch")
# Replicate the train state on each device
state = state.replicate()
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_epochs}")
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
logger.info(f" Total optimization steps = {total_train_steps}")
train_time = 0
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
# ======================== Training ================================
train_start = time.time()
# Create sampling rng
rng, input_rng = jax.random.split(rng)
train_metrics = []
steps_per_epoch = len(train_dataset) // train_batch_size
train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
# train
for batch in train_loader:
batch = shard(batch)
state, train_metric = p_train_step(state, batch)
train_metrics.append(train_metric)
train_step_progress_bar.update(1)
train_time += time.time() - train_start
train_metric = unreplicate(train_metric)
train_step_progress_bar.close()
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
f" {train_metric['learning_rate']})"
)
# ======================== Evaluating ==============================
eval_metrics = []
eval_steps = len(eval_dataset) // eval_batch_size
eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
for batch in eval_loader:
# Model forward
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
state.params, batch, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)
eval_step_progress_bar.update(1)
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# Print metrics and update progress bar
eval_step_progress_bar.close()
desc = (
f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {round(eval_metrics['loss'].item(), 4)} | "
f"Eval Accuracy: {round(eval_metrics['accuracy'].item(), 4)})"
)
epochs.write(desc)
epochs.desc = desc
# Save metrics
if has_tensorboard and jax.process_index() == 0:
cur_step = epoch * (len(train_dataset) // train_batch_size)
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params)
if training_args.push_to_hub:
api.upload_folder(
commit_message=f"Saving weights and logs of epoch {epoch}",
folder_path=training_args.output_dir,
repo_id=repo_id,
repo_type="model",
token=training_args.hub_token,
)
if __name__ == "__main__":
main()

View File

@ -13,7 +13,7 @@ cd examples/metrics-monitoring
docker compose up
```
Then, in your srcipt running CB, you will need to create a MeterProvider and TracerProvider as follows:
Then, in your script running CB, you will need to create a MeterProvider and TracerProvider as follows:
```py
from opentelemetry import metrics, trace

View File

@ -136,6 +136,7 @@ class DummyBertSelfAttention(nn.Module):
1, 2
)
is_updated = False
is_cross_attention = encoder_hidden_states is not None
if past_key_values is not None:
if isinstance(past_key_values, EncoderDecoderCache):
@ -170,7 +171,7 @@ class DummyBertSelfAttention(nn.Module):
key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
past_key_values.is_updated[self.layer_idx] = True
# Take the dot product between "query" and "key" to get the raw attention scores.
@ -266,6 +267,7 @@ class DummyBertSdpaSelfAttention(DummyBertSelfAttention):
self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
)
is_updated = False
is_cross_attention = encoder_hidden_states is not None
current_states = encoder_hidden_states if is_cross_attention else hidden_states
if past_key_values is not None:
@ -303,7 +305,7 @@ class DummyBertSdpaSelfAttention(DummyBertSelfAttention):
key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
past_key_values.is_updated[self.layer_idx] = True
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom

View File

@ -232,7 +232,7 @@ class NewTaskModelModel(NewTaskModelPreTrainedModel):
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
):
"""
Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
equal to the length of multimodal features. If the lengths are different, an error is raised.
"""
if input_ids is None:
@ -406,7 +406,7 @@ class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
def get_image_features(self, pixel_values):
return self.model.get_image_features(pixel_values)
# Make modules available throught conditional class for BC
# Make modules available through conditional class for BC
@property
def language_model(self):
return self.model.language_model

View File

@ -139,6 +139,7 @@ class RobertaSelfAttention(nn.Module):
1, 2
)
is_updated = False
is_cross_attention = encoder_hidden_states is not None
if past_key_values is not None:
if isinstance(past_key_values, EncoderDecoderCache):
@ -173,7 +174,7 @@ class RobertaSelfAttention(nn.Module):
key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
past_key_values.is_updated[self.layer_idx] = True
# Take the dot product between "query" and "key" to get the raw attention scores.
@ -269,6 +270,7 @@ class RobertaSdpaSelfAttention(RobertaSelfAttention):
self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
)
is_updated = False
is_cross_attention = encoder_hidden_states is not None
current_states = encoder_hidden_states if is_cross_attention else hidden_states
if past_key_values is not None:
@ -306,7 +308,7 @@ class RobertaSdpaSelfAttention(RobertaSelfAttention):
key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
past_key_values.is_updated[self.layer_idx] = True
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom

View File

@ -55,7 +55,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.56.0.dev0")
check_min_version("4.57.0.dev0")
require_version("datasets>=1.14.0", "To fix: pip install -r examples/pytorch/audio-classification/requirements.txt")

View File

@ -229,7 +229,9 @@ if __name__ == "__main__":
use_cuda_graph=args.use_cuda_graph,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
do_sample=True,
temperature=0.8,
top_p=0.9,
num_blocks=args.num_blocks,
max_batch_tokens=args.max_batch_tokens,
)

View File

@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.56.0.dev0")
check_min_version("4.57.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/contrastive-image-text/requirements.txt")

View File

@ -68,7 +68,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.56.0.dev0")
check_min_version("4.57.0.dev0")
require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt")

View File

@ -61,7 +61,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.56.0.dev0")
check_min_version("4.57.0.dev0")
logger = get_logger(__name__)

View File

@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.56.0.dev0")
check_min_version("4.57.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -56,7 +56,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.56.0.dev0")
check_min_version("4.57.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -61,7 +61,7 @@ Any model supported by the AutoModelForMaskedImageModeling API can be used.
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.56.0.dev0")
check_min_version("4.57.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/image-pretraining/requirements.txt")

View File

@ -57,7 +57,7 @@ from transformers.utils.versions import require_version
logger = logging.getLogger(__name__)
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.56.0.dev0")
check_min_version("4.57.0.dev0")
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/instance-segmentation/requirements.txt")

Some files were not shown because too many files have changed in this diff Show More