Compare commits

..

73 Commits

Author SHA1 Message Date
50042518db Some stuff 2025-09-16 12:08:49 +00:00
571ca0200d Merge branch 'main' into feat/async-checkpointing 2025-09-13 14:16:07 +00:00
0cb1a33475 fix Muti node CUDA error: invalid device ordinal #3775 (#3779) 2025-09-13 15:32:47 +02:00
dfdc219018 use reset_peak_memory_stats on xpu (#3772)
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
2025-09-12 15:05:31 +02:00
45959d7b96 fix FSDP2 test case failure on XPU (#3771)
* fix FSDP2 test case failure on XPU

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
2025-09-12 15:05:05 +02:00
8b493524c8 Fix: typo makes tests fail (#3765) 2025-09-09 12:06:05 +02:00
9ead94e556 fix: torch_npu import error (#3764) 2025-09-09 11:38:57 +02:00
a0bc36e8ed feat: allow mixed precision policy as dtype (#3751)
* feat: allow mixed precision as dtype

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* feat: allow mixed precision as dtype

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* feat: allow mixed precision as dtype

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* test: extend test for MP as str dtype

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* Fix: style

---------

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
Co-authored-by: S1ro1 <matej.sirovatka@gmail.com>
2025-09-08 23:29:20 +02:00
8830e58a91 Fix typos (#3753)
* Fix typos

Signed-off-by: cyy <cyyever@outlook.com>

* Fix: style

---------

Signed-off-by: cyy <cyyever@outlook.com>
Co-authored-by: S1ro1 <matej.sirovatka@gmail.com>
2025-09-08 13:33:18 +02:00
40ebb4bea3 make torch_native_parallelism examples device agnostic (#3759)
* make torch_native_parallelism examples device agnostic

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* xxx

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* xxx

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* Style + deprecation warning

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: S1ro1 <matej.sirovatka@gmail.com>
2025-09-08 12:16:56 +02:00
ec92b1af7a fix: model.set_requires_gradient_sync(False) should be called to turn off gradient synchronization in FSDP2 (#3762)
* fix :`model.set_requires_gradient_sync(False)` should be called to turn off gradient synchronization in FSDP2.

* fix: remove trailing whitespace
2025-09-06 23:57:46 +02:00
62ede1ed2a CP docs typos fixed (#3761) 2025-09-05 12:23:33 +02:00
9f9c490c6b fix: specify device for process_tensor in example usage (#3755) 2025-09-03 11:05:24 +02:00
8b55e62b2c xpu INT64 all_gather issue fixed in 2.9 (#3756)
* xpu gather issue fixed in 2.9 and validated config_yamls on XPU

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* xxx

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
2025-09-03 10:56:14 +02:00
0e4419b347 Add bf16/fp16 support for amp with mps device (#3373)
* Fix tests

* format

* amp mps support for fp16/bf16

* add error

* revert

* revert

* fix

* ruff
2025-08-28 14:20:56 +02:00
3b67c21696 Add support for TE MXFP8 recipe in accelerate (#3688)
* Add support for MXFP8 recipe in accelerate

* ruff reformat

* add and fix test for deepspeed / fp8 from config

* minor lints

Signed-off-by: Peter St. John <pstjohn@nvidia.com>

---------

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
2025-08-27 14:08:34 +02:00
7b981788ca [ND Parallel] Update examples, cleanup (#3737)
* Fix: update cp example

* Feat: add rename examples

* WIP: Cleanup with_trainer

* Feat: more cleanup

* Feat: more refactor + better readme + more configs

* Fin
2025-08-26 14:41:14 +02:00
c4460e33ef fix: specify device_ids in torch.distributed.barrier for PartialState (#3744) 2025-08-26 14:05:33 +02:00
5dd3d0b690 Protect import for device_mesh (#3742) 2025-08-22 15:44:56 +02:00
5fe4460ccd Feat: add to_json (#3743) 2025-08-22 15:25:38 +02:00
979d81e4a9 fix: cpu ram efficient loading for nd or hsdp parallelisms (#3740)
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
2025-08-21 13:40:06 +02:00
7c25f696b8 Fix convert LayerNorm without bias to fp8 (#3725) 2025-08-18 22:28:48 +02:00
a7d6f28f99 feat: add ignored_params support for fsdp2 (#3731)
* feat: add ignored_params support for fsdp2

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* feat: add ignored_params support for fsdp2

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* feat: add ignored_params support for fsdp2

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* feat: add ignored_params support for fsdp2

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* test: update testcase for fsdp2 ignored_params

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: add defensive use of ignored params

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: styling errors

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

---------

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
2025-08-18 14:31:19 +02:00
23cf4ef8a3 Fix tests (#3722)
* fix tests

* fix skorch tests

* fix deepspeed

* pin torch as compile tests don't pass and create segmentation fault

* skip compile tests

* fix

* forgot v ...

* style
2025-08-07 16:59:29 +02:00
ff872f5f71 bump to 1.11.0dev0 2025-08-07 12:58:08 +02:00
2941a6b0fb remove (#3721) 2025-08-07 12:48:11 +02:00
c0a3aefea8 feature: CpuOffload pre_forward don't attempt to move if already on device (#3695)
* feature: added optimisation to not attempt to move devices if allready on that the device. This is more noticiable in large step itterations on diffusion loops when the pre_froward can get called many times

* fix: linting

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-08-06 19:46:13 +02:00
42fdda1c1f Remove ParallelismConfig from PartialState (#3720)
* remove

* style

* fix

* valueerror instead

* add device_mesh
2025-08-06 19:00:26 +02:00
e23b004b30 TST Add test for FSDP ignored_modules as str (#3719)
Follow up to #3698.
2025-08-06 18:05:54 +02:00
898cad39e8 Fix: tp size wouldn't read from env (#3716) 2025-08-06 15:08:55 +02:00
24c8157bba Set parallelism_config in constructor due to Trainer reset of State (#3713) 2025-08-06 13:47:49 +02:00
6891c57072 Feat: context parallel v2.0 (#3700)
* Cleanup: context parallel

* Feat: cleanup

* Feat: concept guide

* Fix: rename + version check

* Style

* Fix: add to namespace in a test

* Fix: add skip_if on dataclass tests

* Fix: proper version for version check

* Feat: add tests and cleanup

* Fix: properly version check added tests

* Feat: address comments

* Fix: add both shift_labels and labels to make the model.forward calculate loss

* Fix: remove import, improve comment

* Fix: final checks

* Fix: style

* Fix: style
2025-08-05 16:17:13 +02:00
24e48f3d20 ENH: Allow FSDP ignored modules to be regex (#3698)
* ENH: Allow FSDP ignored modules to be regex

Description

For FSDP, there is an option to indicate ignored_modules, which should
be a list of modules are ignored by FSDP. Even though this argument was
supported in accelerate, it was not very usable:

1. Listing all modules can tricky, especially with something like PEFT,
where the whole model is wrapped and thus the module structure changes.
2. When configuring this argument, accelerate takes a detour via
environment variables. These can only be strings. Therefore, passing a
list of modules is not feasible.

Moreover, I noticed that the environment variable for ignored_modules
was not even set, so configuring this argument didn't even work.

Status

This PR is lacking tests. I would be happy for pointers on how to add
those.

Context

When using PEFT with LoRA and the target_parameters feature, I ran into
an issue training such a model with FSDP. The only working fix I found
was to ignore the layers targeted by LoRA. However, I could not
configure accelerate to do that. With this PR, it is possible. I could
successfully trained such a PEFT model that targets q_proj and v_proj by
setting fsdp_ignored_modules: '.*\.(q_proj$|v_proj$)'.

* Fix type annotation

* Fix failing test
2025-08-05 14:23:14 +02:00
jp
6640ff415c Fix: Ensure environment variable values are case-insensitive in Accelerate (#3712)
* Add: lower

* apply ruff
2025-08-05 13:22:00 +02:00
c173b4fdd6 Fix: prepare works even if nothing except tp specified (rare) (#3707) 2025-08-05 13:07:37 +02:00
cb343c63d7 Add Parallelism getter property to Accelerator class (#3703)
* Add rank property to Accelerator class

Signed-off-by: WoosungMyung <dntjd517@naver.com>

* Raise errors when parallelism configuration is not enabled

Signed-off-by: WoosungMyung <dntjd517@naver.com>

* Fix: PR feedback

Signed-off-by: WoosungMyung <dntjd517@naver.com>

* Fix: style

---------

Signed-off-by: WoosungMyung <dntjd517@naver.com>
Co-authored-by: S1ro1 <matej.sirovatka@gmail.com>
2025-08-02 18:20:08 +02:00
354b0b5da3 WIP: very much wip but works (probably) 2025-08-01 01:28:49 +00:00
9359a0194f Parallelism config + TP + HSDP + BYODM (Bring Your Own Device Mesh) (#3682)
* Feat: init

* Feat: add validation + init from kwargs

* Fix: minor fixes

* Feat: more cleanup

* Minor refactor

* remove import

* adding support for pre-configured device mesh

* adding device mesh to fsdp2

* moving mesh dim defn to parralismconfig

* tests

* WIP device mesh/accelerator validation

* WIP more tests

* Test Driven Development (TDD)

* fixing build_device_mesh

* FSDP dim names

* adding example

* WIP

* fixing HSDP

* Feat: add back old options

* working example

* debugging

* adding parallelism config to partialstate

* Feat: revert ddp changes

* Revert DDP

* Feat: (untested) update mesh dims and some minor tweaks

* adding dp_cp dims

* updating comments

* WIP

* wip 2

* reverting

* storing state in accelerator rather than acceleratorstate

* Fix: minor tweaks

* wip example update

* Fixes for non-fsdp2 case

* Feat: ensure ddp/tp only works

* updating example

* updating example

* updating examples, fixing state

* fixed state

* comments

* fixing partial state check

* linting

* comments

* removing fn

* WIP: fix tp

* comments

* removing return

* reverting upcast

* add guards

* guards for empty self.parallelism_config

* use len on tuple to check if empty

* Feat: cleanup example

* Feat: some cleanup of example

* Feat: add trackio

* Fix: improve trackio

* Feat: TP works

* Feat: some fsdp2 improv

* Feat: working examples

* handle clipping for tensor parallel

* Implicit replicate

* Refactor: move to separate file + cleanup + basic comments

* Fix: add unadded files, fix circular import

* Feat: better readme

* Feat: add blog + ultrascale links

* Tmp: should_save_model now returns only true

* Fix: remove implicit_replication and style

* Fix: remove optional

* add guard on parallelism_config.tp_enabled

* fix import

* fixing empty parallelism_config

* fix import path for test patch

* fixing patch

---------

Co-authored-by: S1ro1 <matej.sirovatka@gmail.com>
Co-authored-by: Salman Mohammadi <“salman.mohammadi@outlook.com”>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-07-30 21:03:13 +02:00
2f075c724c set default submesh_tp_size to prevent unset local variable error (#3687)
* set default submesh_tp_size to prevent unset local variable error

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-07-22 12:31:03 +02:00
7ecc2d7f39 bump to v1.10.0-release 2025-07-16 16:26:03 +00:00
12f89bb754 do not call partial state if not initialized 2025-07-16 13:42:58 +00:00
348aabaaaf Update Gaudi runner image to latest SynapseAI and enable previously disabled tests (#3653)
* update synapse and add tp tests

* only skip regional compile speedup check

* pass sdp test on hpu
2025-07-16 14:33:36 +02:00
3b13453bbf “Stop Halving My Batch!” · Default back-off 0.5 → 0.9 (#3684)
* feat(memory): change default find_executable_batch_size to change by 10% instead of 50%

* Update test_memory_utils.py

* Apply style fixes

---------

Co-authored-by: Amit Moryossef <amitmoryossef@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-07-16 12:32:46 +02:00
0408ab12d7 warn for invalid keys (#3613)
* warn for invalid keys

* add test for check_device_map invalid keys

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-07-16 12:23:41 +02:00
55e518a762 accelerate/data_loader.py: do not yield if the base_dataloader is empty (#3659)
* accelerate/data_loader.py: do not yield if the base_dataloader is empty

in the code:
```
        dataloader_iter = self.base_dataloader.__iter__()
        # We iterate one batch ahead to check when we are at the end
        try:
            current_batch = next(dataloader_iter)
        except StopIteration:
            yield
```

If the base dataloader is empty then the exception is raised but `yield`
yields nothing.

This at the time of:
```
if self.device is not None:
                    current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
```

would lead to uncaught exception like:
 File "/root/rl-swarm/.venv/lib/python3.10/site-packages/accelerate/data_loader.py", line 575, in iter
    current_batch = send_to_device(current_batch, self.device, non_blocking=self._non_blocking)
UnboundLocalError: local variable 'current_batch' referenced before assignment because `current_batch`
was never assigned because `next(dataloader_iter)` returned with exception `StopIteration`.

Signed-off-by: 0xnightwind <nightwind1899@gmail.com>

* Update src/accelerate/data_loader.py

---------

Signed-off-by: 0xnightwind <nightwind1899@gmail.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
2025-07-16 12:04:25 +02:00
7e11ac43f0 fix: wandb config not saved in offline mode (#3648)
* fix: wandb config not saved in offline mode

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-07-15 17:51:44 +02:00
e2cc537db8 trackio (#3669)
* trackio

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>

* seven -> eight

* Add trackio as a real tracker instead

* Sort

* Style

* Style

* Remove step

* Disable trackio on Python < 3.10

* Update src/accelerate/tracking.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* More style

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
2025-07-15 17:17:49 +02:00
847ae58c74 Fix FP8 tests, enable FP8 to be used without direct Accelerator() configuring (#3677)
* single-gpu tests passing

* install deepspeed in fp8 container

* revert mixed_precision check
2025-07-15 15:20:57 +02:00
6e104f31de unpin datasets (#3681) 2025-07-15 15:00:35 +02:00
524e5f9828 Speedup model loading by 4-5x in Diffusers (#3674)
* update

* update

* make style

* update

* merge if statements
2025-07-11 16:58:35 +02:00
d6c986c3f2 Bunch of FSDP improvements (#3671)
* Feat: split tests

* Feat: finito

* Fix

* Final, tests pass
2025-07-09 16:05:22 +02:00
1ac8643df7 xpu enablement on left cases (#3654)
* 1. enable xpu for launcher 2. expand cuda only ds uts to xpu 3. expand profiler example to xpu

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* rename

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* Update profiler.py

* Apply style fixes

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-07-07 18:10:53 +02:00
07ce74868c Fix: properly error when DDP + Dtensor model (#3629)
* Feat: add check

* Refactor: nits
2025-06-27 01:33:45 +02:00
175fe91589 Added a check in the no_sync() function to avoid errors when using deepspeed zero2/3. (#3656) 2025-06-26 14:39:04 +02:00
fe16ce8bce Fix fsdp2 example (#3657) 2025-06-26 14:08:51 +02:00
5987d79a53 Update gradient_accumulation.md (#3649) 2025-06-23 11:58:31 +02:00
31af8d4e8e shards (#3645) 2025-06-20 11:24:20 +02:00
b7493a82b1 Add support for e5e2 and default to hybrid when launcher is used (#3640)
* add support for e5e2 and defaumt to hybrid when launcher is used

* style
2025-06-20 11:11:32 +02:00
a16d2bb3c1 bump to v1.9.0dev 2025-06-19 15:13:41 +02:00
cac22ed980 fix grad acc deepspeed (#3638)
* fix grad acc deepspeed

* style
2025-06-19 12:06:21 +02:00
be826a6b7b Fix: correct labels (#3637) 2025-06-19 11:01:56 +02:00
5939640829 Feat: add cpu offload (#3636) 2025-06-18 18:13:45 +02:00
7f9c8cbe34 [DeepSpeed] sync gradient accum steps from deepspeed plugin (#3632)
* sync steps

* add a debug log when overriding

* make grad accum always consistent

* remove debug
2025-06-18 16:45:57 +02:00
9888c7ed23 feat: use datasets.IterableDataset shard if possible (#3635)
* feat: use datasets.IterableDataset shard if possible.

When `accelerator.prepare` is called on a
`datasets.IterableDataset`, use the `shard` method to
split the dataset across the available processes. This
allows for more efficient data loading and processing.
Without load and slice overhead of `IterableDatasetShard`

* dataset

* remove unused import

* style

---------

Co-authored-by: wuwenxu.01 <wuwenxu.01@bytedance.com>
2025-06-18 16:45:17 +02:00
42a68c30dc Fix Typos in Documentation and Comments (#3621)
* Update state.py

* Update tracking.py
2025-06-18 15:53:02 +02:00
6597dae780 Integrate SwanLab for offline/online experiment tracking for Accelerate (#3605)
* add support for SwanLabTracker and update related documentation

* add emoji in FRAMWORK

* apply the style corrections and quality control

* add support for SwanLabTracker in tests

* fix bug in test_tracking
2025-06-18 15:42:29 +02:00
8878d93745 remove hardcoded cuda from fsdpv2 (#3631) 2025-06-17 14:32:10 +02:00
2eaf5cdbbc remove ipex.optimize in accelerate (#3608)
* remove ipex.optimize in accelerate

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix mis-style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* Update intel_cpu.md

* Update launch.py

* fix comments

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* add logging

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* Update launch.py

* Apply style fixes

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-06-17 11:08:19 +02:00
23c1d8db89 [Deepspeed] deepspeed auto grad accum (#3630)
* deepspeed auto grad accum

* add tests for grad accum

* use tiny-random-gpt2

* Update tests/deepspeed/test_deepspeed_gradient_accumulation.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix redundant code

* set_gradient_accumulation_boundary is always there

* remove unused helper

* no need for this

* full revert

* Apply style fixes

* get_global_grad_norm is always there

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-06-16 16:28:24 +02:00
0af621bbec add xpu support in TorchTensorParallelPlugin (#3627)
* add xpu support in TorchTensorParallelPlugin

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix typo

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
2025-06-13 17:45:51 +02:00
bee04f1b01 Add fp8_e5m2 support in dtype_byte_size (#3625)
* float8_e5m2 device_map

* remove prints
2025-06-12 16:27:32 +02:00
8a953f08c6 fix xpu 8bit value loading (#3623)
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
2025-06-12 14:55:14 +02:00
3518c03584 small fix (#3619) 2025-06-11 14:02:45 +02:00
86 changed files with 3465 additions and 1108 deletions

View File

@ -15,7 +15,7 @@ jobs:
outputs:
version: ${{ steps.step1.outputs.version }}
steps:
- uses: actions/checkout@4
- uses: actions/checkout@v4
- id: step1
run: echo "version=$(python setup.py --version)" >> $GITHUB_OUTPUT

View File

@ -15,7 +15,7 @@ jobs:
group: itac-bm-emr-gaudi3-dell-2gaudi
container:
image: docker://vault.habana.ai/gaudi-docker/1.20.0/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest
image: docker://vault.habana.ai/gaudi-docker/1.21.1/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest
options: --runtime=habana --shm-size=64G --cap-add=sys_nice --env HABANA_VISIBLE_DEVICES
env:
OMPI_MCA_btl_vader_single_copy_mechanism: none
@ -66,16 +66,21 @@ jobs:
run: |
make test_big_modeling
- name: Run FSDP integration tests
if: ${{ !cancelled() && (success() || failure()) }}
run: |
make test_fsdp
- name: Run DeepSpeed integration tests
if: ${{ !cancelled() && (success() || failure()) }}
run: |
make test_deepspeed
- name: Run FSDP integration tests
if: ${{ !cancelled() && (success() || failure()) }}
run: |
make test_fsdp
- name: Run TP integration tests
if: ${{ !cancelled() && (success() || failure()) }}
run: |
make test_tp
- name: Run Examples tests
if: ${{ !cancelled() && (success() || failure()) }}
run: |

View File

@ -112,7 +112,7 @@ jobs:
cd skorch;
git config --global --add safe.directory '*'
git checkout master && git pull
pip install .[testing]
pip install .[test]
pip install flaky
- name: Show installed libraries

View File

@ -23,16 +23,23 @@ style:
doc-builder style src/accelerate docs/source --max_len 119
# Run tests for the library
test_big_modeling:
python -m pytest -s -v ./tests/test_big_modeling.py ./tests/test_modeling_utils.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_big_modeling.log",)
test_core:
python -m pytest -s -v ./tests/ --ignore=./tests/test_examples.py --ignore=./tests/deepspeed --ignore=./tests/test_big_modeling.py \
--ignore=./tests/fsdp --ignore=./tests/tp --ignore=./tests/test_cli.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_core.log",)
python -m pytest -s -v ./tests/ \
--ignore=./tests/test_big_modeling.py \
--ignore=./tests/test_modeling_utils.py \
--ignore=./tests/test_examples.py \
--ignore=./tests/test_cli.py \
--ignore=./tests/deepspeed \
--ignore=./tests/fsdp \
--ignore=./tests/tp \
$(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_core.log",)
test_cli:
python -m pytest -s -v ./tests/test_cli.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_cli.log",)
test_big_modeling:
python -m pytest -s -v ./tests/test_big_modeling.py ./tests/test_modeling_utils.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_big_modeling.log",)
test_deepspeed:
python -m pytest -s -v ./tests/deepspeed $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_deepspeed.log",)
@ -57,7 +64,7 @@ test_examples:
# Broken down example tests for the CI runners
test_integrations:
python -m pytest -s -v ./tests/deepspeed ./tests/fsdp ./tests/tp $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_integrations.log",)
python -m pytest -s -v ./tests/fsdp ./tests/tp ./tests/deepspeed $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_integrations.log",)
test_example_differences:
python -m pytest -s -v ./tests/test_examples.py::ExampleDifferenceTests $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_example_diff.log",)

View File

@ -7,7 +7,7 @@ RUN pip install transformers evaluate datasets
RUN git clone https://github.com/huggingface/accelerate.git
RUN cd accelerate && \
pip install -e . && \
pip install -e .[deepspeed] && \
cd benchmarks/fp8
RUN /bin/bash

View File

@ -62,8 +62,8 @@
title: Amazon SageMaker
- local: usage_guides/mps
title: Apple M1 GPUs
- local: usage_guides/ipex
title: IPEX training with CPU
- local: usage_guides/intel_cpu
title: Intel CPU
- local: usage_guides/gaudi
title: Intel Gaudi
- local: usage_guides/compilation
@ -82,8 +82,6 @@
title: Accelerate's internal mechanism
- local: concept_guides/big_model_inference
title: Loading big models into memory
- local: concept_guides/context_parallel
title: Context parallelism
- local: concept_guides/performance
title: Comparing performance across distributed setups
- local: concept_guides/deferring_execution
@ -94,6 +92,8 @@
title: FSDP vs DeepSpeed
- local: concept_guides/fsdp1_vs_fsdp2
title: FSDP1 vs FSDP2
- local: concept_guides/context_parallelism
title: Context parallelism
- local: concept_guides/low_precision_training
title: Low precision training methods
- local: concept_guides/training_tpu

View File

@ -19,54 +19,70 @@ This guide will cover basics of using context parallelism in 🤗`accelerate`, f
## Why context parallelism?
With the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has lead to a need for more efficient ways to train models with long sequences.
With the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has led to a need for more efficient ways to train models with long sequences.
With sequence length of 128k, the memory requirement of the attention matrix is `128k * 128k * 2 bytes * num_heads = ~32 GB * num_heads` for `bf16` precision, given vanilla attention implementation. Granted, with usage of `flash attention` or `SDPA` which do not materialize these attention weights, this decreases drastically, but the growth in memory requirements is still considerable.
Context parallelism allows us to shard the inputs to the attention computation along the sequence dimension and compute the attention in parallel on multiple GPUs. With this, we can train models with long sequences, scaling potentially to 1M+ sequence length.
## How to use context parallelism?
As with any other feature in 🤗`accelerate`, enabling context parallelism is as simple as passing the corresponding flags to `accelerate launch`.
```diff
from accelerate.utils import ParallelismConfig, TorchContextParallelConfig
+ cp_config = TorchContextParallelConfig(
+ cp_comm_strategy="alltoall", # no need to use cp_config at all, if you want to use the default "allgather"
+ )
+ parallelism_config = ParallelismConfig(
+ cp_size=8,
+ cp_handler=cp_config, # or just cp_size=8, if you want to use the default "allgather"
+ )
accelerator = Accelerator(
...,
parallelism_config=parallelism_config,
)
```
As with any other feature in 🤗`accelerate`, you can enable context parallelism also by passing the corresponding flags to `accelerate launch`.
In this case, it's no different:
```bash
accelerate launch --context-parallel-size 8 --context-parallel-shard-rotation [allgather|alltoall] ...
accelerate launch --parallelism-config-cp-size 8 --parallelism-config-cp-comm-strategy [allgather|alltoall] ...
```
Context parallelism is tightly coupled (for now) with `FSDP2`, which you can learn more about in the [FSDP2 introduction](fsdp1_vs_fsdp2.md). Meaning, context parallelism is applied only if `FSDP2` is enabled.
You can also enable context parallelism programatically, by passing it in the `FullyShardedDataParallelPlugin` constructor:
> [!Tip]
> You can also set the `cp_size` and `cp_comm_strategy` in the `accelerate config` command, which will save them in your `accelerate` configuration file, so you don't have to pass them every time you launch your script.
```diff
from accelerate.utils import FullyShardedDataParallelPlugin
> [!Tip]
> Context parallelism is compatible with other parallelism strategies, such as data parallelism, tensor parallelism and FSDP2.
> You can simply combine them by setting your parallelism sizes to the desired values, e.g. `--parallelism-config-dp-size 8 --parallelism-config-tp-size 2 --parallelism-config-cp-size 8`. Or you can use the `ParallelismConfig` class to set them programmatically.
plugin = FullyShardedDataParallelPlugin(
...
fsdp_version=2,
+ cp_size=8,
+ cp_comm_strategy="allgather",
)
accelerator = Accelerator(fsdp_plugin=plugin)
```
> [!Warning]
> Context parallelism is tightly coupled with `FSDP2`, which you can learn more about in the [FSDP2 introduction](fsdp1_vs_fsdp2.md). Meaning, context parallelism only works if you use `FullyShardedDataParallelPlugin` or `--use-fsdp` with version set to 2 to your
> program. If no `FSDP2` is used, error will be raised.
After enabling context parallelism with the methods mentioned above, you can then apply it to your training loop. We provide a thin wrapper around [`torch.distributed.tensor.experimental.context_parallel`](https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel) that you can use in your training loop, that abstracts some of the complexity of using it (more on this later).
> [!Warning]
> Context parallelism works only with [SDPA](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) and only with no mask or causal mask. We can't properly detect this for you, so it's your responsibility to ensure that you are using `SDPA` with no mask or causal mask. If you use any other attention implementation, it will raise an error.
After enabling context parallelism with the methods mentioned above, you can then apply it to your training loop. We provide a thin wrapper around [`torch.distributed.tensor.experimental.context_parallel`](https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel) that you can use in your training loop, that abstracts some of the complexity of using it (more on this later). To minimize the changes you have to do in your training loop, we provide a context manager that is a `noop` if context parallelism is not enabled, and applies the context parallelism if it is enabled. This way, you can use it in your training loop without changing any code based on your parallelism configuration.
You can use it as follows:
```python
for batch in dataloader:
with accelerator.context_parallel(
with accelerator.maybe_context_parallel(
buffers=[batch["input_ids"], batch["attention_mask"]],
buffer_seq_dims=[1, 1],
no_restore_buffers={batch["input_ids"]},
no_restore_buffers={batch["input_ids"], batch["labels"]},
):
outputs = model(batch)
outputs = model(**batch)
...
```
> [!Warning]
> This context manager has to be recreated with each training step, as shown in the example above. It's crucial to do so.
This can scale your context size to 1M+ sequence length potentially. Below, we showcase speed and memory usage of context parallelism for up-to 256k context size. We can see that when we double the context size and number of GPUs, we can achieve consistent memory usage, potentiall enabling endless context length scaling.
This can scale your context size to 1M+ sequence length potentially. Below, we showcase speed and memory usage of context parallelism for up-to 256k context size. We can see that when we double the context size and number of GPUs, we can achieve consistent memory usage, potentially enabling endless context length scaling.
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_perf.png" alt="context parallelism memory usage" />
@ -75,7 +91,10 @@ This can scale your context size to 1M+ sequence length potentially. Below, we s
</p>
> [!Tip]
> These examples were created with a script you can find [in the examples folder](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/fsdp2_context_parallel.py). For instructions on how to run it, see the [README](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/README.md) in the same folder.
> These examples were created with a script you can find [in the examples folder](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/nd_parallel.py). To run the example on 8 H100 GPUs (128k sequence length), you can use the following command:
> ```bash
> accelerate launch --use-fsdp --fsdp-activation-checkpointing=TRUE examples/fsdp2/nd_parallel.py --cp-size=8 --sequence-length=128000
> ```
## Accelerate's interface
@ -83,19 +102,32 @@ This can scale your context size to 1M+ sequence length potentially. Below, we s
The context manager takes a few arguments, that are used to configure the context parallelism.
- `buffers`: This is a list of tensors that are to be sharded across the sequence dimension. These tensors are usually input ids, labels and attention mask.
- `buffer_seq_dims`: This is a list of integers, that specify the sequence dimension of the buffers, in the order of the `buffers` list.
- `no_restore_buffers`: The implementation of context parallelism modifies the buffers in-place, converting them to `torch.distributed.tensor.Dtensor`s. After the context manager is exited, a communication kernel would need to be launched to restore the buffers to their original state (usually all-gather). This takes some time, so it is reccomended to pass the same arguments as to the `buffers` argument, to avoid unnecessary communication, unless you are sure that you need to use the buffers after the context manager is exited.
- `buffer_seq_dims`: This is a list of integers, that specify the sequence dimension of the buffers, in the order of the `buffers` list. If you pass `buffers=[input_ids, shift_labels]` with both having shape `[batch_size, sequence_length]`, you would pass `buffer_seq_dims=[1, 1]`.
as the sequence dimension is the second dimension of the tensors. This is required for correct computation of the model outputs.
- `no_restore_buffers`: The implementation of context parallelism modifies the buffers in-place, converting them to `torch.distributed.tensor.Dtensor`s. After the context manager exits, a communication kernel would need to be launched to restore the buffers to their original state (usually all-gather). This takes some time, so it is recommended to pass the same tensors as in the `buffers` argument, to avoid unnecessary communication, unless you are sure that you need to use the buffers after the context manager exits.
> [!Warning]
> Context parallelism is not compatible with `labels` that are a copy of `input_ids`, which models from 🤗 transformers can shift to enable causal language modeling themselves.
> Imagine this case:
> labels = [l1, l2, l3, l4, ... li]
> if we apply context parallelism, each rank would end up with a part of labels, such as this:
> labels_rank_0 = [l1, l2], labels_rank_1 = [l3, l4], ...
> after transformers modelling code shifts the labels, it would end up with:
> labels_rank_0 = [l2, PAD], labels_rank_1 = [l3, PAD], ...
> where `PAD` is a padding token. This would result in incorrect loss computation, as the labels are not aligned with the inputs anymore.
> Because of this, you need to manually shift the labels before passing them in the model
## Configurable options
Accelerate provides only a few options to configure context parallelism, which are:
Accelerate provides only a single option to configure context parallelism (except for `cp_size`)
- `cp_size`: The number of ranks to shard the inputs to the attention computation across the sequence dimension.
- `cp_comm_strategy`: The rotation method to use for the shards. We strongly reccomend keeping this as `"allgather"`, as it's very likely it will outperform `"alltoall"` in most cases.
- `cp_comm_strategy`: The rotation method to use for the shards. We strongly recommend keeping this as `"allgather"`, as it's very likely it will outperform `"alltoall"` in most cases.
Context parallel size is rather self-explanatory, it's the number of ranks across which the inputs are to be-sharded.
Context parallel shard rotation defines how the shards of the inputs are rotated across ranks. We'll cover the 2 options in more detail in the next section.
You can see an end-to-end example in the [FSDP2 context parallel example](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/fsdp2_context_parallel.py) file, where you can train an 8B model with 128k sequence length on 8x H100 SXM GPUs. Using multi-node training, you can scale this to 1M+ sequence length on 64x H100 SXM GPUs.
You can see an end-to-end example in the [ND parallel example](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/nd_parallel.py) file, where you can train an 8B model with up-to 128k context length on a single 8xH100 node. Using multi-node training, you can scale this to 1M+ sequence length on multiple GPUs. You can also seamlessly combine it with other parallelism strategies to fit your needs.
## Technical details
@ -110,7 +142,7 @@ We're going to be using word `shard` extensively in the following sections, so l
Context parallelism works on sharding the `Q, K and V` matrices across the sequence dimension. Each rank has its assigned shard of `Q`, let's call it `Q_i`. This matrix stays only on this rank, during the whole computation. Similarly, each rank has its own shard of `K` and `V`, let's call them `K_i` and `V_i`. Then, each rank calculates attention with its own shard of `Q_i`, `K_i` and `V_i`, let's call it `attn_i`. During this computation, a communication kernel is launched to gather the `Ks` and `Vs` from all other ranks. What communication primitive is used, depends on the `context_parallel_shard_rotation` option.
This way, each rank gets to calculate local attention, first with `Q_i`, `K_i` and `V_i`, then with `K_j` and `V_j` from all other ranks. As each rank holds `Q, K and V` matrices that are sharded across the sequence dimension, the resulting matrices are smaller and can fit on a single GPU.
We can formalize this in a following pseudocode:
We can formalize this in the following pseudocode:
```python
comm_kernel = {"allgather": allgather, "alltoall": alltoall}[context_parallel_shard_rotation]
Qi, Ki, Vi = shard(Q, K, V, seq_dim)
@ -132,7 +164,7 @@ In ideal scenario, all-gather finishes in the exact moment as the calculation of
All-to-all, or sometimes called `ring-rotation` utilizes a ring-like communication pattern. After concluding `attn_i` computation, an all-to-all is launched to send `K_i` and `V_i` to the neighbouring ranks. We then repeat this `context_parallel_size-1` times, so that each rank sees all the shards of `K` and `V` from all other ranks once. In ideal scenario, we prefetch shards `K_i+1` and `V_i+1` from the neighbouring rank and this communication is exactly overlapped with computation of our current `attn_i`. Again, realistically, this perfect overlap doesn't ever happen. Given the nature of this approach, if we don't achieve perfect overlap, the penalty is way larger than with all-gather.
## How to choose the right rotation method?
In theory, all-to-all should be the better choice. Though in practice, it rarely is. Therefore, we default to all-gather, as it's more likely to achieve better performance. Extensive [benchmarks](https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082) from the `torchtitan` team also shows that all-to-all rarely outperforms all-gather. Though, we still provide both options, as you might find one to be better for your use case.
In theory, all-to-all should be the better choice. Though in practice, it rarely is. Therefore, we default to all-gather, as it's more likely to achieve better performance. Extensive [benchmarks](https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082) from the `torchtitan` team also show that all-to-all rarely outperforms all-gather. Though, we still provide both options, as you might find one to be better for your use case.
You can directly see this issue in the profiler output in the image below:
<p align="center">
@ -144,8 +176,10 @@ You can directly see this issue in the profiler output in the image below:
## Why only FSDP2?
We only support context parallelism with `FSDP2` for now, as we create a joint mesh of `context_parallel_size` and `dp_shard_size` to
utilize its full potential. In the profiler output in the image below, you can see why this is the case.
We only support context parallelism with `FSDP2`, as we create a joint mesh of `context_parallel_size` and `dp_shard_size` to
utilize its full potential.
How it works is: we shard the model across the joint mesh of size `cp_size*dp_shard_size`, which maximizes the memory savings.
This is a "free lunch" of sorts, as `FSDP` communication is fully overlapped with the computation of attention, as shown in the images below.
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_why_fsdp2.png" alt="why FSDP2+CP" />
@ -154,3 +188,17 @@ utilize its full potential. In the profiler output in the image below, you can s
</p>
In the figure above, you can also note the difference between all-to-all and all-gather. While in all-to-all (Figure 1), we launch a communication kernel N-1 times for each attention call, in all-gather (Figure 2), we launch a communication kernel only once. This results in a bigger bubble, but it only happens once per attention call, while in all-to-all, it happens N-1 times.
## Data dispatching in joint mesh
We make sure to dispatch the same batch of data to the whole `cp` subgroup, so that the results are correct. (Meaning each rank in `cp` subgroup gets the same batch of data.) However, we also dispatch different batches to each rank of `dp_shard` group.
Imagine it like this:
```
# 8 GPUS, --dp_shard_size 4, --cp_size 2
# mesh = [[0, 1], [2, 3], [4, 5], [6, 7]]
# model is sharded across the whole mesh (each GPU holds 1/8 of the model)
# GPUs 0,1 = batch 0
# GPUs 2,3 = batch 1
... and so on.
```

View File

@ -139,7 +139,7 @@ values. They can also be passed in manually.
* `--cpu` (`bool`) -- Whether or not to force the training on the CPU.
* `--multi_gpu` (`bool`) -- Whether or not this should launch a distributed GPU training.
* `--tpu` (`bool`) -- Whether or not this should launch a TPU training.
* `--ipex` (`bool`) -- Whether or not this should launch an Intel Pytorch Extension (IPEX) training.
* `--ipex` (`bool`) -- Whether or not this should launch an Intel Pytorch Extension (IPEX) training. **This argument is deprecated, will be removed in Accelerate v1.10**
**Resource Selection Arguments**:
@ -158,7 +158,7 @@ The following arguments are useful for selecting which training paradigm to use.
* `--use_deepspeed` (`bool`) -- Whether or not to use DeepSpeed for training.
* `--use_fsdp` (`bool`) -- Whether or not to use FullyShardedDataParallel for training.
* `--use_megatron_lm` (`bool`) -- Whether or not to use Megatron-LM for training.
* `--use_xpu` (`bool`) -- Whether to use IPEX plugin to speed up training on XPU specifically. **This argument is deprecated and ignored, will be removed in Accelerate v1.20**
* `--use_xpu` (`bool`) -- Whether to use IPEX plugin to speed up training on XPU specifically. **This argument is deprecated and ignored, will be removed in Accelerate v1.10**
**Distributed GPU Arguments**:

View File

@ -29,6 +29,11 @@ rendered properly in your Markdown viewer.
[[autodoc]] tracking.WandBTracker
- __init__
## Trackio
[[autodoc]] tracking.TrackioTracker
- __init__
## CometMLTracker
[[autodoc]] tracking.CometMLTracker
@ -48,3 +53,8 @@ rendered properly in your Markdown viewer.
[[autodoc]] tracking.ClearMLTracker
- __init__
## SwanLabTracker
[[autodoc]] tracking.SwanLabTracker
- __init__

View File

@ -245,7 +245,7 @@ As was pointed out in this [blog-post](https://huggingface.co/blog/gradient_accu
> [...] for gradient accumulation across token-level tasks like causal LM training, the correct loss should be computed by the **total loss across all batches in a gradient accumulation step** divided by the **total number of all non padding tokens in those batches**. This is not the same as the average of the per-batch loss values.
In other words, some adjustements must be made on losses that operate on a token-level basis.
In other words, some adjustments must be made on losses that operate on a token-level basis.
### Skeleton code
@ -282,7 +282,7 @@ for update_step in range(total_updates):
num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item()
for i, batch in enumerate(batch_samples):
# if we perform gradient accumulation in a multi-devices set-up, we want to avoid unecessary communications when accumulating
# if we perform gradient accumulation in a multi-devices set-up, we want to avoid unnecessary communications when accumulating
# cf: https://muellerzr.github.io/blog/gradient_accumulation.html
if (i < len(batch_samples) - 1 and accelerator.num_processes > 1):
ctx = model.no_sync
@ -294,7 +294,7 @@ for update_step in range(total_updates):
with ctx():
inputs, targets = batch
outputs = model(inputs)
loss = loss_function(outputs, targets) # the loss function shoud sum over samples rather than averaging
loss = loss_function(outputs, targets) # the loss function should sum over samples rather than averaging
# We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices
# Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps
@ -394,7 +394,7 @@ for update_step in range(total_gradient_updates):
for i, batch in enumerate(batch_samples):
inputs, labels = batch["input_ids"], batch["labels"]
total_batched_samples += 1
# if we perform gradient accumulation in a multi-devices set-up, we want to avoid unecessary communications when accumulating
# if we perform gradient accumulation in a multi-devices set-up, we want to avoid unnecessary communications when accumulating
# cf: https://muellerzr.github.io/blog/gradient_accumulation.html
if (i < len(batch_samples) - 1 and accelerator.num_processes > 1):
ctx = model.no_sync

View File

@ -13,34 +13,11 @@ specific language governing permissions and limitations under the License.
rendered properly in your Markdown viewer.
-->
# Intel® Extension for PyTorch
[IPEX](https://github.com/intel/intel-extension-for-pytorch) is optimized for CPUs with AVX-512 or above, and functionally works for CPUs with only AVX2. So, it is expected to bring performance benefit for Intel CPU generations with AVX-512 or above while CPUs with only AVX2 (e.g., AMD CPUs or older Intel CPUs) might result in a better performance under IPEX, but not guaranteed. IPEX provides performance optimizations for CPU training with both Float32 and BFloat16. The usage of BFloat16 is the main focus of the following sections.
Low precision data type BFloat16 has been natively supported on the 3rd Generation Xeon® Scalable Processors (aka Cooper Lake) with AVX512 instruction set and will be supported on the next generation of Intel® Xeon® Scalable Processors with Intel® Advanced Matrix Extensions (Intel® AMX) instruction set with further boosted performance. The Auto Mixed Precision for CPU backend has been enabled since PyTorch-1.10. At the same time, the support of Auto Mixed Precision with BFloat16 for CPU and BFloat16 optimization of operators has been massively enabled in Intel® Extension for PyTorch, and partially upstreamed to PyTorch master branch. Users can get better performance and user experience with IPEX Auto Mixed Precision.
## IPEX installation:
IPEX release is following PyTorch, to install via pip:
| PyTorch Version | IPEX version |
| :---------------: | :----------: |
| 2.0 | 2.0.0 |
| 1.13 | 1.13.0 |
| 1.12 | 1.12.300 |
| 1.11 | 1.11.200 |
| 1.10 | 1.10.100 |
```
pip install intel_extension_for_pytorch==<version_name> -f https://developer.intel.com/ipex-whl-stable-cpu
```
Check more approaches for [IPEX installation](https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/installation.html).
# Training on Intel CPU
## How It Works For Training optimization in CPU
Accelerate has integrated [IPEX](https://github.com/intel/intel-extension-for-pytorch), all you need to do is enabling it through the config.
Accelerate has full support for Intel CPU, all you need to do is enabling it through the config.
**Scenario 1**: Acceleration of No distributed CPU training
@ -55,7 +32,6 @@ This machine
Which type of machine are you using?
No distributed training
Do you want to run your training on CPU only (even if a GPU / Apple Silicon device is available)? [yes/NO]:yes
Do you want to use Intel PyTorch Extension (IPEX) to speed up training on CPU? [yes/NO]:yes
Do you wish to optimize your script with torch dynamo?[yes/NO]:NO
Do you want to use DeepSpeed? [yes/NO]: NO
-----------------------------------------------------------------------------------------------------------------------------------------------------------
@ -69,15 +45,12 @@ default options when doing
accelerate launch my_script.py --args_to_my_script
```
For instance, here is how you would run the NLP example `examples/nlp_example.py` (from the root of the repo) with IPEX enabled.
default_config.yaml that is generated after `accelerate config`
For instance, here is how you would run the NLP example `examples/nlp_example.py` (from the root of the repo) with `default_config.yaml` which is generated by `accelerate config`
```bash
compute_environment: LOCAL_MACHINE
distributed_type: 'NO'
downcast_bf16: 'no'
ipex_config:
ipex: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
@ -117,7 +90,6 @@ What is the rank of this machine?
What is the IP address of the machine that will host the main process? 36.112.23.24
What is the port you will use to communicate with the main process? 29500
Are all the machines on the same local network? Answer `no` if nodes are on the cloud and/or on different network hosts [YES/no]: yes
Do you want to use Intel PyTorch Extension (IPEX) to speed up training on CPU? [yes/NO]:yes
Do you want accelerate to launch mpirun? [yes/NO]: yes
Please enter the path to the hostfile to use with mpirun [~/hostfile]: ~/hostfile
Enter the number of oneCCL worker threads [1]: 1
@ -129,13 +101,11 @@ bf16
```
For instance, here is how you would run the NLP example `examples/nlp_example.py` (from the root of the repo) with IPEX enabled for distributed CPU training.
default_config.yaml that is generated after `accelerate config`
`default_config.yaml` which is generated by `accelerate config`
```bash
compute_environment: LOCAL_MACHINE
distributed_type: MULTI_CPU
downcast_bf16: 'no'
ipex_config:
ipex: true
machine_rank: 0
main_process_ip: 36.112.23.24
main_process_port: 29500
@ -156,8 +126,10 @@ use_cpu: true
Set following env and using intel MPI to launch the training
In node0, you need to create a configuration file which contains the IP addresses of each node (for example hostfile) and pass that configuration file path as an argument.
If you selected to have Accelerate launch `mpirun`, ensure that the location of your hostfile matches the path in the config.
In `node0`, you need to create a configuration file which contains the IP addresses of each node (for example hostfile) and pass that configuration file path as an argument.
If you selected to let Accelerate launch `mpirun`, ensure that the location of your hostfile matches the path in the config.
```bash
$ cat hostfile
xxx.xxx.xxx.xxx #node0 ip
@ -165,18 +137,18 @@ xxx.xxx.xxx.xxx #node1 ip
xxx.xxx.xxx.xxx #node2 ip
xxx.xxx.xxx.xxx #node3 ip
```
When Accelerate is launching `mpirun`, source the oneCCL bindings setvars.sh to get your Intel MPI environment, and then
run your script using `accelerate launch`. Note that the python script and environment needs to exist on all of the
machines being used for multi-CPU training.
Before executing `accelerate launch` command, you need source the oneCCL bindings `setvars.sh` to get your Intel MPI environment properly. Note that both the python script and environment need to be available on all of the machines being used for multi-CPU training.
```bash
oneccl_bindings_for_pytorch_path=$(python -c "from oneccl_bindings_for_pytorch import cwd; print(cwd)")
source $oneccl_bindings_for_pytorch_path/env/setvars.sh
accelerate launch examples/nlp_example.py
```
Otherwise, if you selected not to have Accelerate launch `mpirun`, run the following command in node0 and **16DDP** will
be enabled in node0,node1,node2,node3 with BF16 mixed precision. When using this method, the python script, python
environment, and accelerate config file need to be present on all of the machines used for multi-CPU training.
You can also directly launch distributed training with `mpirun` command, you need to run the following command in node0 and **16DDP** will be enabled in node0,node1,node2,node3 with BF16 mixed precision. When using this method, the python script, python environment, and accelerate config file need to be available on all of the machines used for multi-CPU training.
```bash
oneccl_bindings_for_pytorch_path=$(python -c "from oneccl_bindings_for_pytorch import cwd; print(cwd)")
source $oneccl_bindings_for_pytorch_path/env/setvars.sh
@ -185,11 +157,3 @@ export MASTER_ADDR=xxx.xxx.xxx.xxx #node0 ip
export CCL_ATL_TRANSPORT=ofi
mpirun -f hostfile -n 16 -ppn 4 accelerate launch examples/nlp_example.py
```
## Related Resources
- [Project's github](https://github.com/intel/intel-extension-for-pytorch)
- [API docs](https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/api_doc.html)
- [Tuning guide](https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html)
- [Blogs & Publications](https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/blogs_publications.html)

View File

@ -20,10 +20,11 @@ Accelerate provides a general tracking API that can be used to log useful items
## Integrated Trackers
Currently `Accelerate` supports seven trackers out-of-the-box:
Currently `Accelerate` supports eight trackers out-of-the-box:
- TensorBoard
- WandB
- WandB
- Trackio
- CometML
- Aim
- MLFlow

View File

@ -218,7 +218,7 @@ def parse_args():
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"`, `"comet_ml"`, and `"dvclive"`. Use `"all"` (default) to report to all integrations.'
' `"wandb"`, `"comet_ml"`, `"dvclive"`, and `"swanlab"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)

View File

@ -215,7 +215,7 @@ def parse_args():
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"`, `"comet_ml"`, and `"dvclive"`. Use `"all"` (default) to report to all integrations.'
' `"wandb"`, `"comet_ml"`, and `"dvclive"`, and `"swanlab"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)

View File

@ -31,8 +31,8 @@ from accelerate.utils import ProfileKwargs
#
# This example trains a Bert base model on GLUE MRPC
# in any of the following settings (with the same script):
# - single CPU or single GPU
# - multi GPUS (using PyTorch distributed mode)
# - single CPU or single device (CUDA GPU, Intel XPU etc.)
# - multi devices (using PyTorch distributed mode)
# - (multi) TPUs
# - fp16 (mixed-precision) or fp32 (normal precision)
#
@ -183,7 +183,8 @@ def training_function(config, args):
# New Code #
accelerator.print(
prof.key_averages().table(
sort_by="self_cpu_time_total" if args.cpu else "self_cuda_time_total", row_limit=-1
sort_by="self_cpu_time_total" if args.cpu else f"self_{accelerator.device.type}_time_total",
row_limit=-1,
)
)
@ -215,7 +216,7 @@ def main():
choices=["no", "fp16", "bf16", "fp8"],
help="Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU.",
"and an Nvidia Ampere GPU or an Intel XPU.",
)
# New Code #
parser.add_argument(

View File

@ -8,7 +8,7 @@ deepspeed_config:
# `transformers` uses the right `init` function
zero3_init_flag: false # true
# Finally we need to specify the number of GPUs to use
# Finally we need to specify the number of accelerators to use
num_processes: 2
# Optionally we can set the mixed precision now instead of in the deepspeed config file,
# however this requires the `fp16` and `bf16` options to be set to `auto` in the deepspeed config file

View File

@ -11,8 +11,8 @@ fp8_config:
fp8_format: E4M3
interval: 1
margin: 0
override_linear_precision: (false, false, false)
override_linear_precision: [false, false, false]
# Generally this should always be set to `false` to have the most realistic fp8 eval performance
use_autocast_during_eval: false
# If using MS-AMP, we ignore all of the prior and set a opt_level
#opt_level: O1
#opt_level: O1

View File

@ -1,8 +1,8 @@
# Since we are doing FSDP (even though it's multi-GPU), we need to specify the distributed type as FSDP
# Since we are doing FSDP (even though it's multi-accelerator), we need to specify the distributed type as FSDP
distributed_type: FSDP
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`, but it works for FSDP as well)
mixed_precision: 'bf16'
# Specify the number of GPUs to use
# Specify the number of accelerators to use
num_processes: 2
# Then we can specify the FSDP config
fsdp_config:

View File

@ -0,0 +1,6 @@
# Specify distributed_type as `MULTI_XPU` for DDP
distributed_type: "MULTI_XPU"
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`)
mixed_precision: "bf16"
# Specify the number of XPUs to use
num_processes: 2

View File

@ -1,4 +1,4 @@
# Since this is single GPU, we don't need distributed training
# Since this is single GPU/XPU, we don't need distributed training
distributed_type: "NO"
# Can be one of "no", "fp16", or "bf16" (see `transformer_engine.yaml` for `fp8`)
mixed_precision: "bf16"
mixed_precision: "bf16"

View File

@ -1,58 +0,0 @@
# FSDP2 Examples
This folder contains examples of using FSDP2 with Accelerate, utilizing extra methods to improve training speed, performance or accuracy.
## FSDP2 + ao Float8Linear (`fsdp2_fp8.py`)
In file `fsdp2_fp8.py` we use `Float8Linear` from `ao` to train a model partially in FP8 precision. We utilize `AORecipeKwargs` to pass the `Float8LinearConfig` to the accelerator,
which replaces the default `torch.nn.Linear` with `Float8Linear`. We also utilize `TorchDynamoPlugin` together with regional compilation to compile the model,
gaining even more speed and memory savings, as `ao` doesn't ship with any kernels by default, so we have to gain the performance from compiling the model.
Replacing linear layers with `Float8Linear` can greatly improve performance, if used correctly and on hardware that supports FP8 tensor cores. This highly depends on the model dimensions and sequence length used for training.
You can view the performance of `Float8Linear` as a function of matrix dimensions in [this document](https://github.com/pytorch/ao/blob/main/torchao/float8/README.md#performance).
In our example, we use a 8B Llama3.1 model, which has a hidden dimension of 4096 and we train on sequence length of 8192. In the below images, we can see that this improves performance by ~25% compared to `bf16`, reaching ~10000 tokens per second, per device on 8x H100 GPUs, compared to ~8000 tokens per second using `bf16`, while loss function stays roughly the same. We can also see that the FLOPS raise by using FP8.
<div style="display: flex; gap: 25px;">
<div style="text-align: center; width: 49%;">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/fp8_tps.png" alt="tps" style="width: 100%;">
<p style="text-align: center; margin-top: 8px;">TPs per device, bf16 vs fp8</p>
</div>
<div style="text-align: center; width: 49%;">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/fp8_tflops.png" alt="tflops" style="width: 100%;">
<p style="text-align: center; margin-top: 8px;">TFLOPS per device, bf16 vs fp8. We cannot really compare MFU as fp8 tensor cores are used as well.</p>
</div>
<div style="text-align: center; width: 49%;">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/fp8_loss.png" alt="loss" style="width: 100%; max-width: 900px;">
<p style="text-align: center; margin-top: 8px;">Loss curve, bf16 vs fp8, it's hard to see the difference as the curves mostly overlap</p>
</div>
</div>
The figures above were generated on 8x H100 SXM GPUs, with 8192 sequence length and 1000 steps. To run the example, you can use the following command, where you can specify the precision to train in:
```bash
accelerate launch --fsdp2_fp8.py --sequence_length 8192 --num_steps 1000 --log_with wandb --precision [fp8 | bf16]
```
## FSDP2 + context parallelism (`fsdp2_context_parallel.py`)
In this file, we showcase integration of context parallelism with FSDP2. Context parallelism is a technique that allows us to scale the training to sequence length of up to a million tokens. With `accelerator.context_parallel` context manager, we replace the attention implementation with a context parallel version, which enables us to train on a sequence length of up to 128k tokens on 8x H100 GPUs, with possibility of endless scaling if we have enough GPUs.
For a detailed explanation and more details, please refer to [this guide](https://huggingface.co/docs/accelerate/concept_guides/context_parallel). You can run the example with the following command:
```bash
accelerate launch --fsdp2_context_parallel.py --sequence_length 128000 --num_steps 1000 --log_with wandb --cp_size 8 --cp_comm_strategy allgather
```
More details about the context parallelism can be found in the [concept guide](https://huggingface.co/docs/accelerate/concept_guides/context_parallel). You can see some results below:
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_perf.png" alt="context parallelism memory usage" />
<br>
<em>Figure 1: Memory usage and speed of context parallelism for up-to 256k context size.</em>
</p>

View File

@ -1,179 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example of training with Context Parallel using FSDP2 via Accelerate.
This example demonstrates how to use Accelerate's context_parallel feature for efficient long sequence training.
"""
import argparse
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM
from accelerate import Accelerator
from accelerate.utils import FullyShardedDataParallelPlugin, set_seed
from utils import PerformanceTracker, create_collate_fn, get_dataset, setup_tokenizer
MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--sequence-length", type=int, default=128_000, help="Sequence length for the dataset")
parser.add_argument("--num-steps", type=int, default=100, help="Number of training steps")
parser.add_argument("--log-with", type=str, default="wandb", help="Logging service to use")
parser.add_argument("--cp-size", type=int, default=8, help="Context parallel size")
parser.add_argument("--cp-comm-strategy", type=str, default="allgather", help="Context parallel shard rotation")
return parser.parse_args()
def training_step(batch, model, optimizer, accelerator: Accelerator):
"""
Perform a single training step with context parallel.
Args:
batch: Input batch containing input_ids and labels
model: The model to train
optimizer: Optimizer
accelerator: Accelerator instance
Returns:
loss: Training loss
"""
# Use context parallel for efficient long sequence processing
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG)
return loss
def main():
set_seed(42)
args = parse_args()
# Configure FSDP2 plugin
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy="transformer_based_wrap",
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
cpu_ram_efficient_loading=True,
activation_checkpointing=True,
fsdp_version=2,
cp_size=args.cp_size,
cp_comm_strategy=args.cp_comm_strategy,
)
# Initialize accelerator
accelerator = Accelerator(
log_with=args.log_with,
fsdp_plugin=fsdp_plugin,
mixed_precision="bf16",
)
accelerator.init_trackers(
project_name="FSDP2_context_parallel",
config={
"sequence_length": args.sequence_length,
"num_steps": args.num_steps,
"cp_size": args.cp_size,
"cp_comm_strategy": args.cp_comm_strategy,
},
)
# Prepare model and optimizer
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
use_cache=False,
)
tokenizer = setup_tokenizer(MODEL_ID)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
model, optimizer = accelerator.prepare(model, optimizer)
accelerator.print("Preparing dataset... this might take a while")
dataset = get_dataset(
accelerator,
tokenizer,
args.sequence_length,
processing_batch_size=args.sequence_length
// 20, # we need to override the default processing batch size to avoid empty packed sequences
)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
dataloader = accelerator.prepare(dataloader)
model.train()
total_num_steps = min(args.num_steps, len(dataloader))
performance_tracker = PerformanceTracker(warmup_steps=10)
accelerator.print(f"Starting training with context parallel for {total_num_steps} steps...")
accelerator.print(f"Sequence length: {args.sequence_length}")
accelerator.print("Warming up for 10 steps...")
accelerator.print(
"Each step takes ~10 seconds with default settings on 8x H100 SXM GPUs, seeing logs takes a while"
)
for step, batch in enumerate(dataloader):
print(f"Step {step}")
if step >= total_num_steps:
break
# get number of tokens before context_parallel shards the batch
batch_tokens = batch["input_ids"].shape[0] * batch["input_ids"].shape[1]
loss = training_step(batch, model, optimizer, accelerator)
# each batch gets the same data, we divide by the number of processes to get the number of tokens per process
metrics = performance_tracker.step(batch_tokens // accelerator.num_processes)
log_metrics = {"loss": loss.item()}
if "warmup_completed" in metrics:
accelerator.print("Warmup completed! Starting performance tracking...")
elif metrics:
log_metrics.update(
{
"tokens_per_second": int(metrics["tokens_per_second"]),
"steps_per_second": metrics["steps_per_second"],
}
)
if (step % 10 == 0 or step == total_num_steps - 1) and metrics:
accelerator.print(
f"Step {step}/{total_num_steps} | "
f"Loss: {loss.item():.4f} | "
f"Tokens/s: {int(metrics['tokens_per_second'])} | "
f"Steps/s: {metrics['steps_per_second']:.2f} | "
)
accelerator.log(log_metrics)
accelerator.wait_for_everyone()
accelerator.end_training()
accelerator.print("Training completed!")
if __name__ == "__main__":
main()

View File

@ -177,6 +177,7 @@ def training_function(config, args):
outputs = model(**batch)
predictions = outputs.logits.argmax(dim=-1)
predictions, references = accelerator.gather_for_metrics((predictions, batch["labels"]))
print(f"===== {predictions}")
metric.add_batch(
predictions=predictions,
references=references,

View File

@ -1,5 +1,5 @@
accelerate # used to be installed in Amazon SageMaker environment
evaluate
datasets==2.3.2
datasets
schedulefree
huggingface_hub>=0.20.0

View File

@ -0,0 +1,77 @@
## Torch Native Parallelism
With recent versions of Torch, there have been steady improvements in native parallelism using `DeviceMesh` and `DTensor`. 🤗 accelerate allows you to use these with our `ParallelismConfig` abstraction and/or `FullyShardedDataParallelPlugin(fsdp_version=2)`
This folder contains various examples of such use-cases: such as composing multiple parallelism strategies, low-bit training etc.
### ND Parallelism
With `ParallelismConfig`, you can use 🤗 accelerate to train models with n-dimensional parallelism. This builds on top of 🤗 transformers, which we utilize for tensor parallelism sharding.
Accelerate then takes care of everything else, such as data parallelism, FSDP or context parallelism.
Script `nd_parallel.py` showcases this. We enable you to configure 4 different parallel dimensions (for now 👀):
- dp_replicate_size: how many replicas of the model to create, each replica is trained on a different subset of the data and averaged at the end of each step, same as DDP in Torch
- dp_shard_size: across how many devices is the model sharded, this is utilizing FSDP2 to shard the model across devices, so each device has a different part of the model
- tp_size: how many devices to use for tensor parallelism, this is utilizing the tensor parallelism from 🤗 transformers
- cp_size: how many devices to use for context parallelism, this will also shard the model, optimizer and gradients using `FSDP2` across
the same group of devices, to further optimize memory usage (this comes with no slowdown)
For example, with 8 nodes, you can run the script as such:
```bash
accelerate launch --num-processes 8 nd_parallel.py \
--dp-replicate-size 2 \
--dp-shard-size 2 \
--tp-size 2 \
```
<Tip>
Only use TP intra-node - therefore max TP size you should need is 8. You can also use a lower size, as FSDP (`--dp-shard-size`) can be faster on smaller models with
shorter sequence lengths. If you cannot fit your model into memory, utilize `--dp-shard-size` as much as you can. Afterwards, to scale up and utilize all your resources, use `--dp-replicate-size`. This is only a general guideline, you can (and should) experiment with different parallelism configurations to find the best one for your model and hardware. You can learn more about the general strategies for parallelism in our [blog](https://huggingface.co/blog/accelerate-nd-parallel), or if you really want to dive deep, read the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook).
</Tip>
This feature is also fully integrated into 🤗 transformers `Trainer`. To use it, simply launch your script with path to your accelerate configuration file. You can see a minimal example of such script in `nd_parallel_trainer.py`.
We provide 2 pre-configured configuration files:
#### HSDP + TP (3D parallelism)
```bash
accelerate launch --config-file configs/tp_hsdp.yaml nd_parallel_trainer.py
```
#### Context parallelism (128k sequence length)
```bash
accelerate launch --config-file configs/cp.yaml nd_parallel_trainer.py --sequence-length=128000
```
### FSDP2 + ao Float8Linear
In file `fsdp2_fp8.py` we use `Float8Linear` from `ao` to train a model partially in FP8 precision. We utilize `AORecipeKwargs` to pass the `Float8LinearConfig` to the accelerator,
which replaces the default `torch.nn.Linear` with `Float8Linear`. We also utilize `TorchDynamoPlugin` together with regional compilation to compile the model,
gaining even more speed and memory savings, as `ao` doesn't ship with any kernels by default, so we have to gain the performance from compiling the model.
Replacing linear layers with `Float8Linear` can greatly improve performance, if used correctly and on hardware that supports FP8 tensor cores. This highly depends on the model dimensions and sequence length used for training.
You can view the performance of `Float8Linear` as a function of matrix dimensions in [this document](https://github.com/pytorch/ao/blob/main/torchao/float8/README.md#performance).
In our example, we use a 8B Llama3.1 model, which has a hidden dimension of 4096 and we train on sequence length of 8192. In the below images, we can see that this improves performance by ~25% compared to `bf16`, reaching ~10000 tokens per second, per device on 8x H100 GPUs, compared to ~8000 tokens per second using `bf16`, while loss function stays roughly the same. We can also see that the FLOPS rise by using FP8.
<div style="display: flex; gap: 25px;">
<div style="text-align: center; width: 49%;">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/fp8_tps.png" alt="tps" style="width: 100%;">
<p style="text-align: center; margin-top: 8px;">TPS per device, BF16 vs FP8</p>
</div>
<div style="text-align: center; width: 49%;">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/fp8_tflops.png" alt="tflops" style="width: 100%;">
<p style="text-align: center; margin-top: 8px;">TFLOPS per device, BF16 vs FP8. We cannot really compare MFU as FP8 tensor cores are used as well.</p>
</div>
<div style="text-align: center; width: 49%;">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/fp8_loss.png" alt="loss" style="width: 100%; max-width: 900px;">
<p style="text-align: center; margin-top: 8px;">Loss curve, BF16 vs FP8, it's hard to see the difference as the curves mostly overlap</p>
</div>
</div>
The figures above were generated on 8x H100 SXM GPUs, with 8192 sequence length and 1000 steps. To run the example, you can use the following command, where you can specify the precision to train in:
```bash
accelerate launch fsdp2_fp8.py --sequence-length 8192 --num-steps 1000 --log_with wandb --precision [fp8 | bf16]
```

View File

@ -0,0 +1,29 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: false
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
parallelism_config:
parallelism_config_cp_size: 8
parallelism_config_dp_replicate_size: 1
parallelism_config_dp_shard_size: 1
parallelism_config_tp_size: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,29 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: false
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
parallelism_config:
parallelism_config_cp_size: 1
parallelism_config_dp_replicate_size: 2
parallelism_config_dp_shard_size: 2
parallelism_config_tp_size: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -22,13 +22,15 @@ import argparse
import torch
from torch.utils.data import DataLoader
from torchao.float8 import Float8LinearConfig
from transformers import AutoConfig, AutoModelForCausalLM
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
from accelerate.utils import AORecipeKwargs, FullyShardedDataParallelPlugin, TorchDynamoPlugin, set_seed
from utils import PerformanceTracker, create_collate_fn, get_dataset, get_model_flops_per_token, setup_tokenizer
from utils import PerformanceTracker, create_collate_fn, get_dataset, get_model_flops_per_token
WARMUP_STEPS = 10
MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
@ -66,7 +68,6 @@ def main():
fp8_config = Float8LinearConfig(
enable_fsdp_float8_all_gather=True, # extra saving by gathering parameters in fp8 and upcasting after
force_recompute_fp8_weight_in_bwd=True,
)
kwargs = []
@ -89,24 +90,22 @@ def main():
torch_dtype=torch.bfloat16,
)
tokenizer = setup_tokenizer(MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
model, optimizer = accelerator.prepare(model, optimizer)
dataset = get_dataset(accelerator, tokenizer, args.sequence_length)
dataset = get_dataset(tokenizer, args.sequence_length, accelerator)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
dataloader = accelerator.prepare(dataloader)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
accelerator.wait_for_everyone()
model.train()
total_num_steps = min(args.num_steps, len(dataloader))
model_flops_per_token = get_model_flops_per_token(model, args.sequence_length)
performance_tracker = PerformanceTracker(warmup_steps=10)
accelerator.print(f"Starting training with {args.precision} precision for {total_num_steps} steps...")
accelerator.print(f"Sequence length: {args.sequence_length}")
accelerator.print("Warming up for 10 steps...")
performance_tracker = PerformanceTracker(warmup_steps=5)
for step, batch in enumerate(dataloader):
if step >= total_num_steps:
@ -118,35 +117,18 @@ def main():
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
batch_tokens = batch["input_ids"].shape[1]
metrics = performance_tracker.step(batch_tokens)
metrics = performance_tracker.step(batch["input_ids"].shape[1], model_flops_per_token)
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
log_metrics = {"loss": loss.item()}
if "warmup_completed" in metrics:
accelerator.print("Warm up completed! Starting performance tracking...")
accelerator.print("Warm up completed! Starting training")
elif metrics:
tps = metrics["tokens_per_second"]
tflops = metrics["total_tokens"] * model_flops_per_token / (metrics["total_time"] * 1e12)
# it's rather hard to get a good estimate of MFU as we train with FP8, so both FP8 and BF16 tensor cores are used, therefore we just report TFLOPS (Tera floating point operations per second)
# Given H100 SXM, the theoretical peak flops are ~990 TFLOPS for bf16 and ~1980 TFLOPS for fp8 [https://resources.nvidia.com/en-us-gpu-resources/h100-datasheet-24306]
# This is WITH sparsity, so we divide by 2 to get the answer w/o sparsity
print_msg += f" | Average steps/s: {metrics['steps_per_second']:.2f} | TPS per device: {tps:.2f} | TFLOPS per device: {tflops:.2f}"
log_metrics.update(
{
"steps_per_second": metrics["steps_per_second"],
"tps_per_device": tps,
"tflops_per_device": tflops,
}
)
print_msg += performance_tracker.get_print_message(metrics)
if step % 10 == 0 or step == total_num_steps - 1:
accelerator.print(print_msg)
accelerator.log(log_metrics)
accelerator.log(metrics)
accelerator.wait_for_everyone()
accelerator.end_training()

View File

@ -0,0 +1,173 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example of training with ND parallel using accelerate's ParallelismConfig
"""
import argparse
import warnings
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM
from accelerate import Accelerator
from accelerate.parallelism_config import ParallelismConfig
from accelerate.utils import FullyShardedDataParallelPlugin, set_seed
from utils import (
PerformanceTracker,
create_collate_fn,
get_dataset,
get_model_flops_per_token,
setup_tokenizer,
)
MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dp-replicate-size", type=int, default=1)
parser.add_argument("--dp-shard-size", type=int, default=1)
parser.add_argument("--tp-size", type=int, default=1)
parser.add_argument("--cp-size", type=int, default=1)
parser.add_argument("--sequence-length", type=int, default=1024)
parser.add_argument("--num-steps", type=int, default=1000)
parser.add_argument("--save-dir", type=str, default="./outputs")
parser.add_argument("--checkpoint-frequency", type=int, default=100)
parser.add_argument("--model-name", type=str, default=MODEL_ID)
return parser.parse_args()
def forward(model, batch, optimizer, accelerator: Accelerator):
batch["position_ids"] = torch.arange(0, batch["input_ids"].size(1), device=batch["input_ids"].device).unsqueeze(0)
# We need both labels and shift_labels, as the loss computation in the model is hidden behind `if labels is not None`, but the loss computation
# itself prioritzes shift_labels (if provided) which are the correct ones (due to labels being wrong if cp enabled)
buffers = [batch["input_ids"], batch["shift_labels"], batch["labels"], batch["position_ids"]]
with accelerator.maybe_context_parallel(
buffers=buffers, buffer_seq_dims=[1, 1, 1, 1], no_restore_buffers=set(buffers)
):
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
# As for DP we have a different batch on each device and for CP we essentially have a different part of sequences on each device
# I.e. with causal modelling and seq_len 1024, this dimension becomes another batch dimension of sorts
loss_reduce_grp = (
accelerator.torch_device_mesh["dp_cp"].get_group()
if accelerator.parallelism_config.dp_cp_dim_names
else None
)
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)
return loss
def train(args):
parallelism_config = ParallelismConfig(
dp_replicate_size=args.dp_replicate_size,
dp_shard_size=args.dp_shard_size,
tp_size=args.tp_size,
cp_size=args.cp_size,
)
# FSDP needs extra configuration, so we properly shard the model
fsdp2_plugin = None
if parallelism_config.dp_shard_enabled or parallelism_config.cp_enabled:
fsdp2_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
auto_wrap_policy="transformer_based_wrap",
transformer_cls_names_to_wrap=["Qwen3DecoderLayer"],
state_dict_type="SHARDED_STATE_DICT",
)
accelerator = Accelerator(
log_with=["wandb"], mixed_precision="bf16", parallelism_config=parallelism_config, fsdp_plugin=fsdp2_plugin
)
accelerator.init_trackers("nd_parallel_training")
# If TP was enabled, we need to tell transformers to prepare the model for us
model_kwargs = (
{"tp_size": args.tp_size, "tp_plan": "auto", "device_mesh": accelerator.torch_device_mesh}
if args.tp_size > 1
else {}
)
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
torch_dtype=torch.bfloat16,
use_cache=False,
**model_kwargs,
)
tokenizer = setup_tokenizer(args.model_name)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)
dataset = get_dataset(tokenizer, args.sequence_length, accelerator)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
total_num_steps = min(args.num_steps, len(dataloader))
model_flops_per_token = get_model_flops_per_token(model, args.sequence_length)
performance_tracker = PerformanceTracker(warmup_steps=5)
accelerator.print("Starting training...")
for step, batch in enumerate(dataloader):
if step >= total_num_steps:
break
loss = forward(model, batch, optimizer, accelerator)
# We report TPS per device, so we divide by the number of devices in the non-data parallel dimension
metrics = performance_tracker.step(
batch["input_ids"].shape[1] / parallelism_config.non_data_parallel_size,
model_flops_per_token=model_flops_per_token,
)
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
if "warmup_completed" in metrics:
accelerator.print("Warm up completed! Starting performance tracking...")
elif metrics:
print_msg += performance_tracker.get_print_message(metrics, with_memory=True)
if step % 10 == 0 or step == total_num_steps - 1:
accelerator.print(print_msg)
if step % args.checkpoint_frequency == 0 and step > 0 and parallelism_config.dp_shard_enabled:
accelerator.print(f"Saving checkpoint at step {step}...")
accelerator.save_state(args.save_dir + f"/checkpoint-{step}")
accelerator.log({"loss": loss.item()})
accelerator.print("Training completed!")
model.save_pretrained(args.save_dir + f"/{args.model_name}")
accelerator.print(f"Model saved to {args.save_dir}/{args.model_name}")
accelerator.end_training()
if __name__ == "__main__":
set_seed(42)
args = parse_args()
if args.dp_shard_size == 1 and args.tp_size > 1:
# We currently don't support saving with `save_state` when using only
# tensor parallelism, fsdp must be enabled
warnings.warn(
"Accelerator.save_state() is not yet supported with pure tensor parallel training. Training will work, but intermediate checkpoints will not be saved."
)
train(args)

View File

@ -0,0 +1,82 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from accelerate.utils import ParallelismConfig
from utils import get_dataset
MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--sequence-length", type=int, default=4096)
parser.add_argument("--checkpoint-frequency", type=int, default=100)
parser.add_argument("--model-name", type=str, default=MODEL_ID)
parser.add_argument("--save-dir", type=str, default=f"./accelerate-nd-parallel-{MODEL_ID.split('/')[-1]}")
parser.add_argument("--device-type", type=str, default="auto")
return parser.parse_args()
def main():
# If ParallelismConfig is not initialized with __init__, it reads from env vars
# which were set by using config
pc = ParallelismConfig()
args = parse_args()
if args.device_type == "auto":
args.device_type = torch.accelerator.current_accelerator().type
model_kwargs = {}
if pc.tp_enabled:
model_kwargs["tp_plan"] = "auto"
model_kwargs["device_mesh"] = pc.build_device_mesh(args.device_type)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForCausalLM.from_pretrained(args.model_name, use_cache=False, **model_kwargs)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
packed_dataset = get_dataset(tokenizer, args.sequence_length)
training_args = TrainingArguments(
output_dir=args.save_dir,
parallelism_config=pc,
num_train_epochs=1,
per_device_train_batch_size=1,
logging_steps=5,
save_steps=args.checkpoint_frequency,
learning_rate=5e-5,
remove_unused_columns=False,
bf16=True,
)
trainer = Trainer(
model=model,
args=training_args,
processing_class=tokenizer,
train_dataset=packed_dataset,
)
trainer.train()
trainer.save_model()
if __name__ == "__main__":
main()

View File

@ -13,10 +13,11 @@
# limitations under the License.
"""
Common utilities for FSDP2 examples.
Common utilities for torch-native-parallelism examples.
"""
import time
from contextlib import nullcontext
import torch
from datasets import Dataset, load_dataset
@ -25,12 +26,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
def get_dataset(
accelerator: Accelerator,
tokenizer: AutoTokenizer,
seq_len: int,
processing_batch_size: int = 1000,
) -> Dataset:
def get_dataset(tokenizer: AutoTokenizer, seq_len: int, accelerator: Accelerator | None = None) -> Dataset:
"""
Load and prepare TinyStories dataset.
@ -38,11 +34,11 @@ def get_dataset(
accelerator (Accelerator): Accelerate accelerator instance
tokenizer (AutoTokenizer): Hugging Face tokenizer
seq_len (int): Sequence length for the dataset
processing_batch_size (int): Batch size for processing the dataset
Returns:
Dataset: Packed dataset
"""
processing_ctx = accelerator.main_process_first if accelerator else nullcontext
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:50%]")
def tokenize_function(examples):
@ -56,10 +52,8 @@ def get_dataset(
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
return tokenized_batch
with accelerator.main_process_first():
tokenized_dataset = raw_dataset.map(
tokenize_function, batched=True, remove_columns=["text"], batch_size=processing_batch_size
)
with processing_ctx():
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
def create_packed_sequences(examples):
all_tokens = []
@ -69,6 +63,7 @@ def get_dataset(
num_sequences = len(all_tokens) // (seq_len + 1)
packed_input_ids = []
packed_labels = []
packed_position_ids = []
for i in range(num_sequences):
start_idx = i * (seq_len + 1)
@ -76,15 +71,21 @@ def get_dataset(
full_sequence = all_tokens[start_idx:end_idx]
packed_input_ids.append(full_sequence[:-1])
packed_labels.append(full_sequence[1:])
packed_position_ids.append(torch.arange(0, seq_len))
return {"input_ids": packed_input_ids, "labels": packed_labels}
return {
"input_ids": packed_input_ids,
"shift_labels": packed_labels,
"position_ids": packed_position_ids,
"labels": packed_labels,
}
with accelerator.main_process_first():
with processing_ctx():
packed_dataset = tokenized_dataset.map(
create_packed_sequences,
batched=True,
remove_columns=tokenized_dataset.column_names,
batch_size=processing_batch_size,
batch_size=1000,
)
return packed_dataset.shuffle(seed=42)
@ -119,8 +120,8 @@ def create_collate_fn():
def collate_fn(batch):
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
return {"input_ids": input_ids, "labels": labels}
shift_labels = torch.tensor([item["shift_labels"] for item in batch], dtype=torch.long)
return {"input_ids": input_ids, "shift_labels": shift_labels, "labels": shift_labels}
return collate_fn
@ -139,7 +140,7 @@ class PerformanceTracker:
self.is_in_warmup = True
self.step_count = 0
def step(self, batch_tokens: int) -> dict:
def step(self, batch_tokens: int, model_flops_per_token: float | None = None) -> dict:
"""
Update performance tracking with a new step.
@ -158,20 +159,43 @@ class PerformanceTracker:
return {"warmup_completed": True}
if not self.is_in_warmup and self.start_time is not None:
dct = {}
self.num_tokens += batch_tokens
total_time = time.perf_counter() - self.start_time
steps_from_warmup = self.step_count - self.warmup_steps
if total_time > 0 and steps_from_warmup > 0:
return {
memory_stats = gpu_memory_usage_all()
dct = {
"tokens_per_second": self.num_tokens / total_time,
"steps_per_second": steps_from_warmup / total_time,
"total_tokens": self.num_tokens,
"total_time": total_time,
**memory_stats,
}
if model_flops_per_token is not None:
flops = model_flops_per_token * self.num_tokens
dct["tflops_per_device"] = flops / (total_time * 1e12)
return dct
return {}
def get_print_message(self, metrics: dict, with_memory: bool = False) -> str:
print_msg = f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f}"
if "tflops_per_device" in metrics:
print_msg += f" | Average TFLOPS: {metrics['tflops_per_device']:.2f}\n"
else:
print_msg += "\n"
if with_memory:
print_msg += (
f"\tMemory (GB): active={metrics['peak_memory_active']:.1f}, "
f"alloc={metrics['peak_memory_alloc']:.1f}, "
f"reserved={metrics['peak_memory_reserved']:.1f}"
)
return print_msg
def setup_tokenizer(model_id: str) -> AutoTokenizer:
"""Setup tokenizer with proper padding token."""
@ -179,3 +203,21 @@ def setup_tokenizer(model_id: str) -> AutoTokenizer:
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def gpu_memory_usage_all(device=0):
device_type = torch.accelerator.current_accelerator().type
device = torch.device(f"{device_type}:{device}")
torch_device_module = getattr(torch, device_type, torch.cuda)
_BYTES_IN_GIB = 1024**3
peak_memory_active = torch_device_module.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
peak_memory_alloc = torch_device_module.max_memory_allocated(device) / _BYTES_IN_GIB
peak_memory_reserved = torch_device_module.max_memory_reserved(device) / _BYTES_IN_GIB
memory_stats = {
"peak_memory_active": peak_memory_active,
"peak_memory_alloc": peak_memory_alloc,
"peak_memory_reserved": peak_memory_reserved,
}
torch_device_module.reset_peak_memory_stats(device)
return memory_stats

View File

@ -41,7 +41,16 @@ extras["deepspeed"] = ["deepspeed"]
extras["rich"] = ["rich"]
extras["test_fp8"] = ["torchao"] # note: TE for now needs to be done via pulling down the docker image directly
extras["test_trackers"] = ["wandb", "comet-ml", "tensorboard", "dvclive", "mlflow", "matplotlib"]
extras["test_trackers"] = [
"wandb",
"comet-ml",
"tensorboard",
"dvclive",
"mlflow",
"matplotlib",
"swanlab",
"trackio",
]
extras["dev"] = extras["quality"] + extras["testing"] + extras["rich"]
extras["sagemaker"] = [
@ -50,7 +59,7 @@ extras["sagemaker"] = [
setup(
name="accelerate",
version="1.8.0.dev0",
version="1.11.0.dev0",
description="Accelerate",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "1.8.0.dev0"
__version__ = "1.11.0.dev0"
from .accelerator import Accelerator
from .big_modeling import (
@ -26,6 +26,7 @@ from .big_modeling import (
from .data_loader import skip_first_batches
from .inference import prepare_pippy
from .launchers import debug_launcher, notebook_launcher
from .parallelism_config import ParallelismConfig
from .state import PartialState
from .utils import (
AutocastKwargs,

File diff suppressed because it is too large Load Diff

View File

@ -747,3 +747,43 @@ def _attach_layerwise_casting_hooks(
non_blocking,
_prefix=layer_name,
)
def _attach_context_parallel_hooks(
model: nn.Module,
):
"""
Monkeypatch huggingface's `transformers` model to fix attention mask issues when using context parallelism.
This function attaches forward_pre_hooks to each self_attn module of the model, where each hook checks the
args/kwargs, if they contain an attention mask, if it does, it will remove this mask, check if it is a causal mask,
if yes, will add a kwarg `is_causal=True`, otherwise will raise an error. This is because context parallelism does
not support attention masks. This function modifies the model in place.
Args:
model (`nn.Module`):
The model to attach the hooks to.
"""
def _self_attn_pre_forward_hook(_module, module_args, module_kwargs):
if "attention_mask" in module_kwargs:
module_kwargs["attention_mask"] = None
module_kwargs["is_causal"] = True
return module_args, module_kwargs
for name, module in model.named_modules():
# We hope (assume) that if user uses their own model (without this structure which transformers uses), they read the docs saying they can't pass in attention masks
# Then these cases can happen:
# 1) some modules end with a `self-attn` module, in which case we attach the hook, but the
# there's no attention mask kwarg -> hook is a no-op
# 2) some modules end with a `self-attn` module, in which case we attach the hook, and the
# attention mask kwarg is passed -> hook will remove the attention mask and add
# `is_causal=True` kwarg, which either crashes the training or fixes it
# (training would crash anyway as attention mask isn't supported)
# 3) no modules end with a `self-attn` module, in which case we don't attach the hook, this is
# a no-op as well
if name.endswith("self_attn"):
# we want the hook to be executed first, to avoid any other hooks doing work on the attention mask
module.register_forward_pre_hook(_self_attn_pre_forward_hook, with_kwargs=True, prepend=True)

View File

@ -505,17 +505,48 @@ def get_cluster_input():
error_message="Please enter yes or no.",
)
if fsdp_version == 2:
fsdp_config["fsdp_cp_size"] = _ask_field(
"What should be your FSDP's context parallel size? (Input 1 or leave blank for no context parallel) [1]: ",
int,
default=1,
error_message="Please enter an integer.",
)
parallelism_config = {}
if fsdp_version == 2 and fsdp_config.get("fsdp_cp_size", 1) != 1:
fsdp_config["fsdp_cp_comm_strategy"] = _ask_options(
"What should be your FSDP's context parallel communication strategy? [allgather]: ",
if fsdp_config.get("fsdp_version", 1) == 2:
use_parallelism_config = _ask_field(
"Do you want to use the parallelism config? [yes/NO]: ",
_convert_yes_no_to_bool,
default=False,
error_message="Please enter yes or no.",
)
if use_parallelism_config:
prefix = "parallelism_config_"
parallelism_config[prefix + "dp_replicate_size"] = _ask_field(
"What is the data parallelism replicate size? [1]: ",
int,
default=1,
error_message="Please enter an integer.",
)
parallelism_config[prefix + "dp_shard_size"] = _ask_field(
"What is the FSDP shard size? [1]: ",
int,
default=1,
error_message="Please enter an integer.",
)
parallelism_config[prefix + "tp_size"] = _ask_field(
"What is the tensor parallelism size? [1]: ",
int,
default=1,
error_message="Please enter an integer.",
)
parallelism_config[prefix + "cp_size"] = _ask_field(
"What is the context parallelism size? [1]: ",
int,
default=1,
error_message="Please enter an integer.",
)
if parallelism_config[prefix + "cp_size"] > 1:
parallelism_config[prefix + "cp_comm_strategy"] = _ask_options(
"What is the compute parallelism communication strategy?",
["allgather", "alltoall"],
lambda x: ["allgather", "alltoall"][int(x)],
default=0,
@ -790,8 +821,8 @@ def get_cluster_input():
)
fp8_config["fp8_format"] = _ask_options(
"Which weight format should be used?",
["HYBRID", "E4M3"],
lambda x: "HYBRID" if x == 0 else "E4M3",
["HYBRID", "E4M3", "E5M2"],
lambda i: ["HYBRID", "E4M3", "E5M2"][i],
default=0,
)
fp8_config["amax_history_length"] = _ask_field(
@ -865,6 +896,7 @@ def get_cluster_input():
fp8_config=fp8_config,
deepspeed_config=deepspeed_config,
fsdp_config=fsdp_config,
parallelism_config=parallelism_config,
megatron_lm_config=megatron_lm_config,
ipex_config=ipex_config,
mpirun_config=mpirun_config,

View File

@ -194,6 +194,8 @@ class ClusterConfig(BaseConfig):
deepspeed_config: dict = None
# args for fsdp
fsdp_config: dict = None
# args for parallelism config
parallelism_config: dict = None
# args for megatron_lm
megatron_lm_config: dict = None
# args for ipex
@ -229,6 +231,8 @@ class ClusterConfig(BaseConfig):
self.mpirun_config = {}
if self.fp8_config is None:
self.fp8_config = {}
if self.parallelism_config is None:
self.parallelism_config = {}
return super().__post_init__()

View File

@ -60,4 +60,4 @@ def update_command_parser(parser, parents):
def update_config_command(args):
config_file = update_config(args)
print(f"Sucessfully updated the configuration file at {config_file}.")
print(f"Successfully updated the configuration file at {config_file}.")

View File

@ -182,13 +182,6 @@ def launch_command_parser(subparsers=None):
hardware_args.add_argument(
"--tpu", default=False, action="store_true", help="Whether or not this should launch a TPU training."
)
hardware_args.add_argument(
"--ipex",
default=False,
action="store_true",
help="Whether or not this should launch a Intel PyTorch Extension (IPEX) training.",
)
# Resource selection arguments
resource_args = parser.add_argument_group(
"Resource Selection Arguments", "Arguments for fine-tuning how available hardware should be used."
@ -269,6 +262,12 @@ def launch_command_parser(subparsers=None):
action="store_true",
help="Whether to use fsdp.",
)
paradigm_args.add_argument(
"--use_parallelism_config",
default=False,
action="store_true",
help="Whether to use the parallelism config to configure the N-d distributed training.",
)
paradigm_args.add_argument(
"--use_megatron_lm",
default=False,
@ -494,13 +493,13 @@ def launch_command_parser(subparsers=None):
"--deepspeed_exclusion_filter",
default=None,
type=str,
help="DeepSpeed exclusion filter string when using mutli-node setup.",
help="DeepSpeed exclusion filter string when using multi-node setup.",
)
deepspeed_args.add_argument(
"--deepspeed_inclusion_filter",
default=None,
type=str,
help="DeepSpeed inclusion filter string when using mutli-node setup.",
help="DeepSpeed inclusion filter string when using multi-node setup.",
)
deepspeed_args.add_argument(
"--deepspeed_multinode_launcher",
@ -586,7 +585,7 @@ def launch_command_parser(subparsers=None):
"--fsdp_use_orig_params",
default="true",
type=str,
help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres."
help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable parameters."
" (useful only when `use_fsdp` flag is passed).",
)
fsdp_args.add_argument(
@ -610,18 +609,6 @@ def launch_command_parser(subparsers=None):
type=str,
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
)
fsdp_args.add_argument(
"--fsdp_cp_size",
type=int,
default=1,
help="FSDP's context parallel size. (useful only when `use_fsdp` flag is passed and `fsdp_version` is 2). Defaults to 1 (CP not applied).",
)
fsdp_args.add_argument(
"--fsdp_cp_comm_strategy",
type=str,
default="allgather",
help="FSDP's context parallel communication strategy. (useful only when `use_fsdp` flag is passed and `fsdp_version` is 2). Defaults to `allgather`.",
)
# megatron_lm args
megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
@ -704,8 +691,8 @@ def launch_command_parser(subparsers=None):
fp8_args.add_argument(
"--fp8_format",
type=str,
default="E4M3",
choices=["E4M3", "HYBRID"],
default="HYBRID",
choices=["HYBRID", "E4M3", "E5M2"],
help="The format to use for the FP8 recipe (useful only when `--fp8_backend=te` is passed).",
)
fp8_args.add_argument(
@ -779,6 +766,45 @@ def launch_command_parser(subparsers=None):
help="The number of oneCCL worker threads when using Accelerate to launch multi-CPU training with mpirun.",
)
# ParallelismConfig arguments
parallelism_config_args = parser.add_argument_group(
"ParallelismConfig Arguments",
"Arguments related to the ParallelismConfig used for distributed training.",
)
parallelism_config_args.add_argument(
"--parallelism_config_dp_replicate_size",
type=int,
default=1,
help="The number of processes for data parallel training. Defaults to 1 (no data parallelism).",
)
parallelism_config_args.add_argument(
"--parallelism_config_dp_shard_size",
type=int,
default=1,
help="The number of processes for FSDP sharding. Defaults to 1 (No FSDP sharding).",
)
parallelism_config_args.add_argument(
"--parallelism_config_tp_size",
type=int,
default=1,
help="The number of processes for tensor parallel training. Defaults to 1 (no tensor parallelism).",
)
parallelism_config_args.add_argument(
"--parallelism_config_cp_size",
type=int,
default=1,
help="The number of processese for context parallel training. Defaults to 1 (no context parallelism).",
)
parallelism_config_args.add_argument(
"--parallelism_config_cp_comm_strategy",
type=str,
default="allgather",
help="The communication strategy for context parallel training. Defaults to 'allgather'. Other option is alltoall",
)
# Other arguments of the training scripts
parser.add_argument("training_script_args", nargs=argparse.REMAINDER, help="Arguments of the training script.")
@ -1006,6 +1032,9 @@ def _validate_launch_command(args):
if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):
raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.")
if (not args.use_fsdp or args.fsdp_version == 1) and args.use_parallelism_config:
raise ValueError("You cannot use `--use_parallelism_config` without `--use_fsdp` and `--fsdp_version=2`. ")
defaults = None
warned = []
mp_from_config_flag = False
@ -1039,6 +1068,7 @@ def _validate_launch_command(args):
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
args.use_parallelism_config = defaults.parallelism_config != {}
if args.gpu_ids is None:
if defaults.gpu_ids is not None:
args.gpu_ids = defaults.gpu_ids

View File

@ -89,7 +89,7 @@ def convert_config_to_fsdp2(config: dict) -> dict:
new_fsdp_config = {}
if fsdp_config.get("fsdp_version", 1) == 2:
logger.warning("Config already specfies FSDP2, skipping conversion...")
logger.warning("Config already specifies FSDP2, skipping conversion...")
logger.warning(
"If the config doesn't use new argument names, change `fsdp_version` to `1` and rerun the command."
)

View File

@ -32,6 +32,7 @@ from .utils import (
find_batch_size,
get_data_structure,
initialize_tensors,
is_datasets_available,
is_torch_version,
is_torchdata_stateful_dataloader_available,
send_to_device,
@ -74,7 +75,7 @@ class SeedableRandomSampler(RandomSampler):
Same as a random sampler, except that in `__iter__` a seed can be used.
Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed
and be fully reproducable on multiple iterations.
and be fully reproducible on multiple iterations.
If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on
(stored in `self.epoch`).
@ -407,7 +408,7 @@ class DataLoaderStateMixin:
class DataLoaderAdapter:
"""
A class which wraps around a PyTorch `DataLoader` (or variants of it) to be used with the `Accelerator`. For
compatability reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
compatibility reasons, this class inherits from the class it wraps around, so it can be used as a drop-in.
"""
def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, **kwargs):
@ -450,8 +451,8 @@ class DataLoaderAdapter:
@property
def __class__(self):
"""
In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)`
returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the
In order to maintain backwards compatibility with other code, we need to ensure `isinstance(obj, DataLoader)`
returns true. This is because some downstream code assumes that the `DataLoader` is the base class of the
object.
"""
return self.base_dataloader.__class__
@ -565,7 +566,8 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
try:
current_batch = next(dataloader_iter)
except StopIteration:
yield
self.end()
return
batch_index = 0
while True:
@ -761,12 +763,12 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
# if a device mesh is provided extract each dimension (dp, fsdp, tp)
# device mesh may hold any number of dimensions, however,
# below code is for targetted support for dp, fsdp and tp
# below code is for targeted support for dp, fsdp and tp
# device mesh will be used only if there is tp involved
# or any multi-dimensional parallelism involving tp
# (dp, tp) (fsdp, tp) (dp, fsdp, tp)
# otherwise the default behavour not using device mesh should be sufficient
# otherwise the default behaviour not using device mesh should be sufficient
# since multi dimensional parallelism devoid of tp would anyway need
# different batches for each process irrespective of dp or fsdp
self.submesh_tp = None
@ -1061,7 +1063,7 @@ def prepare_data_loader(
ignored otherwise.
use_seedable_sampler (`bool`, *optional*, defaults to `False`):
Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better
reproducability. Comes at a cost of potentially different performances due to different shuffling
reproducibility. Comes at a cost of potentially different performances due to different shuffling
algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every
`self.set_epoch`
data_seed (`int`, *optional*, defaults to `None`):
@ -1111,32 +1113,34 @@ def prepare_data_loader(
# Given a device mesh (dp, tp) = (2, 3):
# - From the data parallel perspective, ranks should be structured as: 0 0 0 1 1 1
# - Processes with the same DP rank will receive the same batch.
submesh_tp_size = 1
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
process_index = process_index // submesh_tp_size
num_processes = num_processes // submesh_tp_size
else:
# when device mesh is used, specifically with TP or CP
# when device mesh is used, specifically with TP
# then there is need to update process_index and num_processes
# to bring in the effect of generating same batch across TP/CP ranks
# to bring in the effect of generating same batch across TP ranks
# and different batch across FSDP and DP ranks.
# Example:
# if device mesh is (dp,fsdp,tp,cp) = (2, 2, 2, 3)
# ranks would range from 0...23
# from data angle ranks should look like 0 0 0 0 0 0 1 1 1 1 1 1 ...
# if device mesh is (dp,fsdp,tp) = (2, 2, 3)
# ranks would range from 0...11
# from data angle ranks should look like 0 0 0 1 1 1 2 2 2 3 3 3
# processes with same ranks/ids would receive the same batch
# for CP the same as TP applies
submesh_fsdp_size = 1
submesh_dp_size = 1
submesh_tp_size = 1
submesh_cp_size = 1
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
if "dp" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp"].size()
if "fsdp" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
if "cp" in torch_device_mesh.mesh_dim_names:
submesh_cp_size = torch_device_mesh["cp"].size()
if "dp_replicate" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp_replicate"].size()
if "dp_shard" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["dp_shard"].size()
process_index = process_index // (submesh_tp_size * submesh_cp_size)
num_processes = submesh_fsdp_size * submesh_dp_size
@ -1197,7 +1201,16 @@ def prepare_data_loader(
dataloader.sampler.generator = generator
# No change if no multiprocess
if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:
if isinstance(new_dataset, IterableDataset):
if is_datasets_available():
from datasets import IterableDataset as DatasetsIterableDataset
if (
is_datasets_available()
and isinstance(new_dataset, DatasetsIterableDataset)
and not split_batches
and new_dataset.n_shards > num_processes
):
new_dataset = new_dataset.shard(num_shards=num_processes, index=process_index)
elif isinstance(new_dataset, IterableDataset):
if getattr(dataloader.dataset, "generator", None) is not None:
synchronized_generator = dataloader.dataset.generator
new_dataset = IterableDatasetShard(

View File

@ -0,0 +1,189 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.import queue
import dataclasses
import os
import pickle
import queue
from io import UnsupportedOperation
from pathlib import Path
from typing import TYPE_CHECKING, cast
import torch
import torch.distributed.checkpoint as dcp
import torch.distributed.checkpoint.state_dict as dcs
from torch.distributed.checkpoint.filesystem import (
FileSystemWriter,
SavePlan,
SavePlanner,
_generate_uuid,
_split_by_size_and_type,
)
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta
from torch.distributed.checkpoint.storage import WriteResult
if TYPE_CHECKING:
from accelerate import Accelerator
class AccelerateStorageWriter(FileSystemWriter):
_DEFAULT_SUFFIX = ".distcp"
_OPTIM_FILE_PATH = "optimizer_0"
_MODEL_FILE_PATH = "pytorch_model_fsdp_0"
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
self.optim_path = self.fs.concat_path(self.path, self._OPTIM_FILE_PATH)
self.model_path = self.fs.concat_path(self.path, self._MODEL_FILE_PATH)
self.fs.mkdir(self.optim_path)
self.fs.mkdir(self.model_path)
return super().prepare_local_plan(plan)
def write_data(
self,
plan: SavePlan,
planner: SavePlanner,
):
storage_plan = plan.storage_data
optim_file_count = 0
model_file_count = 0
def gen_file(is_optimizer: bool = False) -> str:
nonlocal optim_file_count, model_file_count
if is_optimizer:
optim_file_count += 1
return f"{storage_plan.prefix}{optim_file_count}{self._DEFAULT_SUFFIX}"
else:
model_file_count += 1
return f"{storage_plan.prefix}{model_file_count}{self._DEFAULT_SUFFIX}"
file_queue: queue.Queue = queue.Queue()
for bucket in _split_by_size_and_type(1, plan.items):
optim_states = [wi for wi in bucket if "optim" in wi.index.fqn]
model_states = [wi for wi in bucket if "model" in wi.index.fqn]
for state, path in zip([optim_states, model_states], [self.optim_path, self.model_path]):
file_name = gen_file()
path = self.fs.concat_path(path, file_name)
file_queue.put((path, file_name, state))
return self._write_data(planner, file_queue)
def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
try:
metadata = dataclasses.replace(metadata, version="1.0.0")
except TypeError:
pass
def _split_metadata(
metadata: Metadata,
) -> tuple[Metadata, Metadata]:
result = []
for to_get in ["model", "optim"]:
result.append(
Metadata(
state_dict_metadata={
k.removeprefix("state."): v for k, v in metadata.state_dict_metadata.items() if to_get in k
},
planner_data={
k.removeprefix("state."): tuple([x for x in v if x != "state"])
for k, v in metadata.planner_data.items()
if to_get in k
},
)
)
return tuple(result)
model_metadata, optim_metadata = _split_metadata(metadata)
model_storage_md, optim_storage_md = {}, {}
for wr_list in results:
for wr in wr_list:
new_index = dataclasses.asdict(wr.index)
new_index["fqn"] = new_index["fqn"].removeprefix("state.")
wr = WriteResult(
index=MetadataIndex(**new_index),
size_in_bytes=wr.size_in_bytes,
storage_data=wr.storage_data,
)
if "optim" in wr.index.fqn:
optim_storage_md.update({wr.index: wr.storage_data})
else:
model_storage_md.update({wr.index: wr.storage_data})
model_metadata.storage_data = model_storage_md
optim_metadata.storage_data = optim_storage_md
model_metadata.storage_meta = StorageMeta(self.model_path, save_id=_generate_uuid())
optim_metadata.storage_meta = StorageMeta(self.optim_path, save_id=_generate_uuid())
tmp_optim_path = cast(Path, self.fs.concat_path(self.optim_path, ".metadata.tmp"))
tmp_model_path = cast(Path, self.fs.concat_path(self.model_path, ".metadata.tmp"))
for meta, tmp_path, final_path in zip(
[model_metadata, optim_metadata],
[tmp_model_path, tmp_optim_path],
[self.model_path, self.optim_path],
):
with self.fs.create_stream(tmp_path, "wb") as metadata_file:
pickle.dump(meta, metadata_file)
if self.sync_files:
try:
os.fsync(metadata_file.fileno())
except (AttributeError, UnsupportedOperation):
os.sync()
metadata_path = self.fs.concat_path(final_path, ".metadata")
if self.fs.exists(metadata_path):
self.fs.rm_file(metadata_path)
self.fs.rename(tmp_path, metadata_path)
def save_model_and_optimizer(
accelerator: "Accelerator",
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
save_path: str,
async_save: bool = False,
) -> None:
# async_save = False
if getattr(accelerator, "_async_save_handle", None) is not None:
accelerator._async_save_handle.result()
options = dcs.StateDictOptions()
import time
accelerator.print(f"{time.asctime()} - Preparing state dict...")
model_sd, optimizer_sd = dcs.get_state_dict(model, optimizer, options=options)
accelerator.print(f"{time.asctime()} - Prepared state dict...")
accelerator.print(f"{time.asctime()} - Saving state dict...")
stateful = {
"model": model_sd,
"optimizer": optimizer_sd,
}
save_fn = dcp.save if not async_save else dcp.async_save
potential_handle = dcp.async_save(
state_dict=stateful,
storage_writer=AccelerateStorageWriter(save_path),
)
accelerator.print(f"{time.asctime()} - Finished saving state dict...")
if async_save:
accelerator._async_save_handle = potential_handle

View File

@ -714,9 +714,20 @@ class CpuOffload(ModelHook):
return module.to("cpu")
def pre_forward(self, module, *args, **kwargs):
if self.prev_module_hook is not None:
self.prev_module_hook.offload()
clear_device_cache()
if self.prev_module_hook is not None and isinstance(self.prev_module_hook, UserCpuOffloadHook):
prev_module = self.prev_module_hook.model
prev_device = next(prev_module.parameters()).device
# Only offload the previous module if it is not already on CPU.
if prev_device != torch.device("cpu"):
self.prev_module_hook.offload()
clear_device_cache()
# If the current device is already the self.execution_device, we can skip the transfer.
current_device = next(module.parameters()).device
if current_device == self.execution_device:
return args, kwargs
module.to(self.execution_device)
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)

View File

@ -60,8 +60,8 @@ def notebook_launcher(
<Tip warning={true}>
To use this function absolutely zero calls to a CUDA device must be made in the notebook session before calling. If
any have been made, you will need to restart the notebook and make sure no cells use any CUDA capability.
To use this function absolutely zero calls to a device must be made in the notebook session before calling. If any
have been made, you will need to restart the notebook and make sure no cells use any device capability.
Setting `ACCELERATE_DEBUG_MODE="1"` in your environment will run a test before truly launching to ensure that none
of those calls have been made.
@ -76,11 +76,11 @@ def notebook_launcher(
Tuple of arguments to pass to the function (it will receive `*args`).
num_processes (`int`, *optional*):
The number of processes to use for training. Will default to 8 in Colab/Kaggle if a TPU is available, to
the number of GPUs available otherwise.
the number of devices available otherwise.
mixed_precision (`str`, *optional*, defaults to `"no"`):
If `fp16` or `bf16`, will use mixed precision training on multi-GPU.
If `fp16` or `bf16`, will use mixed precision training on multi-device.
use_port (`str`, *optional*, defaults to `"29500"`):
The port to use to communicate between processes when launching a multi-GPU training.
The port to use to communicate between processes when launching a multi-device training.
master_addr (`str`, *optional*, defaults to `"127.0.0.1"`):
The address to use for communication between processes.
node_rank (`int`, *optional*, defaults to 0):
@ -105,7 +105,7 @@ def notebook_launcher(
Example:
```python
# Assume this is defined in a Jupyter Notebook on an instance with two GPUs
# Assume this is defined in a Jupyter Notebook on an instance with two devices
from accelerate import notebook_launcher
@ -158,27 +158,27 @@ def notebook_launcher(
else:
if num_processes is None:
raise ValueError(
"You have to specify the number of GPUs you would like to use, add `num_processes=...` to your call."
"You have to specify the number of devices you would like to use, add `num_processes=...` to your call."
)
if node_rank >= num_nodes:
raise ValueError("The node_rank must be less than the number of nodes.")
if num_processes > 1:
# Multi-GPU launch
# Multi-device launch
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
from torch.multiprocessing import start_processes
from torch.multiprocessing.spawn import ProcessRaisedException
if len(AcceleratorState._shared_state) > 0:
raise ValueError(
"To launch a multi-GPU training from your notebook, the `Accelerator` should only be initialized "
"To launch a multi-device training from your notebook, the `Accelerator` should only be initialized "
"inside your training function. Restart your notebook and make sure no cells initializes an "
"`Accelerator`."
)
# Check for specific libraries known to initialize CUDA that users constantly use
# Check for specific libraries known to initialize device that users constantly use
problematic_imports = are_libraries_initialized("bitsandbytes")
if len(problematic_imports) > 0:
err = (
"Could not start distributed process. Libraries known to initialize CUDA upon import have been "
"Could not start distributed process. Libraries known to initialize device upon import have been "
"imported already. Please keep these imports inside your training function to try and help with this:"
)
for lib_name in problematic_imports:
@ -203,24 +203,26 @@ def notebook_launcher(
# process here (the other ones will be set be the launcher).
with patch_environment(**patched_env):
# First dummy launch
device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
distributed_type = "MULTI_XPU" if device_type == "xpu" else "MULTI_GPU"
if os.environ.get("ACCELERATE_DEBUG_MODE", "false").lower() == "true":
launcher = PrepareForLaunch(test_launch, distributed_type="MULTI_GPU")
launcher = PrepareForLaunch(test_launch, distributed_type=distributed_type)
try:
start_processes(launcher, args=(), nprocs=num_processes, start_method="fork")
except ProcessRaisedException as e:
err = "An issue was found when verifying a stable environment for the notebook launcher."
if "Cannot re-initialize CUDA in forked subprocess" in e.args[0]:
if f"Cannot re-initialize {device_type.upper()} in forked subprocess" in e.args[0]:
raise RuntimeError(
f"{err}"
"This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
"Please review your imports and test them when running the `notebook_launcher()` to identify "
"which one is problematic and causing CUDA to be initialized."
f"which one is problematic and causing {device_type.upper()} to be initialized."
) from e
else:
raise RuntimeError(f"{err} The following error was raised: {e}") from e
# Now the actual launch
launcher = PrepareForLaunch(function, distributed_type="MULTI_GPU")
print(f"Launching training on {num_processes} GPUs.")
launcher = PrepareForLaunch(function, distributed_type=distributed_type)
print(f"Launching training on {num_processes} {device_type.upper()}s.")
try:
if rdzv_conf is None:
rdzv_conf = {}
@ -244,23 +246,25 @@ def notebook_launcher(
launch_config_kwargs["log_line_prefix_template"] = log_line_prefix_template
elastic_launch(config=LaunchConfig(**launch_config_kwargs), entrypoint=function)(*args)
except ProcessRaisedException as e:
if "Cannot re-initialize CUDA in forked subprocess" in e.args[0]:
if f"Cannot re-initialize {device_type.upper()} in forked subprocess" in e.args[0]:
raise RuntimeError(
"CUDA has been initialized before the `notebook_launcher` could create a forked subprocess. "
f"{device_type.upper()} has been initialized before the `notebook_launcher` could create a forked subprocess. "
"This likely stems from an outside import causing issues once the `notebook_launcher()` is called. "
"Please review your imports and test them when running the `notebook_launcher()` to identify "
"which one is problematic and causing CUDA to be initialized."
f"which one is problematic and causing {device_type.upper()} to be initialized."
) from e
else:
raise RuntimeError(f"An issue was found when launching the training: {e}") from e
else:
# No need for a distributed launch otherwise as it's either CPU, GPU or MPS.
# No need for a distributed launch otherwise as it's either CPU, GPU, XPU or MPS.
if is_mps_available():
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
print("Launching training on MPS.")
elif torch.cuda.is_available():
print("Launching training on one GPU.")
elif torch.xpu.is_available():
print("Launching training on one XPU.")
else:
print("Launching training on CPU.")
function(*args)

View File

@ -0,0 +1,322 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
from accelerate.utils.dataclasses import TorchContextParallelConfig, TorchTensorParallelConfig
from accelerate.utils.versions import is_torch_version
if TYPE_CHECKING:
from accelerate import Accelerator
@dataclass
class ParallelismConfig:
"""
A dataclass to configure parallelisms applied to the model. Inspired by torchtitan's `ParallelDims`
https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py
Args:
dp_replicate_size (`int`, defaults to `1`):
The size of the data parallel group. If `dp_replicate_size` is set to 1, the data parallel replication
group will not be used.
dp_shard_size (`int`, defaults to `1`):
The size of the model shard group. If `dp_replicate_size > 1` and `tp_size > 1`, `dp_shard_size` must also
be greater than 1, as composing DDP + TP is currently not supported.
tp_size (`int`, defaults to `1`):
The size of the tensor parallel group. If `tp_size` is set to `1`, the tensor parallel group will not be
used.
cp_size (`int`, defaults to `1`):
The size of the context parallel group. Currently not supported, but reserved for future use and enabled
for downstream libraries.
tp_handler (`~utils.TorchTensorParallelConfig`, defaults to `None`):
The handler for the tensor parallel group.
You may obtain different distributed data parallel paradigms by configuring `dp_replicate_size` and `dp_shard_size`
together:
- `dp_replicate_size == 1` and `dp_shard_size > 1`, we obtain Fully Sharded Data Parallel (FSDP).
- `dp_replicate_size > 1` and `dp_shard_size > 1`, we obtain Hybrid Sharded Data Parallel (HSDP).
- `dp_replicate_size > 1` and `dp_shard_size == 1` is an invalid configuration, to use pure DP, use
`DistributedDataParallelKwargs` instead.
"""
dp_replicate_size: int = None
dp_shard_size: int = None
tp_size: int = None
cp_size: int = None
# we use Union because we might support other x parallel plugins (i.e. deepspeed, etc)
tp_handler: Union[None, TorchTensorParallelConfig] = None
cp_handler: Union[None, TorchContextParallelConfig] = None
device_mesh = None
def __repr__(self):
return (
"ParallelismConfig(\n "
f"\tdp_replicate_size={self.dp_replicate_size},\n"
f"\tdp_shard_size={self.dp_shard_size},\n"
f"\ttp_size={self.tp_size},\n"
f"\tcp_size={self.cp_size},\n"
f"\ttotal_size={self.total_size}\n"
f"\ttp_handler={self.tp_handler},\n"
f"\tcp_handler={self.cp_handler})\n"
)
def to_json(self):
import copy
_non_serializable_fields = ["device_mesh"]
copy.deepcopy(
{
k: copy.deepcopy(v.__dict__) if hasattr(v, "__dict__") else v
for k, v in self.__dict__.items()
if k not in _non_serializable_fields
}
)
@property
def dp_dim_names(self):
"""Names of enabled dimensions across which data parallelism is applied."""
dims = []
if self.dp_replicate_enabled:
dims += ["dp_replicate"]
if self.dp_shard_enabled:
dims += ["dp_shard"]
return dims
@property
def non_dp_dim_names(self):
"""Names of enabled dimensions which will receive the same batch (non-data parallel dimensions)."""
dims = []
if self.tp_enabled:
dims += ["tp"]
if self.cp_enabled:
dims += ["cp"]
return dims
@property
def dp_shard_cp_dim_names(self):
"""Names of enabled dimensions which will be flattened into a joint mesh across which is model sharded in FSDP."""
dims = []
if self.dp_shard_enabled:
dims += ["dp_shard"]
if self.cp_enabled:
dims += ["cp"]
return dims
@property
def dp_cp_dim_names(self):
"""Names of enabled dimensions across which loss should be averaged"""
dims = []
if self.dp_replicate_enabled:
dims += ["dp_replicate"]
if self.dp_shard_enabled:
dims += ["dp_shard"]
if self.cp_enabled:
dims += ["cp"]
return dims
@property
def fsdp_dim_names(self):
"""Names of enabled dimensions across which FSDP is applied, including data parallel replication."""
dims = []
if self.dp_replicate_enabled:
dims += ["dp_replicate"]
dims += ["dp_shard_cp"]
return dims
@property
def total_size(self):
"""The total size of the parallelism configuration, which is the product of all sizes."""
return self.dp_replicate_size * self.dp_shard_size * self.tp_size * self.cp_size
@property
def non_data_parallel_size(self):
"""The size of the non-data parallel dimensions, which is the product of tensor and context parallel sizes."""
return self.tp_size * self.cp_size
@property
def data_parallel_size(self):
"""The size of the data parallel dimensions, which is the product of data parallel replication and"""
return self.dp_replicate_size * self.dp_shard_size
@property
def dp_replicate_enabled(self):
"""True if data parallel replication is enabled, i.e. `dp_replicate_size > 1`."""
return self.dp_replicate_size > 1
@property
def dp_shard_enabled(self):
"""True if data parallel sharding is enabled, i.e. `dp_shard_size > 1`."""
return self.dp_shard_size > 1
@property
def tp_enabled(self):
"""True if tensor parallelism is enabled, i.e. `tp_size > 1`."""
return self.tp_size > 1
@property
def cp_enabled(self):
"""True if context parallelism is enabled, i.e. `cp_size > 1`."""
return self.cp_size > 1
@property
def active_mesh_dims(self):
"""Names of all active mesh dimensions."""
return self.dp_dim_names + self.non_dp_dim_names
def build_device_mesh(self, device_type: str):
"""Builds a device mesh for the given device type based on the parallelism configuration.
This method will also create required joint meshes (e.g. `dp_shard_cp`, `dp_cp`, `dp`).
Args:
device_type (`str`): The type of device for which to build the mesh, e
"""
if is_torch_version(">=", "2.2.0"):
from torch.distributed.device_mesh import init_device_mesh
else:
raise RuntimeError("Building a device_mesh requires to have torch>=2.2.0")
mesh = self._get_mesh()
if len(mesh) == 0:
return None
mesh_dim_names, mesh_shape = mesh
device_mesh = init_device_mesh(
device_type,
mesh_shape,
mesh_dim_names=mesh_dim_names,
)
if self.dp_dim_names:
device_mesh[self.dp_dim_names]._flatten("dp")
if self.dp_shard_cp_dim_names:
device_mesh[self.dp_shard_cp_dim_names]._flatten("dp_shard_cp")
if self.dp_cp_dim_names:
device_mesh[self.dp_cp_dim_names]._flatten("dp_cp")
return device_mesh
def get_device_mesh(self, device_type: Optional[str] = None):
if self.device_mesh is None:
if device_type is not None:
self.device_mesh = self.build_device_mesh(device_type)
else:
raise ("You need to pass a device_type e.g cuda to build the device mesh")
else:
if device_type is not None:
if self.device_mesh.device_type != device_type:
raise ValueError(
f"The device_mesh is already created with device type {self.device_mesh.device_type}. However, you are trying to get a device mesh with device_type {device_type}. Please check if you correctly initialized your device_mesh"
)
return self.device_mesh
def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]:
"""Generate mesh shape and dimension names for torch.distributed.init_device_mesh()."""
# Build mesh dimensions dictionary
mesh_dims = {parallelism: self._sizes[parallelism] for parallelism in self.active_mesh_dims}
# Apply canonical ordering
mesh_order = ["dp_replicate", "dp_shard", "cp", "tp"]
sorted_items = sorted(
mesh_dims.items(),
key=lambda x: (mesh_order.index(x[0])),
)
return tuple(zip(*sorted_items))
def __post_init__(self):
# Basic size validation
if self.dp_replicate_size is None:
self.dp_replicate_size = int(os.environ.get("PARALLELISM_CONFIG_DP_REPLICATE_SIZE", "1"))
if self.dp_shard_size is None:
self.dp_shard_size = int(os.environ.get("PARALLELISM_CONFIG_DP_SHARD_SIZE", "1"))
if self.tp_size is None:
self.tp_size = int(os.environ.get("PARALLELISM_CONFIG_TP_SIZE", "1"))
if self.cp_size is None:
self.cp_size = int(os.environ.get("PARALLELISM_CONFIG_CP_SIZE", "1"))
if self.tp_size > 1:
if self.tp_handler is None:
self.tp_handler = TorchTensorParallelConfig()
if self.cp_size > 1:
if self.cp_handler is None:
self.cp_handler = TorchContextParallelConfig()
if self.dp_replicate_size < 1:
raise ValueError(f"dp_replicate_size must be at least 1, but got {self.dp_replicate_size}")
if self.dp_shard_size < 1:
raise ValueError(f"dp_shard_size must be at least 1, but got {self.dp_shard_size}")
if self.tp_size < 1:
raise ValueError(f"tp_size must be at least 1, but got {self.tp_size}")
if self.cp_size < 1:
raise ValueError(f"cp_size must be at least 1, but got {self.cp_size}")
if (self.tp_size > 1 or self.cp_size > 1) and self.dp_replicate_size > 1 and self.dp_shard_size == 1:
raise ValueError(
"Tensor/Context parallelism (tp/cp_size > 1) cannot be used with pure data parallelism (dp_replicate_size > 1 and dp_shard_size == 1). "
"Please set dp_shard_size > 1 and dp_replicate_size == 1 to compose FSDP + TP/CP for 2D parallel, "
"or set dp_replicate_size == 1 and dp_shard_size > 1 to compose HSDP + TP/CP for 3D parallel."
)
self._sizes = {
"dp_replicate": self.dp_replicate_size,
"dp_shard": self.dp_shard_size,
"tp": self.tp_size,
"cp": self.cp_size,
}
def _set_size(self, parallelism: str, size: int):
assert parallelism in self._sizes.keys(), f"Parallelism must be one of {self._sizes.keys()}"
self._sizes[parallelism] = size
setattr(self, f"{parallelism}_size", size)
def _validate_accelerator(self, accelerator: "Accelerator"):
_warnings = set()
if not accelerator.multi_device and self.total_size == 1:
# No distributed setup, valid parallelism config
return
# We need this to ensure DDP works
if self.total_size == 1:
self._set_size("dp_replicate", accelerator.num_processes)
if self.total_size != accelerator.num_processes:
raise ValueError(
f"ParallelismConfig total_size ({self.total_size}) does not match "
f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
f"dp_shard_size/tp_size/cp_size."
)
if self.total_size > 1 and not (accelerator.is_fsdp2 or accelerator.multi_device):
raise ValueError(
f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}."
)
for parallelism, size in self._sizes.items():
if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
_warnings.add(
f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
)
if _warnings and accelerator.is_main_process:
warnings.warn(
"ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
UserWarning,
)

View File

@ -132,7 +132,7 @@ class PartialState:
Whether or not to force the script to execute on CPU. Will ignore any accelerators available if set to
`True` and force the execution on the CPU.
kwargs (additional keyword arguments, *optional*):
Additional keyword arguments to pass to the relevent `init_process_group` function. Valid `kwargs` can be
Additional keyword arguments to pass to the relevant `init_process_group` function. Valid `kwargs` can be
found in [`utils.InitProcessGroupKwargs`]. See the example section for detailed usage.
**Available attributes:**
@ -187,7 +187,7 @@ class PartialState:
dist_information = None
if use_sagemaker_dp is None:
use_sagemaker_dp = (
os.environ.get("ACCELERATE_USE_SAGEMAKER", "false") == "true"
os.environ.get("ACCELERATE_USE_SAGEMAKER", "false").lower() == "true"
and os.environ.get("ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE") != SageMakerDistributedType.NO
)
@ -195,14 +195,14 @@ class PartialState:
original_backend = kwargs.pop("backend", None)
backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, original_backend)
if original_backend is not None and backend != original_backend:
raise ValueError(f"Your assigned backend {original_backend} is not avaliable, please use {backend}")
raise ValueError(f"Your assigned backend {original_backend} is not available, please use {backend}")
self.backend = backend
self.distributed_type = distributed_type
use_deepspeed = False
if not cpu and self.backend != "xla":
if int(os.environ.get("LOCAL_RANK", -1)) != -1:
# Deal with spawning deepspeed
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true":
if not is_deepspeed_available():
raise ImportError(
"DeepSpeed is not available => install it using `pip3 install deepspeed` or build it from source"
@ -213,12 +213,6 @@ class PartialState:
if self.backend == "tccl":
local_rank = os.environ.get("LOCAL_RANK", -1)
torch.sdaa.set_device(f"sdaa:{local_rank}")
if (
self.backend == "nccl"
and os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
and os.environ.get("FSDP_OFFLOAD_PARAMS", "false") == "true"
):
self.backend = "cuda:nccl,cpu:gloo"
dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs)
# We need to flag to `use_deepspeed` to be True to override `distributed_type` later
use_deepspeed = True
@ -230,6 +224,16 @@ class PartialState:
if self.backend == "tccl":
local_rank = os.environ.get("LOCAL_RANK", -1)
torch.sdaa.set_device(f"sdaa:{local_rank}")
if (
self.backend == "nccl"
and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
and (
os.environ.get("FSDP_OFFLOAD_PARAMS", "false").lower() == "true"
or os.environ.get("FSDP_STATE_DICT_TYPE", "SHARDED_STATE_DICT") == "FULL_STATE_DICT"
or True
)
):
self.backend = "cuda:nccl,cpu:gloo"
torch.distributed.init_process_group(backend=self.backend, **kwargs)
# XPU and CPU require special env configs to be set
@ -397,7 +401,7 @@ class PartialState:
DistributedType.DEEPSPEED,
DistributedType.FSDP,
):
torch.distributed.barrier()
torch.distributed.barrier(device_ids=[self.local_process_index])
elif self.distributed_type == DistributedType.XLA:
xm.rendezvous("accelerate.utils.wait_for_everyone")
@ -866,6 +870,8 @@ class AcceleratorState:
- **device** (`torch.device`) -- The device to use.
- **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
in use.
- **parallelism_config** ([`~accelerate.utils.ParallelismConfig`]) -- The parallelism configuration for the
current training environment. This is used to configure the distributed training environment.
- **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`.
- **local_process_index** (`int`) -- The index of the current process on the current server.
- **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
@ -896,6 +902,7 @@ class AcceleratorState:
fsdp_plugin=None,
torch_tp_plugin=None,
megatron_lm_plugin=None,
parallelism_config=None,
_from_accelerator: bool = False,
**kwargs,
):
@ -910,6 +917,8 @@ class AcceleratorState:
self.deepspeed_plugins = None
self.use_ipex = None
self.torch_tp_plugin = torch_tp_plugin
self.parallelism_config = parallelism_config
self.device_mesh = None
mixed_precision = (
parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
if mixed_precision is None
@ -941,8 +950,13 @@ class AcceleratorState:
"Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` "
"before using any functionality from the `accelerate` library."
)
# deepspeed handles mixed_precision using deepspeed_config
self._mixed_precision = "no" if self.distributed_type == DistributedType.DEEPSPEED else mixed_precision
# deepspeed handles mixed_precision using deepspeed_config. But we need to set it to fp8
# if we're using fp8.
if self.distributed_type == DistributedType.DEEPSPEED and mixed_precision != "fp8":
self._mixed_precision = "no"
else:
self._mixed_precision = mixed_precision
if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True):
if mixed_precision == "bf16":
if os.environ.get("ACCELERATE_DOWNCAST_BF16"):
@ -953,7 +967,7 @@ class AcceleratorState:
os.environ["XLA_USE_BF16"] = str(1)
os.environ["XLA_DOWNCAST_BF16"] = str(0)
self.downcast_bfloat = False
elif os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" and not cpu:
elif os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true" and not cpu:
self.distributed_type = DistributedType.DEEPSPEED
if not isinstance(deepspeed_plugin, dict):
deepspeed_plugin.set_mixed_precision(mixed_precision)
@ -974,19 +988,35 @@ class AcceleratorState:
DistributedType.MULTI_XPU,
DistributedType.MULTI_HPU,
]:
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or fsdp_plugin is not None:
# TODO: Siro - remove when axolotl fixes their side
if not os.environ.get("ACCELERATE_ALLOW_CP_STANDALONE", "false").lower() == "true":
if self.parallelism_config and self.parallelism_config.cp_enabled and fsdp_plugin is None:
raise ValueError(
"`cp_size > 1` specified in the `parallelism_config`, but no `fsdp_plugin` was provided. We need a `fsdp_plugin` to use context parallelism, as we also shard the model across the device mesh to save more memory"
)
if (
self.parallelism_config is not None
and self.parallelism_config.cp_enabled
and fsdp_plugin.fsdp_version == 1
):
raise ValueError(
"Using `cp_size>1` requires FSDP2, but the provided `fsdp_plugin` is using FSDP1. "
)
if (os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" or fsdp_plugin is not None) or (
self.parallelism_config is not None and self.parallelism_config.cp_enabled
):
self.distributed_type = DistributedType.FSDP
if self._mixed_precision != "no":
if self._mixed_precision != "no" and fsdp_plugin is not None:
fsdp_plugin.set_mixed_precision(self._mixed_precision)
self.fsdp_plugin = fsdp_plugin
if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" and self.distributed_type not in [
if os.environ.get(
"ACCELERATE_USE_MEGATRON_LM", "false"
).lower() == "true" and self.distributed_type not in [
DistributedType.MULTI_XPU,
]:
self.distributed_type = DistributedType.MEGATRON_LM
megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
self.megatron_lm_plugin = megatron_lm_plugin
if self.torch_tp_plugin is not None:
self.distributed_type = DistributedType.TP
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
if is_ipex_available():
# check if user disables it explicitly
@ -1032,7 +1062,7 @@ class AcceleratorState:
@property
def mixed_precision(self):
if self.distributed_type == DistributedType.DEEPSPEED:
if self.distributed_type == DistributedType.DEEPSPEED and self._mixed_precision != "fp8":
config = self.deepspeed_plugin.deepspeed_config
if config.get("fp16", {}).get("enabled", False):
mixed_precision = "fp16"
@ -1055,7 +1085,7 @@ class AcceleratorState:
"""
Destroys the process group. If one is not specified, the default process group is destroyed.
If `self.fork_lauched` is `True` and `group` is `None`, nothing happens.
If `self.fork_launched` is `True` and `group` is `None`, nothing happens.
"""
PartialState().destroy_process_group(group)

View File

@ -53,6 +53,7 @@ from .testing import (
require_torchvision,
require_tpu,
require_transformer_engine,
require_transformer_engine_mxfp8,
require_xpu,
run_first,
skip,

View File

@ -1,46 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.utils.data import DataLoader
from accelerate import Accelerator
def main():
accelerator = Accelerator()
B, S, D = 2, 3, 4
rank_data = torch.ones((B, S, D), device="cuda") * (accelerator.process_index + 1)
all_rank_data = [torch.empty_like(rank_data) for _ in range(accelerator.num_processes)]
torch.distributed.all_gather(all_rank_data, rank_data)
dataloader = DataLoader(all_rank_data, batch_size=B, shuffle=False)
dataloader = accelerator.prepare(dataloader)
for batch in dataloader:
all_rank_batch = [torch.empty_like(batch) for _ in range(accelerator.num_processes)]
torch.distributed.all_gather(all_rank_batch, batch)
if accelerator.is_main_process:
for rank_idx in range(accelerator.num_processes):
torch.testing.assert_close(
all_rank_batch[0],
all_rank_batch[rank_idx],
msg=f"Rank {rank_idx} batch {all_rank_batch[rank_idx]} differs from rank 0 batch {all_rank_batch[0]}",
)
accelerator.end_training()
if __name__ == "__main__":
main()

View File

@ -34,8 +34,7 @@ from accelerate.state import AcceleratorState
from accelerate.utils.deepspeed import get_active_deepspeed_plugin
MAX_GPU_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 32
EVAL_BATCH_SIZE = 16
class NoiseModel(torch.nn.Module):
@ -318,11 +317,11 @@ def main():
parser.add_argument(
"--num_epochs",
type=int,
default=2,
default=3,
help="Number of train epochs.",
)
args = parser.parse_args()
config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16}
config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 8}
single_model_training(config, args)
AcceleratorState._reset_state(True)
multiple_model_training(config, args)

View File

@ -69,7 +69,7 @@ class TorchTracemalloc:
self.begin = torch.npu.memory_allocated()
elif is_xpu_available():
torch.xpu.empty_cache()
torch.xpu.reset_max_memory_allocated() # reset the peak gauge to zero
torch.xpu.reset_peak_memory_stats() # reset the peak gauge to zero
self.begin = torch.xpu.memory_allocated()
elif is_hpu_available():
# torch.hpu.empty_cache() # not available on hpu as it reserves all device memory for the current process

View File

@ -25,7 +25,8 @@ from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
from accelerate import Accelerator, DistributedType
from accelerate.utils import SAFE_WEIGHTS_NAME, TorchTensorParallelPlugin, set_seed
from accelerate.parallelism_config import ParallelismConfig
from accelerate.utils import SAFE_WEIGHTS_NAME, set_seed
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
@ -83,7 +84,7 @@ def training_function(config, args):
accelerator_kwargs = {}
# need this for DeepSpeed tests as `args.tp_size` would be None and `torch.distributed.init_device_mesh` would fail
if args.tp_size is not None:
accelerator_kwargs["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=args.tp_size)
accelerator_kwargs["parallelism_config"] = ParallelismConfig(tp_size=args.tp_size)
# Initialize accelerator
accelerator = Accelerator(**accelerator_kwargs)

View File

@ -79,10 +79,6 @@ def mock_training(accelerator, model):
def check_weights(operation, state_1, state_2):
for weight_1, weight_2 in zip(state_1.values(), state_2.values()):
if str(weight_1.device) != torch_device:
weight_1 = weight_1.to(torch_device)
if str(weight_2.device) != torch_device:
weight_2 = weight_2.to(torch_device)
if operation == "same":
assert torch.allclose(weight_1, weight_2)
else:
@ -91,7 +87,7 @@ def check_weights(operation, state_1, state_2):
def check_safetensors_weights(path, model):
safe_state_dict = load_file(path / "model.safetensors")
safe_loaded_model = TinyModel()
safe_loaded_model = TinyModel().to(torch_device)
check_weights("diff", model.state_dict(), safe_loaded_model.state_dict())
safe_loaded_model.load_state_dict(safe_state_dict)
check_weights("same", model.state_dict(), safe_loaded_model.state_dict())
@ -99,7 +95,7 @@ def check_safetensors_weights(path, model):
def check_pytorch_weights(path, model):
nonsafe_state_dict = torch.load(path / "pytorch_model.bin", weights_only=True)
nonsafe_loaded_model = TinyModel()
nonsafe_loaded_model = TinyModel().to(torch_device)
check_weights("diff", model.state_dict(), nonsafe_loaded_model.state_dict())
nonsafe_loaded_model.load_state_dict(nonsafe_state_dict)
check_weights("same", model.state_dict(), nonsafe_loaded_model.state_dict())

View File

@ -50,7 +50,7 @@ def test_gather_object(state):
assert gathered_obj == list(range(state.num_processes)), f"{gathered_obj} != {list(range(state.num_processes))}"
def test_gather_non_contigous(state):
def test_gather_non_contiguous(state):
# Skip this test because the 'is_contiguous' function of XLA tensor always returns True.
if state.distributed_type == DistributedType.XLA:
return
@ -160,8 +160,8 @@ def main():
test_gather(state)
state.print("testing gather_object")
test_gather_object(state)
state.print("testing gather non-contigous")
test_gather_non_contigous(state)
state.print("testing gather non-contiguous")
test_gather_non_contiguous(state)
state.print("testing broadcast")
test_broadcast(state)
state.print("testing pad_across_processes")

View File

@ -35,10 +35,12 @@ from accelerate.utils import (
gather,
gather_object,
is_bf16_available,
is_cuda_available,
is_datasets_available,
is_fp16_available,
is_hpu_available,
is_ipex_available,
is_mps_available,
is_pytest_available,
is_xpu_available,
set_seed,
@ -534,7 +536,7 @@ def training_check(use_seedable_sampler=False):
accelerator.print("Training yielded the same results on one CPU or distributed setup with batch split.")
# FP32 wrapper check
if torch.cuda.is_available():
if is_cuda_available() or is_mps_available():
# Mostly a test that model.forward will have autocast when running unwrap_model(model, keep_fp32_wrapper=True)
print("Keep fp32 wrapper check.")
AcceleratorState._reset_state()
@ -625,7 +627,7 @@ def training_check(use_seedable_sampler=False):
msg=lambda msg: f"Did not obtain the same model on CPU or distributed training.\n{msg}",
)
# IPEX support is only for CPU
# IPEX CPU tests
if is_ipex_available():
print("ipex BF16 training check.")
AcceleratorState._reset_state()

View File

@ -61,6 +61,7 @@ from ..utils import (
is_pytest_available,
is_schedulefree_available,
is_sdaa_available,
is_swanlab_available,
is_tensorboard_available,
is_timm_available,
is_torch_version,
@ -68,7 +69,9 @@ from ..utils import (
is_torchao_available,
is_torchdata_stateful_dataloader_available,
is_torchvision_available,
is_trackio_available,
is_transformer_engine_available,
is_transformer_engine_mxfp8_available,
is_transformers_available,
is_triton_available,
is_wandb_available,
@ -249,6 +252,10 @@ def require_fp8(test_case):
return unittest.skipUnless(fp8_is_available, "test requires FP8 support")(test_case)
def require_fsdp2(test_case):
return unittest.skipUnless(is_torch_version(">=", "2.5.0"), "test requires FSDP2 (torch >= 2.5.0)")(test_case)
def require_mlu(test_case):
"""
Decorator marking a test that requires MLU. These tests are skipped when there are no MLU available.
@ -454,6 +461,13 @@ def require_wandb(test_case):
return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
def require_trackio(test_case):
"""
Decorator marking a test that requires trackio installed. These tests are skipped when trackio isn't installed
"""
return unittest.skipUnless(is_trackio_available(), "test requires trackio")(test_case)
def require_comet_ml(test_case):
"""
Decorator marking a test that requires comet_ml installed. These tests are skipped when comet_ml isn't installed
@ -482,6 +496,13 @@ def require_dvclive(test_case):
return unittest.skipUnless(is_dvclive_available(), "test requires dvclive")(test_case)
def require_swanlab(test_case):
"""
Decorator marking a test that requires swanlab installed. These tests are skipped when swanlab isn't installed
"""
return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case)
def require_pandas(test_case):
"""
Decorator marking a test that requires pandas installed. These tests are skipped when pandas isn't installed
@ -520,6 +541,16 @@ def require_transformer_engine(test_case):
return unittest.skipUnless(is_transformer_engine_available(), "test requires transformers engine")(test_case)
def require_transformer_engine_mxfp8(test_case):
"""
Decorator marking a test that requires transformers engine MXFP8 block scaling available. These tests are skipped
when transformers engine MXFP8 block scaling isn't available
"""
return unittest.skipUnless(
is_transformer_engine_mxfp8_available(), "test requires transformers engine MXFP8 block scaling"
)(test_case)
def require_torchao(test_case):
"""
Decorator marking a test that requires torchao installed. These tests are skipped when torchao isn't installed
@ -536,7 +567,8 @@ def require_matplotlib(test_case):
_atleast_one_tracker_available = (
any([is_wandb_available(), is_tensorboard_available()]) and not is_comet_ml_available()
any([is_wandb_available(), is_tensorboard_available(), is_trackio_available(), is_swanlab_available()])
and not is_comet_ml_available()
)
@ -566,7 +598,7 @@ def require_torchdata_stateful_dataloader(test_case):
def run_first(test_case):
"""
Decorator marking a test with order(1). When pytest-order plugin is installed, tests marked with this decorator are
garanteed to run first.
guaranteed to run first.
This is especially useful in some test settings like on a Gaudi instance where a Gaudi device can only be used by a
single process at a time. So we make sure all tests that run in a subprocess are launched first, to avoid device
@ -585,7 +617,7 @@ def run_first(test_case):
class TempDirTestCase(unittest.TestCase):
"""
A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
data at the start of a test, and then destroyes it at the end of the TestCase.
data at the start of a test, and then destroys it at the end of the TestCase.
Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases

View File

@ -34,7 +34,9 @@ from .utils import (
is_comet_ml_available,
is_dvclive_available,
is_mlflow_available,
is_swanlab_available,
is_tensorboard_available,
is_trackio_available,
is_wandb_available,
listify,
)
@ -63,6 +65,12 @@ if is_clearml_available():
if is_dvclive_available():
_available_trackers.append(LoggerType.DVCLIVE)
if is_swanlab_available():
_available_trackers.append(LoggerType.SWANLAB)
if is_trackio_available():
_available_trackers.append(LoggerType.TRACKIO)
logger = get_logger(__name__)
@ -103,7 +111,7 @@ class GeneralTracker:
(`bool`): Whether the logger requires a directory to store their logs. `tracker` (`object`): Should return internal
tracking mechanism used by a tracker class (such as the `run` for wandb)
Implementations can also include a `main_process_only` (`bool`) attribute to toggle if relevent logging, init, and
Implementations can also include a `main_process_only` (`bool`) attribute to toggle if relevant logging, init, and
other functions should occur on the main process or across all processes (by default will use `True`)
"""
@ -133,7 +141,7 @@ class GeneralTracker:
def start(self):
"""
Lazy initialization of the tracker inside Accelerator to avoid initalizing PartialState before
Lazy initialization of the tracker inside Accelerator to avoid initializing PartialState before
InitProcessGroupKwargs.
"""
pass
@ -332,7 +340,16 @@ class WandBTracker(GeneralTracker):
"""
import wandb
wandb.config.update(values, allow_val_change=True)
if os.environ.get("WANDB_MODE") == "offline":
# In offline mode, restart wandb with config included
if hasattr(self, "run") and self.run:
self.run.finish()
init_kwargs = self.init_kwargs.copy()
init_kwargs["config"] = values
self.run = wandb.init(project=self.run_name, **init_kwargs)
else:
wandb.config.update(values, allow_val_change=True)
logger.debug("Stored initial configuration hyperparameters to WandB")
@on_main_process
@ -411,6 +428,83 @@ class WandBTracker(GeneralTracker):
logger.debug("WandB run closed")
class TrackioTracker(GeneralTracker):
"""
A `Tracker` class that supports `trackio`. Should be initialized at the start of your script.
Args:
run_name (`str`):
The name of the experiment run. Will be used as the `project` name when instantiating trackio.
**kwargs (additional keyword arguments, *optional*):
Additional key word arguments passed along to the `trackio.init` method. Refer to this
[init](https://github.com/gradio-app/trackio/blob/814809552310468b13f84f33764f1369b4e5136c/trackio/__init__.py#L22)
to see all supported key word arguments.
"""
name = "trackio"
requires_logging_directory = False
main_process_only = False
def __init__(self, run_name: str, **kwargs):
super().__init__()
self.run_name = run_name
self.init_kwargs = kwargs
@on_main_process
def start(self):
import trackio
self.run = trackio.init(project=self.run_name, **self.init_kwargs)
logger.debug(f"Initialized trackio project {self.run_name}")
logger.debug(
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
)
@property
def tracker(self):
return self.run
@on_main_process
def store_init_configuration(self, values: dict):
"""
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
Args:
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
`str`, `float`, `int`, or `None`.
"""
import trackio
trackio.config.update(values, allow_val_change=True)
logger.debug("Stored initial configuration hyperparameters to trackio")
@on_main_process
def log(self, values: dict, step: Optional[int] = None, **kwargs):
"""
Logs `values` to the current run.
Args:
values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
`str` to `float`/`int`.
step (`int`, *optional*):
The run step. If included, the log will be affiliated with this step.
kwargs:
Additional key word arguments passed along to the `trackio.log` method.
"""
self.run.log(values, **kwargs)
logger.debug("Successfully logged to trackio")
@on_main_process
def finish(self):
"""
Closes `trackio` run
"""
self.run.finish()
logger.debug("trackio run closed")
class CometMLTracker(GeneralTracker):
"""
A `Tracker` class that supports `comet_ml`. Should be initialized at the start of your script.
@ -1061,6 +1155,106 @@ class DVCLiveTracker(GeneralTracker):
self.live.end()
class SwanLabTracker(GeneralTracker):
"""
A `Tracker` class that supports `swanlab`. Should be initialized at the start of your script.
Args:
run_name (`str`):
The name of the experiment run.
**kwargs (additional keyword arguments, *optional*):
Additional key word arguments passed along to the `swanlab.init` method.
"""
name = "swanlab"
requires_logging_directory = False
main_process_only = False
def __init__(self, run_name: str, **kwargs):
super().__init__()
self.run_name = run_name
self.init_kwargs = kwargs
@on_main_process
def start(self):
import swanlab
self.run = swanlab.init(project=self.run_name, **self.init_kwargs)
swanlab.config["FRAMEWORK"] = "🤗Accelerate" # add accelerate logo in config
logger.debug(f"Initialized SwanLab project {self.run_name}")
logger.debug(
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
)
@property
def tracker(self):
return self.run
@on_main_process
def store_init_configuration(self, values: dict):
"""
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
Args:
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
`str`, `float`, `int`, or `None`.
"""
import swanlab
swanlab.config.update(values, allow_val_change=True)
logger.debug("Stored initial configuration hyperparameters to SwanLab")
@on_main_process
def log(self, values: dict, step: Optional[int] = None, **kwargs):
"""
Logs `values` to the current run.
Args:
data : Dict[str, DataType]
Data must be a dict. The key must be a string with 0-9, a-z, A-Z, " ", "_", "-", "/". The value must be a
`float`, `float convertible object`, `int` or `swanlab.data.BaseType`.
step : int, optional
The step number of the current data, if not provided, it will be automatically incremented.
If step is duplicated, the data will be ignored.
kwargs:
Additional key word arguments passed along to the `swanlab.log` method. Likes:
print_to_console : bool, optional
Whether to print the data to the console, the default is False.
"""
self.run.log(values, step=step, **kwargs)
logger.debug("Successfully logged to SwanLab")
@on_main_process
def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
"""
Logs `images` to the current run.
Args:
values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
step (`int`, *optional*):
The run step. If included, the log will be affiliated with this step.
kwargs:
Additional key word arguments passed along to the `swanlab.log` method. Likes:
print_to_console : bool, optional
Whether to print the data to the console, the default is False.
"""
import swanlab
for k, v in values.items():
self.log({k: [swanlab.Image(image) for image in v]}, step=step, **kwargs)
logger.debug("Successfully logged images to SwanLab")
@on_main_process
def finish(self):
"""
Closes `swanlab` writer
"""
self.run.finish()
logger.debug("SwanLab run closed")
LOGGER_TYPE_TO_CLASS = {
"aim": AimTracker,
"comet_ml": CometMLTracker,
@ -1069,6 +1263,8 @@ LOGGER_TYPE_TO_CLASS = {
"wandb": WandBTracker,
"clearml": ClearMLTracker,
"dvclive": DVCLiveTracker,
"swanlab": SwanLabTracker,
"trackio": TrackioTracker,
}
@ -1090,9 +1286,12 @@ def filter_trackers(
- `"all"`
- `"tensorboard"`
- `"wandb"`
- `"trackio"`
- `"aim"`
- `"comet_ml"`
- `"mlflow"`
- `"dvclive"`
- `"swanlab"`
If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
logging_dir (`str`, `os.PathLike`, *optional*):

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..parallelism_config import ParallelismConfig
from .ao import convert_model_to_fp8_ao, filter_first_and_last_linear_layers, has_ao_layers
from .constants import (
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION,
@ -60,7 +61,9 @@ from .dataclasses import (
SageMakerDistributedType,
TensorInformation,
TERecipeKwargs,
TorchContextParallelConfig,
TorchDynamoPlugin,
TorchTensorParallelConfig,
TorchTensorParallelPlugin,
add_model_config_to_megatron_parser,
)
@ -121,6 +124,7 @@ from .imports import (
is_sagemaker_available,
is_schedulefree_available,
is_sdaa_available,
is_swanlab_available,
is_tensorboard_available,
is_timm_available,
is_torch_xla_available,
@ -128,7 +132,9 @@ from .imports import (
is_torchdata_available,
is_torchdata_stateful_dataloader_available,
is_torchvision_available,
is_trackio_available,
is_transformer_engine_available,
is_transformer_engine_mxfp8_available,
is_transformers_available,
is_triton_available,
is_wandb_available,
@ -281,6 +287,7 @@ from .other import (
is_port_in_use,
load,
merge_dicts,
model_has_dtensor,
recursive_getattr,
save,
wait_for_everyone,

View File

@ -314,7 +314,7 @@ def _replace_with_bnb_layers(
"""
Private method that wraps the recursion for module replacement.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
Returns the converted model and a boolean that indicates if the conversion has been successful or not.
"""
# bitsandbytes will initialize CUDA on import, so it needs to be imported lazily
import bitsandbytes as bnb

View File

@ -44,7 +44,6 @@ FSDP_PYTORCH_VERSION = (
"2.1.0.a0+32f93b1" # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image.
)
FSDP2_PYTORCH_VERSION = "2.6.0"
CONTEXT_PARALLEL_PYTORCH_VERSION = "2.7.0"
FSDP_MODEL_NAME = "pytorch_model_fsdp"
DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich", "nossh", "slurm"]
TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"]
@ -52,7 +51,9 @@ ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0"
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0"
BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0"
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.52.0"
BETA_CP_AVAILABLE_PYTORCH_VERSION = "2.6.0"
STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}

View File

@ -32,8 +32,9 @@ from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, get_a
import torch
from .constants import (
BETA_CP_AVAILABLE_PYTORCH_VERSION,
BETA_TP_AVAILABLE_PYTORCH_VERSION,
CONTEXT_PARALLEL_PYTORCH_VERSION,
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION,
FSDP2_PYTORCH_VERSION,
FSDP_AUTO_WRAP_POLICY,
FSDP_BACKWARD_PREFETCH,
@ -59,6 +60,7 @@ if TYPE_CHECKING:
# Mock imports for type checking
from torchao.float8 import Float8LinearConfig
logger = logging.getLogger(__name__)
@ -185,7 +187,9 @@ class DistributedDataParallelKwargs(KwargsHandler):
comm_hook: DDPCommunicationHookType = DDPCommunicationHookType.NO
comm_wrapper: Literal[
DDPCommunicationHookType.NO, DDPCommunicationHookType.FP16, DDPCommunicationHookType.BF16
DDPCommunicationHookType.NO,
DDPCommunicationHookType.FP16,
DDPCommunicationHookType.BF16,
] = DDPCommunicationHookType.NO
comm_state_option: dict = field(default_factory=dict)
@ -193,7 +197,10 @@ class DistributedDataParallelKwargs(KwargsHandler):
return {k: v for k, v in super().to_dict().items() if k not in ignore_keys}
def register_comm_hook(self, model):
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks, powerSGD_hook
from torch.distributed.algorithms.ddp_comm_hooks import (
default_hooks,
powerSGD_hook,
)
hook_map: dict[DDPCommunicationHookType, Callable] = {
DDPCommunicationHookType.FP16: default_hooks.fp16_compress_hook,
@ -216,7 +223,11 @@ class DistributedDataParallelKwargs(KwargsHandler):
if hook:
state = (
powerSGD_hook.PowerSGDState(None, **self.comm_state_option)
if self.comm_hook in (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.BATCHED_POWER_SGD)
if self.comm_hook
in (
DDPCommunicationHookType.POWER_SGD,
DDPCommunicationHookType.BATCHED_POWER_SGD,
)
else None
)
model.register_comm_hook(
@ -290,7 +301,7 @@ class InitProcessGroupKwargs(KwargsHandler):
# Literals
Backend = Literal["MSAMP", "TE"]
OptLevel = Literal["O1", "O2"]
FP8Format = Literal["E4M3", "HYBRID"]
FP8Format = Literal["HYBRID", "E4M3", "E5M2"]
AmaxComputeAlgorithm = Literal["max", "most_recent"]
@ -343,8 +354,8 @@ class TERecipeKwargs(KwargsHandler):
interval (`int`, *optional*, default to 1):
The interval to use for how often the scaling factor is recomputed.
fp8_format (`str`, *optional*, default to "HYBRID"):
The format to use for the FP8 recipe. Must be one of `HYBRID` or `E4M3`. (Generally `HYBRID` for training,
`E4M3` for evaluation)
The format to use for the FP8 recipe. Must be one of `HYBRID`, `E4M3` or `E5M2`. (Generally `HYBRID` for
training, `E4M3` or `E5M2` for evaluation)
amax_history_len (`int`, *optional*, default to 1024):
The length of the history to use for the scaling factor computation
amax_compute_algo (`str`, *optional*, default to "most_recent"):
@ -360,6 +371,7 @@ class TERecipeKwargs(KwargsHandler):
amax_history_len: int = None
amax_compute_algo: AmaxComputeAlgorithm = None
override_linear_precision: tuple[bool, bool, bool] = None
use_mxfp8_block_scaling: bool = None
def __post_init__(self):
env_prefix = "ACCELERATE_FP8_"
@ -388,6 +400,8 @@ class TERecipeKwargs(KwargsHandler):
dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD")
wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD")
self.override_linear_precision = (fprop, dgrad, wgrad)
if self.use_mxfp8_block_scaling is None:
self.use_mxfp8_block_scaling = parse_flag_from_env(env_prefix + "USE_MXFP8_BLOCK_SCALING")
@dataclass
@ -583,7 +597,6 @@ class DistributedType(str, enum.Enum):
MULTI_XPU = "MULTI_XPU"
DEEPSPEED = "DEEPSPEED"
FSDP = "FSDP"
TP = "TP"
XLA = "XLA"
MEGATRON_LM = "MEGATRON_LM"
MULTI_HPU = "MULTI_HPU"
@ -617,8 +630,10 @@ class FP8BackendType(str, enum.Enum):
"""
# Subclassing str as well as Enum allows the `FP8BackendType` to be JSON-serializable out of the box.
NO = "NO"
TE = "TE"
MSAMP = "MSAMP"
AO = "AO"
class ComputeEnvironment(str, enum.Enum):
@ -668,7 +683,7 @@ class DynamoBackend(str, BaseEnum):
more](https://github.com/pytorch/xla/blob/r2.0/docs/dynamo.md)
- **IPEX** -- Uses IPEX for inference on CPU. Inference only. [Read
more](https://github.com/intel/intel-extension-for-pytorch).
- **TVM** -- Uses Apach TVM for inference optimizations. [Read more](https://tvm.apache.org/)
- **TVM** -- Uses Apache TVM for inference optimizations. [Read more](https://tvm.apache.org/)
- **HPU_BACKEND** -- Uses HPU backend for inference optimizations.
"""
@ -700,18 +715,24 @@ class LoggerType(BaseEnum):
- **ALL** -- all available trackers in the environment that are supported
- **TENSORBOARD** -- TensorBoard as an experiment tracker
- **WANDB** -- wandb as an experiment tracker
- **TRACKIO** -- trackio as an experiment tracker
- **COMETML** -- comet_ml as an experiment tracker
- **MLFLOW** -- mlflow as an experiment tracker
- **CLEARML** -- clearml as an experiment tracker
- **DVCLIVE** -- dvclive as an experiment tracker
- **SWANLAB** -- swanlab as an experiment tracker
"""
ALL = "all"
AIM = "aim"
TENSORBOARD = "tensorboard"
WANDB = "wandb"
TRACKIO = "trackio"
COMETML = "comet_ml"
MLFLOW = "mlflow"
CLEARML = "clearml"
DVCLIVE = "dvclive"
SWANLAB = "swanlab"
class PrecisionType(str, BaseEnum):
@ -783,9 +804,9 @@ class DataLoaderConfiguration:
all workers.
use_seedable_sampler (`bool`, defaults to `False`):
Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`]). Ensures
training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results.
training results are fully reproducible using a different sampling technique. While seed-to-seed results
may differ, on average the differences are negligible when using multiple different seeds to compare.
Should also be ran with [`~utils.set_seed`] for the best results.
data_seed (`int`, defaults to `None`):
The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator
will use the current default seed from torch.
@ -828,8 +849,8 @@ class DataLoaderConfiguration:
default=False,
metadata={
"help": "Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`])."
"Ensures training results are fully reproducable using a different sampling technique. "
"While seed-to-seed results may differ, on average the differences are neglible when using"
"Ensures training results are fully reproducible using a different sampling technique. "
"While seed-to-seed results may differ, on average the differences are negligible when using"
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
},
)
@ -935,7 +956,7 @@ class GradientAccumulationPlugin(KwargsHandler):
sync_with_dataloader (`bool`, *optional*, defaults to `True`):
Whether to synchronize setting the gradients when at the end of the dataloader.
sync_each_batch (`bool`, *optional*):
Whether to synchronize setting the gradients at each data batch. Seting to `True` may reduce memory
Whether to synchronize setting the gradients at each data batch. Setting to `True` may reduce memory
requirements when using gradient accumulation with distributed training, at expense of speed.
Example:
@ -948,7 +969,10 @@ class GradientAccumulationPlugin(KwargsHandler):
```
"""
num_steps: int = field(default=None, metadata={"help": "The number of steps to accumulate gradients for."})
num_steps: int = field(
default=None,
metadata={"help": "The number of steps to accumulate gradients for."},
)
adjust_scheduler: bool = field(
default=True,
metadata={
@ -999,12 +1023,22 @@ class TorchDynamoPlugin(KwargsHandler):
metadata={"help": f"Possible options are {[b.value.lower() for b in DynamoBackend]}"},
)
mode: str = field(
default=None, metadata={"help": "Possible options are 'default', 'reduce-overhead' or 'max-autotune'"}
default=None,
metadata={"help": "Possible options are 'default', 'reduce-overhead' or 'max-autotune'"},
)
fullgraph: bool = field(
default=None,
metadata={"help": "Whether it is ok to break model into several subgraphs"},
)
fullgraph: bool = field(default=None, metadata={"help": "Whether it is ok to break model into several subgraphs"})
dynamic: bool = field(default=None, metadata={"help": "Whether to use dynamic shape for tracing"})
options: Any = field(default=None, metadata={"help": "A dictionary of options to pass to the backend."})
disable: bool = field(default=False, metadata={"help": "Turn torch.compile() into a no-op for testing"})
options: Any = field(
default=None,
metadata={"help": "A dictionary of options to pass to the backend."},
)
disable: bool = field(
default=False,
metadata={"help": "Turn torch.compile() into a no-op for testing"},
)
use_regional_compilation: bool = field(
default=None,
@ -1183,7 +1217,7 @@ class DeepSpeedPlugin:
if self.zero3_save_16bit_model is None:
self.zero3_save_16bit_model = (
os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", "false") == "true"
os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", "false").lower() == "true"
)
if self.enable_msamp is None:
self.enable_msamp = os.environ.get("ACCELERATE_FP8_BACKEND", None) == "MSAMP"
@ -1236,13 +1270,13 @@ class DeepSpeedPlugin:
"stage": self.zero_stage,
"offload_optimizer": {
"device": self.offload_optimizer_device,
"nvme_path": self.offload_optimizer_nvme_path
if self.offload_optimizer_device == "nvme"
else None,
"nvme_path": (
self.offload_optimizer_nvme_path if self.offload_optimizer_device == "nvme" else None
),
},
"offload_param": {
"device": self.offload_param_device,
"nvme_path": self.offload_param_nvme_path if self.offload_param_device == "nvme" else None,
"nvme_path": (self.offload_param_nvme_path if self.offload_param_device == "nvme" else None),
},
"stage3_gather_16bit_weights_on_model_save": self.zero3_save_16bit_model,
},
@ -1255,7 +1289,13 @@ class DeepSpeedPlugin:
self.deepspeed_config["steps_per_print"] = float("inf") # this will stop deepspeed from logging @ stdout
if self.zero3_init_flag is None:
self.zero3_init_flag = (
str_to_bool(os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_INIT", str(self.hf_ds_config.is_zero3()))) == 1
str_to_bool(
os.environ.get(
"ACCELERATE_DEEPSPEED_ZERO3_INIT",
str(self.hf_ds_config.is_zero3()),
)
)
== 1
)
if self.zero3_init_flag and not self.hf_ds_config.is_zero3():
warnings.warn("DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.")
@ -1272,7 +1312,10 @@ class DeepSpeedPlugin:
)
if self.msamp_opt_level not in ["O1", "O2"]:
raise ValueError("Invalid optimization level for MS-AMP. Please use one of ['O1' or'O2'].")
self.deepspeed_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level}
self.deepspeed_config["msamp"] = {
"enabled": True,
"opt_level": self.msamp_opt_level,
}
def fill_match(self, ds_key_long, mismatches=None, must_match=True, **kwargs):
mismatches = [] if mismatches is None else mismatches
@ -1317,7 +1360,11 @@ class DeepSpeedPlugin:
for key, value in config.items():
if isinstance(value, dict):
self.deepspeed_config_process(
prefix=prefix + key + ".", mismatches=mismatches, config=value, must_match=must_match, **kwargs
prefix=prefix + key + ".",
mismatches=mismatches,
config=value,
must_match=must_match,
**kwargs,
)
else:
self.fill_match(prefix + key, mismatches, must_match=must_match, **kwargs)
@ -1344,7 +1391,10 @@ class DeepSpeedPlugin:
if mixed_precision == "fp8" and self.enable_msamp:
if "msamp" not in ds_config:
ds_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level}
ds_config["msamp"] = {
"enabled": True,
"opt_level": self.msamp_opt_level,
}
if mixed_precision != "no":
diff_dtype = "bf16" if mixed_precision == "fp16" else "fp16"
@ -1376,9 +1426,15 @@ class DeepSpeedPlugin:
del ds_config["train_batch_size"]
if compare_versions("transformers", "<", "4.46"):
from transformers.deepspeed import HfDeepSpeedConfig, unset_hf_deepspeed_config
from transformers.deepspeed import (
HfDeepSpeedConfig,
unset_hf_deepspeed_config,
)
else:
from transformers.integrations import HfDeepSpeedConfig, unset_hf_deepspeed_config
from transformers.integrations import (
HfDeepSpeedConfig,
unset_hf_deepspeed_config,
)
unset_hf_deepspeed_config()
self.dschf = HfDeepSpeedConfig(ds_config) # keep this object alive # noqa
@ -1497,10 +1553,12 @@ class FullyShardedDataParallelPlugin:
backward_prefetch (`Union[str, torch.distributed.fsdp.BackwardPrefetch]`, defaults to `'NO_PREFETCH'`):
Backward prefetch strategy to use. Should be either a `str` or an instance of
`torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`.
mixed_precision_policy (`Optional[Union[dict, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):
mixed_precision_policy (`Optional[Union[dict, str, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):
A config to enable mixed precision training with FullyShardedDataParallel. If passing in a `dict`, it
should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`, can be an instance of
`torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2.
`torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2. If passing in a `str`, it
should be one of the following values: fp8, fp16, bf16, fp32, and used to set `param_dtype`,
`reduce_dtype`, and `buffer_dtype`.
auto_wrap_policy (`Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to `NO_WRAP`):
A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one
of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. See
@ -1509,8 +1567,9 @@ class FullyShardedDataParallelPlugin:
Whether to offload parameters to CPU. Should be either a `bool` or an instance of
`torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or
`torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2.
ignored_modules (`Optional[Iterable[torch.nn.Module]]`, defaults to `None`):
A list of modules to ignore when wrapping with FSDP.
ignored_modules (`Optional[Union[Iterable[torch.nn.Module], str]]`, defaults to `None`):
A list of modules to ignore when wrapping with FSDP. When passing a string, will match the modules by name
using regex fullmatch. If `fsdp_version` is set to 2, the modules are converted to parameters and used.
state_dict_type (`Union[str, torch.distributed.fsdp.StateDictType]`, defaults to `'FULL_STATE_DICT'`):
State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or
`sharded_state_dict`.
@ -1547,11 +1606,6 @@ class FullyShardedDataParallelPlugin:
min_num_params (`Optional[int]`, defaults to `None`):
The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy`
is `size_based_wrap`.
cp_size (`int`, defaults to `1`):
The size of the context parallel group. Only applicable when `fsdp_version` is set to 2, else error will be
raised. Defaults to 1 (CP not applied).
cp_comm_strategy (`str`, defaults to `allgather`):
The shard rotation strategy to use, only used when `cp_size` > 1 and `fsdp_version` is set to 2.
"""
fsdp_version: int = field(
@ -1581,7 +1635,12 @@ class FullyShardedDataParallelPlugin:
},
)
mixed_precision_policy: Optional[
Union[dict, "torch.distributed.fsdp.MixedPrecision", "torch.distributed.fsdp.MixedPrecisionPolicy"]
Union[
dict,
str,
"torch.distributed.fsdp.MixedPrecision",
"torch.distributed.fsdp.MixedPrecisionPolicy",
]
] = field(
default=None,
metadata={
@ -1599,13 +1658,17 @@ class FullyShardedDataParallelPlugin:
},
)
)
cpu_offload: Union[bool, "torch.distributed.fsdp.CPUOffload", "torch.distributed.fsdp.CPUOffloadPolicy"] = field(
cpu_offload: Union[
bool,
"torch.distributed.fsdp.CPUOffload",
"torch.distributed.fsdp.CPUOffloadPolicy",
] = field(
default=None,
metadata={
"help": "Whether to offload parameters to CPU. Should be either a `bool` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. Defaults to `False`"
},
)
ignored_modules: Optional[Iterable[torch.nn.Module]] = field(
ignored_modules: Optional[Union[Iterable[torch.nn.Module], str]] = field(
default=None,
metadata={"help": "A list of modules to ignore when wrapping with FSDP."},
)
@ -1626,7 +1689,10 @@ class FullyShardedDataParallelPlugin:
metadata={"help": "State dict config to use. Is determined based on the `state_dict_type` if not passed in."},
)
optim_state_dict_config: Optional[
Union["torch.distributed.fsdp.FullOptimStateDictConfig", "torch.distributed.fsdp.ShardedOptimStateDictConfig"]
Union[
"torch.distributed.fsdp.FullOptimStateDictConfig",
"torch.distributed.fsdp.ShardedOptimStateDictConfig",
]
] = field(
default=None,
metadata={
@ -1697,24 +1763,9 @@ class FullyShardedDataParallelPlugin:
"help": "The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy` is `size_based_wrap`."
},
)
cp_size: int = field(
default=None,
metadata={
"help": "The size of the context parallel group. Only applicable when `fsdp_version` is set to 2, else error will be raised. Defaults to 1 (CP not applied)"
},
)
cp_comm_strategy: str = field(
default=None,
metadata={
"help": "The shard rotation strategy to use, only used when `cp_size` > 1 and `fsdp_version` is set to 2. Defaults to `allgather`."
},
)
def __post_init__(self):
from torch.distributed.fsdp import (
BackwardPrefetch,
ShardingStrategy,
)
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
_fsdp2_warnings = set()
@ -1748,7 +1799,8 @@ class FullyShardedDataParallelPlugin:
# Fallback to `reshard_after_forward` in FSDP1 if `sharding_strategy` is not set
if self.reshard_after_forward is None and self.sharding_strategy is None:
reshard_after_forward = os.environ.get(
env_prefix + "RESHARD_AFTER_FORWARD", "true" if self.fsdp_version == 2 else "FULL_SHARD"
env_prefix + "RESHARD_AFTER_FORWARD",
"true" if self.fsdp_version == 2 else "FULL_SHARD",
)
if self.fsdp_version == 2:
self.reshard_after_forward = str_to_bool(reshard_after_forward.lower(), to_bool=True)
@ -1805,7 +1857,10 @@ class FullyShardedDataParallelPlugin:
raise ValueError(
f"Invalid auto wrap policy: {self.auto_wrap_policy}. Must be one of {FSDP_AUTO_WRAP_POLICY}"
)
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
if self.auto_wrap_policy.upper() == "TRANSFORMER_BASED_WRAP":
self.auto_wrap_policy = transformer_auto_wrap_policy
@ -1849,6 +1904,9 @@ class FullyShardedDataParallelPlugin:
str_to_bool(os.environ.get(env_prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1
)
if self.ignored_modules is None:
self.ignored_modules = os.environ.get(env_prefix + "IGNORED_MODULES", None)
if self.cpu_ram_efficient_loading is None:
self.cpu_ram_efficient_loading = (
str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) == 1
@ -1871,29 +1929,11 @@ class FullyShardedDataParallelPlugin:
)
os.environ[env_var] = str(self.cpu_ram_efficient_loading)
if self.cp_size is None:
self.cp_size = int(os.environ.get(env_prefix + "CP_SIZE", "1"))
if self.cp_size > 1 and self.fsdp_version != 2:
raise ValueError(
f"cp_size set to {self.cp_size}. This is not supported with FSDP1, please set to 1 or use `fsdp_version=2`"
)
if self.cp_size > 1 and not is_torch_version(">=", CONTEXT_PARALLEL_PYTORCH_VERSION):
raise ValueError(
f"cp_size set to {self.cp_size}. This is not supported with PyTorch < {CONTEXT_PARALLEL_PYTORCH_VERSION}, please set to None or upgrade your PyTorch version."
)
if self.cp_comm_strategy is None:
self.cp_comm_strategy = os.environ.get(env_prefix + "CP_COMM_STRATEGY", "allgather")
# No need to further check versions, as that check is done in the `context_parallel_size` check
if self.cp_comm_strategy not in ["allgather", "alltoall"]:
raise ValueError(
f"cp_comm_strategy set to {self.cp_comm_strategy}. Must be one of ['allgather', 'alltoall']."
)
if isinstance(self.mixed_precision_policy, dict):
if isinstance(self.mixed_precision_policy, str):
# override is True since self.mixed_precision_policy is not None
# has to be overwritten with the correct mixed precision object
self.set_mixed_precision(self.mixed_precision_policy, override=True)
elif isinstance(self.mixed_precision_policy, dict):
self.set_mixed_precision(self.mixed_precision_policy)
if self.mixed_precision_policy is not None:
self.validate_mixed_precision_policy()
@ -1918,7 +1958,12 @@ class FullyShardedDataParallelPlugin:
# Create a function that will be used to initialize the parameters of the model
# when using `sync_module_states`
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)
if is_torch_version("<", "2.7.0") and self.fsdp_version == 2 and self.ignored_modules is not None:
_fsdp2_warnings.add(
"FSDP2 ignored_params/ignored_modules is not available for torch version < 2.7.0"
"Setting ignored_modules to None."
)
self.ignored_modules = None
# Single warning for all deprecation warnings due to FSDP2 conversion
if _fsdp2_warnings:
logger.warning("Multiple deprecation warnings due to FSDP2 conversion:\n".join(_fsdp2_warnings))
@ -1942,7 +1987,8 @@ class FullyShardedDataParallelPlugin:
if self.state_dict_type is None:
self.state_dict_type = os.environ.get(
"FSDP_STATE_DICT_TYPE", "FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT"
"FSDP_STATE_DICT_TYPE",
"FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT",
)
if isinstance(self.state_dict_type, str):
if self.state_dict_type.isdigit():
@ -1969,10 +2015,13 @@ class FullyShardedDataParallelPlugin:
def set_auto_wrap_policy(self, model):
"""
Given `model`, creates an `auto_wrap_policy` baesd on the passed in policy and if we can use the
Given `model`, creates an `auto_wrap_policy` based on the passed in policy and if we can use the
`transformer_cls_to_wrap`
"""
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
# First base off of `_no_split_modules`
no_split_modules = getattr(model, "_no_split_modules", None)
@ -2109,33 +2158,57 @@ class TorchTensorParallelPlugin:
metadata={"help": "tensor parallel size will be used in the device mesh preparation"},
)
# torch_device_mesh is fo type "torch.distributed.DeviceMesh"
# torch_device_mesh is of type "torch.distributed.DeviceMesh"
torch_device_mesh: Optional["torch.distributed.DeviceMesh"] = field(default=None)
@dataclass
class TorchContextParallelConfig:
"""
This class holds the configuration for context parallelism in PyTorch.
"""
cp_comm_strategy: Optional[str] = field(
default=None,
metadata={
"help": "Communication strategy for context parallelism. Can be one of 'allgather' or 'alltoall'. Defaults to 'allgather'."
},
)
def __post_init__(self):
if not isinstance(self.tp_size, int):
raise ValueError(f"`tp_size` set to {self.tp_size}, please set to an `int`.")
if self.tp_size <= 1:
raise ValueError("`tp_size` must be greater than 1.")
if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION):
if not is_torch_version(">=", BETA_CP_AVAILABLE_PYTORCH_VERSION):
raise ValueError(
f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel."
f"Context parallelism is only available in PyTorch {BETA_CP_AVAILABLE_PYTORCH_VERSION} and later versions. "
"Please upgrade your PyTorch version."
)
if self.cp_comm_strategy is None:
self.cp_comm_strategy = os.environ.get("PARALLELISM_CONFIG_CP_COMM_STRATEGY", "allgather")
if self.cp_comm_strategy not in ["allgather", "alltoall"]:
raise ValueError(
f"Invalid cp_comm_strategy: {self.cp_comm_strategy}. Must be one of 'allgather' or 'alltoall'."
)
from torch.distributed.device_mesh import init_device_mesh
# support for other devices has to be investigated
if is_hpu_available(init_hccl=True):
device = "hpu"
else:
device = "cuda"
mesh_dim_name = "tp"
@dataclass
class TorchTensorParallelConfig:
"""
Use this object in your [`Accelerator`] to customize your torch tensor parallelism.
"""
# device mesh is not used for model sharding
# it is only used for preparing data loader
self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,))
enable_async_tp: bool = False
def __post_init__(self):
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
raise ValueError(
f"Torch tensor parallelism is only available in PyTorch {BETA_TP_AVAILABLE_PYTORCH_VERSION} and later versions. "
"Please upgrade your PyTorch version."
)
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
if self.enable_async_tp:
warnings.warn("Async tensor parallelism is currently not supported, ignoring this option.")
@dataclass
@ -2190,7 +2263,7 @@ class MegatronLMPlugin:
lr_warmup_fraction (`float`, defaults to `None`):
Fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over.
min_lr (`float`, defaults to `0`):
Minumum value for learning rate. The scheduler clip values below this threshold.
Minimum value for learning rate. The scheduler clip values below this threshold.
consumed_samples (`List`, defaults to `None`):
Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call.
no_wd_decay_cond (`Optional`, defaults to `None`):
@ -2239,7 +2312,8 @@ class MegatronLMPlugin:
pp_degree: int = field(default=None, metadata={"help": "pipeline parallelism degree."})
num_micro_batches: int = field(default=None, metadata={"help": "number of micro-batches."})
gradient_clipping: float = field(
default=None, metadata={"help": "gradient clipping value based on global L2 Norm (0 to disable)"}
default=None,
metadata={"help": "gradient clipping value based on global L2 Norm (0 to disable)"},
)
sequence_parallelism: bool = field(
default=None,
@ -2254,7 +2328,8 @@ class MegatronLMPlugin:
metadata={"help": "enable distributed optimizer"},
)
pipeline_model_parallel_split_rank: int = field(
default=None, metadata={"help": "Rank where encoder and decoder should be split."}
default=None,
metadata={"help": "Rank where encoder and decoder should be split."},
)
num_layers_per_virtual_pipeline_stage: int = field(
default=None, metadata={"help": "Number of layers per virtual pipeline stage."}
@ -2315,7 +2390,7 @@ class MegatronLMPlugin:
)
min_lr: float = field(
default=0,
metadata={"help": "Minumum value for learning rate. The scheduler clip values below this threshold."},
metadata={"help": "Minimum value for learning rate. The scheduler clip values below this threshold."},
)
consumed_samples: list[int] = field(
default=None,
@ -2351,10 +2426,12 @@ class MegatronLMPlugin:
metadata={"help": "Whether to set all logging options."},
)
eval_iters: int = field(
default=100, metadata={"help": "Number of iterations to run for evaluation validation/test for."}
default=100,
metadata={"help": "Number of iterations to run for evaluation validation/test for."},
)
eval_interval: int = field(
default=1000, metadata={"help": "Interval between running evaluation on validation set."}
default=1000,
metadata={"help": "Interval between running evaluation on validation set."},
)
return_logits: bool = field(
default=False,
@ -2721,7 +2798,8 @@ class BnbQuantizationConfig:
load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."})
llm_int8_threshold: float = field(
default=6.0, metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"}
default=6.0,
metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"},
)
load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."})

View File

@ -261,22 +261,36 @@ class DeepSpeedEngineWrapper:
def __init__(self, engine):
self.engine = engine
def backward(self, loss, **kwargs):
def backward(self, loss, sync_gradients=True, **kwargs):
# Set gradient accumulation boundary based on Accelerate's sync_gradients state
# This tells DeepSpeed whether this is the final micro-batch before gradient sync
self.engine.set_gradient_accumulation_boundary(is_boundary=sync_gradients)
# runs backpropagation and handles mixed precision
self.engine.backward(loss, **kwargs)
# Deepspeed's `engine.step` performs the following operations:
# - gradient accumulation check
# - gradient clipping
# - optimizer step
# - zero grad
# - checking overflow
# - lr_scheduler step (only if engine.lr_scheduler is not None)
self.engine.step()
# Only perform step and related operations at gradient accumulation boundaries
if sync_gradients:
# Deepspeed's `engine.step` performs the following operations:
# - gradient accumulation check
# - gradient clipping
# - optimizer step
# - zero grad
# - checking overflow
# - lr_scheduler step (only if engine.lr_scheduler is not None)
self.engine.step()
# and this plugin overrides the above calls with no-ops when Accelerate runs under
# Deepspeed, but allows normal functionality for non-Deepspeed cases thus enabling a simple
# training loop that works transparently under many training regimes.
def get_global_grad_norm(self):
"""Get the global gradient norm from DeepSpeed engine."""
grad_norm = self.engine.get_global_grad_norm()
# Convert to scalar if it's a tensor
if hasattr(grad_norm, "item"):
return grad_norm.item()
return grad_norm
class DeepSpeedOptimizerWrapper(AcceleratedOptimizer):
"""

View File

@ -149,7 +149,7 @@ def check_cuda_p2p_ib_support():
Checks if the devices being used have issues with P2P and IB communications, namely any consumer GPU hardware after
the 3090.
Noteably uses `nvidia-smi` instead of torch to not initialize CUDA.
Notably uses `nvidia-smi` instead of torch to not initialize CUDA.
"""
try:
device_names, device_count = get_gpu_info()

View File

@ -14,12 +14,14 @@
import copy
import functools
import os
import re
import shutil
import warnings
from collections import defaultdict
from collections.abc import Iterable
from contextlib import nullcontext
from pathlib import Path
from typing import Callable
from typing import Callable, Union
import torch
@ -179,10 +181,9 @@ def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, a
else nullcontext()
)
sd_options = _prepare_sd_options(fsdp_plugin)
with ctx:
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
if type(model) is not FSDP and accelerator.process_index != 0:
if type(model) is not FSDP and accelerator.process_index != 0 and not accelerator.is_fsdp2:
if not fsdp_plugin.sync_module_states and fsdp_plugin.fsdp_version == 1:
raise ValueError(
"Set the `sync_module_states` flag to `True` so that model states are synced across processes when "
@ -192,7 +193,12 @@ def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, a
weights_name = f"{FSDP_MODEL_NAME}.bin" if model_index == 0 else f"{FSDP_MODEL_NAME}_{model_index}.bin"
input_model_file = os.path.join(input_dir, weights_name)
logger.info(f"Loading model from {input_model_file}")
state_dict = torch.load(input_model_file, weights_only=True)
# we want an empty state dict for FSDP2 as we use `broadcast_from_rank0`
load_model = not accelerator.is_fsdp2 or accelerator.is_main_process
if load_model:
state_dict = torch.load(input_model_file, weights_only=True)
else:
state_dict = {}
logger.info(f"Model loaded from {input_model_file}")
elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT:
weights_name = (
@ -299,13 +305,15 @@ def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, o
optim_state = torch.load(input_optimizer_file, weights_only=True)
logger.info(f"Optimizer state loaded from {input_optimizer_file}")
else:
from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict
ckpt_dir = (
os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
if f"{OPTIMIZER_NAME}" not in input_dir
else input_dir
)
logger.info(f"Loading Optimizer from {ckpt_dir}")
optim_state = {"optimizer": optimizer.state_dict()}
optim_state = {"optimizer": get_optimizer_state_dict(model, optimizer)}
dist_cp.load(
optim_state,
checkpoint_id=ckpt_dir,
@ -498,10 +506,10 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
if accelerator.is_main_process:
for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()):
full_param = full_param.detach().cuda()
mesh = sharded_param.device_mesh
dist.broadcast(full_param, src=0, group=mesh.get_group())
sharded_tensor = distribute_tensor(full_param, mesh, sharded_param.placements)
device_mesh = sharded_param.device_mesh
full_param = full_param.detach().to(device_mesh.device_type)
dist.broadcast(full_param, src=0, group=dist.group.WORLD)
sharded_tensor = distribute_tensor(full_param, device_mesh, sharded_param.placements)
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
@ -512,10 +520,10 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
# We need this else to have a matching `broadcast` for all of the ranks, else we deadlock
else:
for param_name, sharded_param in meta_sharded_sd.items():
full_tensor = torch.empty(sharded_param.size(), device="cuda", dtype=sharded_param.dtype)
mesh = sharded_param.device_mesh
dist.broadcast(full_tensor, src=0, group=mesh.get_group())
sharded_tensor = distribute_tensor(full_tensor, mesh, sharded_param.placements)
device_mesh = sharded_param.device_mesh
full_tensor = torch.empty(sharded_param.size(), device=device_mesh.device_type, dtype=sharded_param.dtype)
dist.broadcast(full_tensor, src=0, group=dist.group.WORLD)
sharded_tensor = distribute_tensor(full_tensor, device_mesh, sharded_param.placements)
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
@ -544,6 +552,11 @@ def fsdp2_switch_optimizer_parameters(optimizer: torch.optim.Optimizer, mapping:
indicates a bug. If we kept the original params instead of raising, the training wouldn't be numerically
correct and weights wouldn't get updated.
"""
from torch.distributed.tensor import DTensor
accessor_mapping = {}
accessor_mapping[DTensor] = "_local_tensor"
try:
for param_group in optimizer.param_groups:
param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
@ -611,16 +624,19 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
fsdp2_plugin.set_auto_wrap_policy(model)
original_sd = model.state_dict()
mesh = getattr(accelerator.state, "torch_device_mesh", None)
mesh = getattr(accelerator, "torch_device_mesh", None)
fsdp2_kwargs = {
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
"offload_policy": fsdp2_plugin.cpu_offload,
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
"mesh": mesh["fsdp_cp"] if mesh else None,
"mesh": mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None,
}
if fsdp2_plugin.ignored_modules is not None:
fsdp2_kwargs["ignored_params"] = get_parameters_from_modules(
fsdp2_plugin.ignored_modules, model, accelerator.device
)
model_has_params4bit = False
for name, param in model.named_parameters():
@ -634,7 +650,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
# Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
# For this reason, we need to move the model to `meta` device, as then sharding happens on `meta` device
# If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.emtpy`), `fully_shard` would move it to GPU
# If we kept the model on CPU (`cpu_ram_efficient_loading` has model be on CPU on all ranks, though non-main ranks only have `torch.empty`), `fully_shard` would move it to GPU
# Afterwards, when we call `fsdp2_load_full_state_dict`, us creating the state_dict would result into briefly having two copies of model state_dict on the GPU -> VRAM spike
# We need to keep the original non-persistent buffers, as those MAY not be in the state_dict, resulting in them staying on meta device
@ -782,5 +798,32 @@ def fsdp2_canonicalize_names(named_params: dict) -> dict:
k.replace("_orig_mod.", "") if k.startswith("_orig_mod.") else k: v for k, v in named_params.items()
}
named_params = {k.replace("._orig_mod", ""): v for k, v in named_params.items()}
named_params = {k.replace("_cp_wrapped_model.", ""): v for k, v in named_params.items()}
return named_params
def get_parameters_from_modules(
modules: Union[Iterable[torch.nn.Module], str], model, device
) -> set[torch.nn.Parameter]:
"""Converts modules to parameters where modules can be a string or list of torch.nn.Module
Args:
modules (`Union[Iterable[torch.nn.Module], str]`): List of modules
Returns:
`List[torch.nn.Parameter]`: List of parameters
"""
if modules is None:
return None
parameters = []
# code taken from accelerate while preparing kwargs for FSDP
if isinstance(modules, str):
reg = re.compile(modules)
mapped_modules = []
for name, module in model.named_modules():
if reg.fullmatch(name):
module.to(device)
mapped_modules.append(module)
modules = mapped_modules
for module in modules:
parameters.extend(list(module.parameters()))
return set(parameters)

View File

@ -15,6 +15,7 @@
import importlib
import importlib.metadata
import os
import sys
import warnings
from functools import lru_cache, wraps
@ -113,6 +114,14 @@ def is_transformer_engine_available():
return _is_package_available("transformer_engine", "transformer-engine")
def is_transformer_engine_mxfp8_available():
if _is_package_available("transformer_engine", "transformer-engine"):
import transformer_engine.pytorch as te
return te.fp8.check_mxfp8_support()[0]
return False
def is_lomo_available():
return _is_package_available("lomo_optim")
@ -173,7 +182,7 @@ def is_bf16_available(ignore_tpu=False):
if is_xpu_available():
return torch.xpu.is_bf16_supported()
if is_mps_available():
return False
return torch.backends.mps.is_macos_or_newer(14, 0)
return True
@ -281,6 +290,14 @@ def is_comet_ml_available():
return _is_package_available("comet_ml")
def is_swanlab_available():
return _is_package_available("swanlab")
def is_trackio_available():
return sys.version_info >= (3, 10) and _is_package_available("trackio")
def is_boto3_available():
return _is_package_available("boto3")
@ -397,7 +414,12 @@ def is_npu_available(check_device=False):
if importlib.util.find_spec("torch_npu") is None:
return False
import torch_npu # noqa: F401
# NOTE: importing torch_npu may raise error in some envs
# e.g. inside cpu-only container with torch_npu installed
try:
import torch_npu # noqa: F401
except Exception:
return False
if check_device:
try:

View File

@ -89,9 +89,9 @@ def setup_fp8_env(args: argparse.Namespace, current_env: dict[str, str]):
value = getattr(args, arg)
if value is not None:
if arg == "fp8_override_linear_precision":
current_env[prefix + "FP8_OVERRIDE_FPROP"] = value[0]
current_env[prefix + "FP8_OVERRIDE_DGRAD"] = value[1]
current_env[prefix + "FP8_OVERRIDE_WGRAD"] = value[2]
current_env[prefix + "FP8_OVERRIDE_FPROP"] = str(value[0])
current_env[prefix + "FP8_OVERRIDE_DGRAD"] = str(value[1])
current_env[prefix + "FP8_OVERRIDE_WGRAD"] = str(value[2])
else:
current_env[f"{prefix}{arg.upper()}"] = str(getattr(args, arg))
return current_env
@ -328,8 +328,8 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:
current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(args.fsdp_cpu_ram_efficient_loading).lower()
current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower()
current_env["FSDP_CP_SIZE"] = str(args.fsdp_cp_size)
current_env["FSDP_CP_COMM_STRATEGY"] = str(args.fsdp_cp_comm_strategy)
if getattr(args, "fsdp_ignored_modules", None) is not None:
current_env["FSDP_IGNORED_MODULES"] = str(args.fsdp_ignored_modules)
if args.use_megatron_lm:
prefix = "MEGATRON_LM_"
@ -349,6 +349,20 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:
current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process)
if args.enable_cpu_affinity:
current_env["ACCELERATE_CPU_AFFINITY"] = "1"
if not args.use_parallelism_config:
return current_env
prefix = "PARALLELISM_CONFIG_"
if args.use_parallelism_config:
current_env["ACCELERATE_USE_PARALLELISM_CONFIG"] = "true"
current_env[prefix + "DP_REPLICATE_SIZE"] = str(args.parallelism_config_dp_replicate_size)
current_env[prefix + "TP_SIZE"] = str(args.parallelism_config_tp_size)
current_env[prefix + "CP_SIZE"] = str(args.parallelism_config_cp_size)
current_env[prefix + "DP_SHARD_SIZE"] = str(args.parallelism_config_dp_shard_size)
if args.parallelism_config_cp_size > 1:
current_env[prefix + "CP_COMM_STRATEGY"] = str(args.parallelism_config_cp_comm_strategy)
return current_env

View File

@ -873,7 +873,7 @@ def finish_mpu_init():
_set_random_seed(args.seed, args.data_parallel_random_init)
# intialize megatron setup
# initialize megatron setup
def initialize(accelerator, extra_args_provider=None, args_defaults={}):
accelerator.print("Initializing Megatron-LM")
assert torch.cuda.is_available(), "Megatron requires CUDA."
@ -1344,7 +1344,7 @@ class MegatronEngine(torch.nn.Module):
padding = torch.cuda.LongTensor([[tokenizer.eod] * max_new_tokens] * inputs.shape[0])
prompts_tokens_tensor = torch.concat([inputs.cuda(), padding], axis=-1)
# We need the sizes of these tensors for the boradcast
# We need the sizes of these tensors for the broadcast
sizes_list = [
prompts_tokens_tensor.size(0), # Batch size
prompts_tokens_tensor.size(1),
@ -1353,7 +1353,7 @@ class MegatronEngine(torch.nn.Module):
# First, broadcast the sizes.
sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=0)
# Now that we have the sizes, we can boradcast the tokens
# Now that we have the sizes, we can broadcast the tokens
# and length tensors.
sizes = sizes_tensor.tolist()
context_tokens_tensor = broadcast_tensor(sizes, torch.int64, tensor=prompts_tokens_tensor, rank=0)

View File

@ -121,7 +121,7 @@ def find_executable_batch_size(
):
"""
A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
CUDNN, the batch size is cut in half and passed to `function`
CUDNN, the batch size is multiplied by 0.9 and passed to `function`
`function` must take in a `batch_size` parameter as its first argument.
@ -153,7 +153,7 @@ def find_executable_batch_size(
def reduce_batch_size_fn():
nonlocal batch_size
batch_size = batch_size // 2
batch_size = int(batch_size * 0.9)
return batch_size
def decorator(*args, **kwargs):

View File

@ -169,7 +169,7 @@ def dtype_byte_size(dtype: torch.dtype):
return 1 / 2
elif dtype == CustomDtype.FP8:
return 1
elif is_torch_version(">=", "2.1.0") and dtype == torch.float8_e4m3fn:
elif is_torch_version(">=", "2.1.0") and dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
return 1
bit_search = re.search(r"[^\d](\d+)$", str(dtype))
if bit_search is None:
@ -222,6 +222,8 @@ def set_module_tensor_to_device(
dtype: Optional[Union[str, torch.dtype]] = None,
fp16_statistics: Optional[torch.HalfTensor] = None,
tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
non_blocking: bool = False,
clear_cache: bool = True,
):
"""
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
@ -245,6 +247,10 @@ def set_module_tensor_to_device(
A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given
execution device, this parameter is useful to reuse the first available pointer of a shared weight on the
device for all others, instead of duplicating memory.
non_blocking (`bool`, *optional*, defaults to `False`):
If `True`, the device transfer will be asynchronous with respect to the host, if possible.
clear_cache (`bool`, *optional*, defaults to `True`):
Whether or not to clear the device cache after setting the tensor on the device.
"""
# Recurse if needed
if "." in tensor_name:
@ -295,9 +301,9 @@ def set_module_tensor_to_device(
if dtype is None:
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
value = value.to(old_value.dtype)
value = value.to(old_value.dtype, non_blocking=non_blocking)
elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
value = value.to(dtype)
value = value.to(dtype, non_blocking=non_blocking)
device_quantization = None
with torch.no_grad():
@ -305,8 +311,8 @@ def set_module_tensor_to_device(
# # fix the case where the device is meta, we don't want to put it on cpu because there is no data =0
if (
param is not None
and param.device.type != "cuda"
and torch.device(device).type == "cuda"
and param.device.type not in ("cuda", "xpu")
and torch.device(device).type in ("cuda", "xpu")
and param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]
):
device_quantization = device
@ -326,15 +332,15 @@ def set_module_tensor_to_device(
if "xpu" in str(device) and not is_xpu_available():
raise ValueError(f'{device} is not available, you should use device="cpu" instead')
if value is None:
new_value = old_value.to(device)
new_value = old_value.to(device, non_blocking=non_blocking)
if dtype is not None and device in ["meta", torch.device("meta")]:
if not str(old_value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
new_value = new_value.to(dtype)
new_value = new_value.to(dtype, non_blocking=non_blocking)
if not is_buffer:
module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad)
elif isinstance(value, torch.Tensor):
new_value = value.to(device)
new_value = value.to(device, non_blocking=non_blocking)
else:
new_value = torch.tensor(value, device=device)
if device_quantization is not None:
@ -347,24 +353,30 @@ def set_module_tensor_to_device(
if param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]:
if param_cls.__name__ == "Int8Params" and new_value.dtype == torch.float32:
# downcast to fp16 if any - needed for 8bit serialization
new_value = new_value.to(torch.float16)
new_value = new_value.to(torch.float16, non_blocking=non_blocking)
# quantize module that are going to stay on the cpu so that we offload quantized weights
if device == "cpu" and param_cls.__name__ == "Int8Params":
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(0).to("cpu")
new_value.CB = new_value.CB.to("cpu")
new_value.SCB = new_value.SCB.to("cpu")
else:
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device)
new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(
device, non_blocking=non_blocking
)
elif param_cls.__name__ in ["QTensor", "QBitsTensor"]:
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(device)
new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(
device, non_blocking=non_blocking
)
elif param_cls.__name__ in ["AffineQuantizedTensor"]:
new_value = new_value.to(device)
new_value = new_value.to(device, non_blocking=non_blocking)
else:
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device)
new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(
device, non_blocking=non_blocking
)
module._parameters[tensor_name] = new_value
if fp16_statistics is not None:
module._parameters[tensor_name].SCB = fp16_statistics.to(device)
module._parameters[tensor_name].SCB = fp16_statistics.to(device, non_blocking=non_blocking)
del fp16_statistics
# as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight
if (
@ -390,8 +402,9 @@ def set_module_tensor_to_device(
device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
if not getattr(module.weight, "quant_state", None) and device_index is not None:
module.weight = module.weight.cuda(device_index)
# clean pre and post forward hook
if device != "cpu":
if clear_cache and device != "cpu":
clear_device_cache()
# When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in
@ -1594,6 +1607,14 @@ def check_device_map(model: nn.Module, device_map: dict[str, Union[int, str, tor
model (`torch.nn.Module`): The model to check the device map against.
device_map (`Dict[str, Union[int, str, torch.device]]`): The device map to check.
"""
all_module_names = dict(model.named_modules())
invalid_keys = [k for k in device_map if k != "" and k not in all_module_names]
if invalid_keys:
warnings.warn(
f"The following device_map keys do not match any submodules in the model: {invalid_keys}", UserWarning
)
all_model_tensors = [name for name, _ in model.state_dict().items()]
for module_name in device_map.keys():
if module_name == "":
@ -2076,7 +2097,6 @@ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwarg
DistributedType.MULTI_HPU,
DistributedType.FSDP,
DistributedType.XLA,
DistributedType.TP,
]:
return torch.autocast(device_type=device_type, dtype=torch.bfloat16, **autocast_kwargs)
else:
@ -2116,6 +2136,10 @@ def get_grad_scaler(distributed_type: DistributedType = None, **kwargs):
return torch.amp.GradScaler("hpu", **kwargs)
elif is_xpu_available():
return torch.amp.GradScaler("xpu", **kwargs)
elif is_mps_available():
if not is_torch_version(">=", "2.8.0"):
raise ValueError("Grad Scaler with MPS device requires a Pytorch >= 2.8.0")
return torch.amp.GradScaler("mps", **kwargs)
else:
if is_torch_version(">=", "2.3"):
return torch.amp.GradScaler("cuda", **kwargs)

View File

@ -32,6 +32,7 @@ from .imports import (
is_torch_distributed_available,
is_torch_xla_available,
)
from .versions import is_torch_version
if is_torch_xla_available():
@ -316,8 +317,8 @@ def _gpu_gather(tensor):
state = PartialState()
gather_op = torch.distributed.all_gather_into_tensor
# FIXME: the below 2 lines are added to work-aound a bug related to INT64 collectives in oneCCL. Remove them once pytorch-2.9 is released.
if state.device.type == "xpu":
# NOTE: need manually synchronize to workaourd a INT64 collectives bug in oneCCL before torch 2.9.0
if state.device.type == "xpu" and is_torch_version("<=", "2.8"):
torch.xpu.synchronize()
def _gpu_gather_one(tensor):
@ -519,7 +520,7 @@ def gather_tensor_shape(tensor):
def copy_tensor_to_devices(tensor=None) -> torch.Tensor:
"""
Copys a tensor that only exists on a single device and broadcasts it to other devices. Differs from `broadcast` as
Copies a tensor that only exists on a single device and broadcasts it to other devices. Differs from `broadcast` as
each worker doesn't need to know its shape when used (and tensor can be `None`)
Args:
@ -731,7 +732,7 @@ def reduce(tensor, reduction="mean", scale=1.0):
reduction (`str`, *optional*, defaults to `"mean"`):
A reduction method. Can be of "mean", "sum", or "none"
scale (`float`, *optional*):
A default scaling value to be applied after the reduce, only valied on XLA.
A default scaling value to be applied after the reduce, only valid on XLA.
Returns:
The same data structure as `data` with all the tensors reduced.
@ -787,7 +788,7 @@ def convert_to_fp32(tensor):
class ConvertOutputsToFp32:
"""
Decorator to apply to a function outputing tensors (like a model forward pass) that ensures the outputs in FP16
Decorator to apply to a function outputting tensors (like a model forward pass) that ensures the outputs in FP16
precision will be convert back to FP32.
Args:

View File

@ -194,6 +194,26 @@ def compile_regions_deepspeed(module: torch.nn.Module, **compile_kwargs):
module.compile(**compile_kwargs)
def model_has_dtensor(model: torch.nn.Module) -> bool:
"""
Check if the model has DTensor parameters.
Args:
model (`torch.nn.Module`):
The model to check.
Returns:
`bool`: Whether the model has DTensor parameters.
"""
if is_torch_version(">=", "2.5.0"):
from torch.distributed.tensor import DTensor
else:
# from torch 2.0.0 (oldest supported accelerate torch version), DTensor is in torch.distributed._tensor
from torch.distributed._tensor import DTensor
return any(isinstance(p, DTensor) for p in model.parameters())
def extract_model_from_parallel(
model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = True, recursive: bool = False
):

View File

@ -16,7 +16,7 @@ from types import MethodType
import torch.nn as nn
from .imports import is_fp8_available, is_hpu_available
from .imports import is_hpu_available, is_transformer_engine_available
from .operations import GatheredParameters
@ -27,11 +27,15 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True, _conv
"""
Recursively converts the linear and layernorm layers of a model to their `transformers_engine` counterpart.
"""
if not is_fp8_available():
if not is_transformer_engine_available():
raise ImportError("Using `convert_model` requires transformer_engine to be installed.")
if is_hpu_available():
import intel_transformer_engine as te
if not hasattr(te, "LayerNorm"):
# HPU does not have a LayerNorm implementation in TE
te.LayerNorm = nn.LayerNorm
else:
import transformer_engine.pytorch as te
@ -56,9 +60,11 @@ def convert_model(model, to_transformer_engine=True, _convert_linear=True, _conv
# Note: @xrsrke (Phuc) found that te.LayerNorm doesn't have any real memory savings or speedups over nn.LayerNorm
elif isinstance(module, nn.LayerNorm) and to_transformer_engine and _convert_ln:
with GatheredParameters([module.weight, module.bias], modifier_rank=0):
has_bias = module.bias is not None
te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
te_module.weight.copy_(module.weight)
te_module.bias.copy_(module.bias)
if has_bias:
te_module.bias.copy_(module.bias)
setattr(model, name, te_module)
elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear:
@ -90,7 +96,7 @@ def has_transformer_engine_layers(model):
"""
Returns whether a given model has some `transformer_engine` layer or not.
"""
if not is_fp8_available():
if not is_transformer_engine_available():
raise ImportError("Using `has_transformer_engine_layers` requires transformer_engine to be installed.")
if is_hpu_available():
@ -114,7 +120,7 @@ def contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False):
Wrapper for a model's forward method to apply FP8 autocast. Is context aware, meaning that by default it will
disable FP8 autocast during eval mode, which is generally better for more accurate metrics.
"""
if not is_fp8_available():
if not is_transformer_engine_available():
raise ImportError("Using `contextual_fp8_autocast` requires transformer_engine to be installed.")
if is_hpu_available():
@ -137,19 +143,39 @@ def apply_fp8_autowrap(model, fp8_recipe_handler):
"""
Applies FP8 context manager to the model's forward method
"""
if not is_fp8_available():
if not is_transformer_engine_available():
raise ImportError("Using `apply_fp8_autowrap` requires transformer_engine to be installed.")
if is_hpu_available():
import intel_transformer_engine.recipe as te_recipe
is_fp8_block_scaling_available = False
message = "MXFP8 block scaling is not available on HPU."
else:
import transformer_engine.common.recipe as te_recipe
import transformer_engine.pytorch as te
is_fp8_block_scaling_available, message = te.fp8.check_mxfp8_support()
kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {}
if "fp8_format" in kwargs:
kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"])
use_during_eval = kwargs.pop("use_autocast_during_eval", False)
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
use_mxfp8_block_scaling = kwargs.pop("use_mxfp8_block_scaling", False)
if use_mxfp8_block_scaling and not is_fp8_block_scaling_available:
raise ValueError(f"MXFP8 block scaling is not available: {message}")
if use_mxfp8_block_scaling:
if "amax_compute_algo" in kwargs:
raise ValueError("`amax_compute_algo` is not supported for MXFP8 block scaling.")
if "amax_history_len" in kwargs:
raise ValueError("`amax_history_len` is not supported for MXFP8 block scaling.")
fp8_recipe = te_recipe.MXFP8BlockScaling(**kwargs)
else:
fp8_recipe = te_recipe.DelayedScaling(**kwargs)
new_forward = contextual_fp8_autocast(model.forward, fp8_recipe, use_during_eval)
if hasattr(model.forward, "__func__"):

View File

@ -0,0 +1,240 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from transformers import AutoModel
from transformers.trainer_utils import set_seed
from accelerate.accelerator import Accelerator
from accelerate.test_utils.testing import AccelerateTestCase, require_deepspeed
from accelerate.test_utils.training import RegressionDataset
from accelerate.utils import patch_environment
from accelerate.utils.dataclasses import DeepSpeedPlugin
set_seed(42)
GPT2_TINY = "hf-internal-testing/tiny-random-gpt2"
ZERO2 = "zero2"
ZERO3 = "zero3"
FP16 = "fp16"
@require_deepspeed
class DeepSpeedGradientAccumulationTest(AccelerateTestCase):
def setUp(self):
super().setUp()
self._test_file_path = inspect.getfile(self.__class__)
path = Path(self._test_file_path).resolve()
self.test_file_dir_str = str(path.parents[0])
self.ds_config_file = dict(
zero2=f"{self.test_file_dir_str}/ds_config_zero2.json",
zero3=f"{self.test_file_dir_str}/ds_config_zero3.json",
)
# Load config files
with open(self.ds_config_file[ZERO2], encoding="utf-8") as f:
config_zero2 = json.load(f)
with open(self.ds_config_file[ZERO3], encoding="utf-8") as f:
config_zero3 = json.load(f)
config_zero3["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"] = False
self.ds_config_dict = dict(zero2=config_zero2, zero3=config_zero3)
self.dist_env = dict(
ACCELERATE_USE_DEEPSPEED="true",
MASTER_ADDR="localhost",
MASTER_PORT="10999",
RANK="0",
LOCAL_RANK="0",
WORLD_SIZE="1",
)
def test_gradient_accumulation_boundary_integration(self):
"""Test that gradient accumulation boundaries are automatically handled by DeepSpeed integration."""
gradient_accumulation_steps = 4
deepspeed_plugin = DeepSpeedPlugin(
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_clipping=1.0,
zero_stage=2,
offload_optimizer_device="cpu",
offload_param_device="cpu",
zero3_save_16bit_model=False,
zero3_init_flag=False,
)
with patch_environment(**self.dist_env):
accelerator = Accelerator(mixed_precision="fp16", deepspeed_plugin=deepspeed_plugin)
# Setup simple training components
train_set = RegressionDataset(length=80)
train_dataloader = DataLoader(train_set, batch_size=16, shuffle=True)
model = AutoModel.from_pretrained(GPT2_TINY)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
model.train()
# Test gradient accumulation with accumulate context manager
batch_data = next(iter(train_dataloader))
# Create proper input format for GPT2 model (RegressionDataset returns {"x": scalar, "y": scalar})
# We need to create dummy input_ids for the GPT2 model
batch_size = batch_data["x"].shape[0] if isinstance(batch_data["x"], torch.Tensor) else 1
# Create dummy input_ids for GPT2 model and move to same device as model
device = next(model.parameters()).device
input_ids = torch.randint(0, 1000, (batch_size, 10), device=device) # batch_size x sequence_length
inputs = {"input_ids": input_ids}
# Track sync_gradients values to verify correct gradient accumulation behavior
sync_values = []
# Simulate gradient accumulation steps
for micro_step in range(gradient_accumulation_steps):
with accelerator.accumulate(model):
sync_values.append(accelerator.sync_gradients)
outputs = model(**inputs)
# Use the last hidden state and create a simple loss
prediction = outputs.last_hidden_state.mean()
loss = prediction.sum() # Simple scalar loss
# This should automatically handle gradient accumulation boundaries
accelerator.backward(loss)
if accelerator.sync_gradients:
optimizer.step()
optimizer.zero_grad()
# Verify gradient accumulation pattern was correct
# Should be False for first 3 steps, True for the last step
expected_sync = [False, False, False, True]
self.assertEqual(sync_values, expected_sync)
# Reset step counter for accelerator
accelerator.step = 0
def test_clip_grad_norm_returns_deepspeed_grad_norm(self):
"""Test that clip_grad_norm_ works with DeepSpeed and returns gradient norm when available."""
deepspeed_plugin = DeepSpeedPlugin(
gradient_accumulation_steps=1,
gradient_clipping=1.0,
zero_stage=2,
offload_optimizer_device="cpu",
offload_param_device="cpu",
zero3_save_16bit_model=False,
zero3_init_flag=False,
)
with patch_environment(**self.dist_env):
accelerator = Accelerator(mixed_precision="fp16", deepspeed_plugin=deepspeed_plugin)
# Setup simple model
model = AutoModel.from_pretrained(GPT2_TINY)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
# Create a simple dataloader for prepare to work
train_set = RegressionDataset(length=16)
train_dataloader = DataLoader(train_set, batch_size=16, shuffle=True)
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
# Perform a forward and backward pass to generate gradients
batch_data = next(iter(train_dataloader))
batch_size = len(batch_data["x"]) if isinstance(batch_data["x"], torch.Tensor) else 1
# Create dummy input_ids for GPT2 model and move to same device as model
device = next(model.parameters()).device
input_ids = torch.randint(0, 1000, (batch_size, 10), device=device)
inputs = {"input_ids": input_ids}
# Forward pass
outputs = model(**inputs)
prediction = outputs.last_hidden_state.mean()
loss = prediction.sum()
# Backward pass to generate gradients
accelerator.backward(loss)
# Test that gradient clipping works and returns a value
grad_norm = accelerator.clip_grad_norm_(model.parameters(), max_norm=1.0)
# After backward pass, we should get a valid gradient norm (either from DeepSpeed or fallback)
self.assertIsInstance(grad_norm, (int, float, type(None)))
if grad_norm is not None:
self.assertGreaterEqual(grad_norm, 0.0)
def test_accelerator_backward_passes_sync_gradients(self):
"""Test that Accelerator.backward() passes sync_gradients to DeepSpeed wrapper."""
deepspeed_plugin = DeepSpeedPlugin(
gradient_accumulation_steps=2,
gradient_clipping=1.0,
zero_stage=2,
offload_optimizer_device="cpu",
offload_param_device="cpu",
zero3_save_16bit_model=False,
zero3_init_flag=False,
)
with patch_environment(**self.dist_env):
accelerator = Accelerator(mixed_precision="fp16", deepspeed_plugin=deepspeed_plugin)
# Setup simple model and data
model = AutoModel.from_pretrained(GPT2_TINY)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
train_set = RegressionDataset(length=16)
train_dataloader = DataLoader(train_set, batch_size=8, shuffle=True)
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
# Track sync_gradients values during backward calls
sync_values = []
# Test two gradient accumulation steps
batch_data = next(iter(train_dataloader))
# Create proper input format for GPT2 model
batch_size = len(batch_data["x"]) if isinstance(batch_data["x"], torch.Tensor) else 1
# Create dummy input_ids for GPT2 model and move to same device as model
device = next(model.parameters()).device
input_ids = torch.randint(0, 1000, (batch_size, 10), device=device)
inputs = {"input_ids": input_ids}
# First step - should have sync_gradients=False
with accelerator.accumulate(model):
sync_values.append(accelerator.sync_gradients)
outputs = model(**inputs)
prediction = outputs.last_hidden_state.mean()
loss = prediction # Simple loss
accelerator.backward(loss)
# Second step - should have sync_gradients=True
with accelerator.accumulate(model):
sync_values.append(accelerator.sync_gradients)
outputs = model(**inputs)
prediction = outputs.last_hidden_state.mean()
loss = prediction # Simple loss
accelerator.backward(loss)
# Verify sync_gradients pattern was correct
self.assertEqual(len(sync_values), 2)
self.assertFalse(sync_values[0]) # First step: not syncing
self.assertTrue(sync_values[1]) # Second step: syncing

View File

@ -29,17 +29,15 @@ from accelerate.test_utils.testing import (
get_launch_command,
path_in_accelerate_package,
require_fp16,
require_fsdp2,
require_multi_device,
require_non_cpu,
require_non_torch_xla,
require_torch_min_version,
run_first,
slow,
)
from accelerate.utils import is_bf16_available, is_fp16_available, is_hpu_available, patch_environment, set_seed
from accelerate.utils.constants import (
CONTEXT_PARALLEL_PYTORCH_VERSION,
FSDP2_PYTORCH_VERSION,
FSDP2_STATE_DICT_TYPE,
FSDP_AUTO_WRAP_POLICY,
FSDP_BACKWARD_PREFETCH,
@ -48,7 +46,6 @@ from accelerate.utils.constants import (
)
from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin
from accelerate.utils.fsdp_utils import disable_fsdp_ram_efficient_loading, enable_fsdp_ram_efficient_loading
from accelerate.utils.versions import is_torch_version
set_seed(42)
@ -65,10 +62,6 @@ if is_fp16_available():
if is_bf16_available():
dtypes.append(BF16)
FSDP_VERSIONS = [1]
if is_torch_version(">=", FSDP2_PYTORCH_VERSION):
FSDP_VERSIONS.append(2)
@require_non_cpu
@require_non_torch_xla
@ -91,6 +84,7 @@ class FSDPPluginIntegration(AccelerateTestCase):
1: self.fsdp1_env,
2: self.fsdp2_env,
}
self.current_fsdp_version = 1
def test_sharding_strategy(self):
@ -322,6 +316,9 @@ class FSDPPluginIntegration(AccelerateTestCase):
AcceleratorState._reset_state(True)
env = self.fsdp_envs[fsdp_version].copy()
with patch_environment(**env):
plugin = FullyShardedDataParallelPlugin(mixed_precision_policy=mp_dtype)
assert plugin.mixed_precision_policy == mp_policy
with patch_environment(**env):
plugin = FullyShardedDataParallelPlugin(
mixed_precision_policy={"param_dtype": dtype, "reduce_dtype": dtype, **{extra_arg: dtype}}
@ -404,25 +401,26 @@ class FSDPPluginIntegration(AccelerateTestCase):
assert fsdp_plugin.cpu_ram_efficient_loading is False
assert os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING") == "False"
def test_cp(self):
if (fsdp_version := self.current_fsdp_version) != 2:
return
env = self.fsdp_envs[fsdp_version].copy()
for cp_comm_strategy in ["allgather", "alltoall"]:
env["FSDP_CP_COMM_STRATEGY"] = cp_comm_strategy
env["FSDP_CP_SIZE"] = "2"
with patch_environment(**env):
fsdp_plugin = FullyShardedDataParallelPlugin()
assert fsdp_plugin.cp_comm_strategy == cp_comm_strategy
env = self.fsdp_envs[fsdp_version].copy()
env["FSDP_CP_SIZE"] = "2"
with patch_environment(**env):
fsdp_plugin = FullyShardedDataParallelPlugin(cp_comm_strategy=cp_comm_strategy)
assert fsdp_plugin.cp_comm_strategy == cp_comm_strategy
def test_ignored_modules_regex(self):
# Check that FSDP's ignored_modules can be a string, in which case it is treated as a regex
env = self.fsdp_envs[1].copy()
env["FSDP_IGNORED_MODULES"] = ".*\\.q_proj$"
with patch_environment(**env):
accelerator = Accelerator()
model = AutoModel.from_pretrained(LLAMA_TESTING)
model = accelerator.prepare(model)
if self.current_fsdp_version == 1:
# model has 2 layers
layers_to_ignore = {model.layers[0].self_attn.q_proj, model.layers[1].self_attn.q_proj}
assert model._ignored_modules == layers_to_ignore
else:
params_to_ignore = {model.layers[0].self_attn.q_proj.weight, model.layers[1].self_attn.q_proj.weight}
assert model._ignored_params == params_to_ignore
@require_fsdp2
@require_non_cpu
@require_non_torch_xla
class FSDP2PluginIntegration(FSDPPluginIntegration):
def setUp(self):
super().setUp()
@ -469,6 +467,7 @@ class FSDPIntegrationTest(TempDirTestCase):
}
self.n_train = 160
self.n_val = 160
self.current_fsdp_version = 1
@require_fp16
@ -624,24 +623,13 @@ class FSDPIntegrationTest(TempDirTestCase):
with patch_environment(omp_num_threads=1):
execute_subprocess_async(cmd_config)
# TODO: Should probably be moved to a separate test file
@require_torch_min_version(version=CONTEXT_PARALLEL_PYTORCH_VERSION)
def test_dist_dataloader(self):
if (fsdp_version := self.current_fsdp_version) != 2:
return
self.test_file_path = self.test_scripts_folder / "test_distributed_dataloader.py"
cmd = get_launch_command(num_processes=2, num_machines=1, machine_rank=0, fsdp_version=fsdp_version)
cmd_config = cmd.copy()
cmd_config.extend(["--use_fsdp", "--fsdp_cp_size=2"])
cmd_config.append(self.test_file_path)
with patch_environment(omp_num_threads=1):
execute_subprocess_async(cmd_config)
@require_fsdp2
@run_first
# Skip this test when TorchXLA is available because accelerate.launch does not support TorchXLA FSDP.
@require_non_torch_xla
@require_multi_device
@slow
class FSDP2IntegrationTest(FSDPIntegrationTest):
def setUp(self):
super().setUp()

View File

@ -17,6 +17,7 @@ import os
import pickle
import tempfile
import time
from unittest import skip
from unittest.mock import patch
import psutil
@ -478,6 +479,7 @@ class AcceleratorTester(AccelerateTestCase):
@require_cuda_or_xpu
@slow
@require_bnb
@skip("Passing locally but not on CI. Also no one will try to train an offloaded bnb model")
def test_accelerator_bnb_cpu_error(self):
"""Tests that the accelerator can be used with the BNB library. This should fail as we are trying to load a model
that is loaded between cpu and gpu"""

View File

@ -625,7 +625,7 @@ class ToFSDP2Tester(unittest.TestCase):
with self.assertLogs(level="WARNING") as cm:
to_fsdp2_command(args)
assert "Config already specfies FSDP2, skipping conversion..." in cm.output[0]
assert "Config already specifies FSDP2, skipping conversion..." in cm.output[0]
# Has to be the last test because it overwrites the config file
def test_fsdp2_overwrite(self):

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest import skip
import torch
from torch.utils.benchmark import Timer
@ -34,8 +35,8 @@ else:
backend = "inductor"
@require_non_hpu
@require_huggingface_suite
@skip("Don't work with torch 2.8")
class RegionalCompilationTester(unittest.TestCase):
def _get_model_and_inputs(self):
from transformers import AutoConfig, AutoModelForCausalLM
@ -109,6 +110,7 @@ class RegionalCompilationTester(unittest.TestCase):
release_memory(model, full_compilation_model, regional_compilation_model)
@slow
@require_non_hpu
@require_non_cpu
@require_huggingface_suite
def test_regional_compilation_inference_speedup(self):

View File

@ -15,6 +15,7 @@ fsdp_config:
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: BertLayer
fsdp_use_orig_params: true
fsdp_ignored_modules: null
machine_rank: 0
main_training_function: main
mixed_precision: 'no'

250
tests/test_dataclasses.py Normal file
View File

@ -0,0 +1,250 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock, patch
import pytest
from accelerate.parallelism_config import ParallelismConfig
from accelerate.utils import patch_environment
from accelerate.utils.constants import (
BETA_CP_AVAILABLE_PYTORCH_VERSION,
BETA_TP_AVAILABLE_PYTORCH_VERSION,
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION,
)
from accelerate.utils.imports import is_transformers_available
from accelerate.utils.versions import compare_versions, is_torch_version
def _should_skip_cp_test(cp_size):
"""Check if CP test should be skipped based on cp_size and torch version."""
return cp_size > 1 and not is_torch_version(">=", BETA_CP_AVAILABLE_PYTORCH_VERSION)
def _should_skip_tp_test(tp_size):
"""Check if TP test should be skipped based on tp_size, torch version, and transformers availability."""
if tp_size <= 1:
return False
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
return True
if not is_transformers_available():
return True
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
return True
return False
class TestParallelismConfig:
@pytest.fixture(autouse=True)
def mock_init_device_mesh(self):
def mock_init_mesh(device_type, mesh_shape, mesh_dim_names):
mesh = Mock()
mesh.size.return_value = 1
for dim in mesh_shape:
mesh.size.return_value *= dim
mesh.shape = mesh_shape
mesh.mesh_dim_names = mesh_dim_names
# mock device_mesh._flatten
mesh.flattened_dims = []
def mock_getitem(key):
submesh = Mock()
def mock_flatten(name):
mesh.flattened_dims.append((key, name))
submesh._flatten = Mock(side_effect=mock_flatten)
return submesh
mesh.__getitem__ = Mock(side_effect=mock_getitem)
return mesh
with patch("torch.distributed.device_mesh.init_device_mesh", side_effect=mock_init_mesh):
yield mock_init_mesh
@pytest.mark.parametrize(
"dp_replicate_size, dp_shard_size, tp_size, cp_size, expected_shape, expected_dim_names",
[
(8, 1, 1, 1, (8,), ("dp_replicate",)), # DDP
(1, 8, 1, 1, (8,), ("dp_shard",)), # FSDP
(2, 4, 1, 1, (2, 4), ("dp_replicate", "dp_shard")), # HSDP
(1, 4, 2, 1, (4, 2), ("dp_shard", "tp")), # FSDP + TP
(2, 2, 2, 1, (2, 2, 2), ("dp_replicate", "dp_shard", "tp")), # HSDP + TP
(1, 1, 8, 1, (8,), ("tp",)), # TP only
(1, 1, 1, 4, (4,), ("cp",)), # CP only
(1, 4, 1, 2, (4, 2), ("dp_shard", "cp")), # FSDP + CP
(1, 2, 2, 2, (2, 2, 2), ("dp_shard", "cp", "tp")), # FSDP + CP + TP
(2, 2, 2, 2, (2, 2, 2, 2), ("dp_replicate", "dp_shard", "cp", "tp")), # HSDP + CP + TP
],
)
def test_get_mesh(
self,
dp_replicate_size,
dp_shard_size,
tp_size,
cp_size,
expected_shape,
expected_dim_names,
):
# Skip tests based on version requirements
if _should_skip_cp_test(cp_size):
pytest.skip(f"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
if _should_skip_tp_test(tp_size):
pytest.skip(
f"tests with `tp_size>1` require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}"
)
config = ParallelismConfig(
dp_replicate_size=dp_replicate_size, dp_shard_size=dp_shard_size, tp_size=tp_size, cp_size=cp_size
)
mesh_dim_names, mesh_shape = config._get_mesh()
assert mesh_shape == expected_shape
assert mesh_dim_names == expected_dim_names
@pytest.mark.parametrize(
"dp_replicate_size, dp_shard_size, tp_size, cp_size, expected_shape, expected_dim_names",
[
(8, 1, 1, 1, (8,), ("dp_replicate",)),
(1, 8, 1, 1, (8,), ("dp_shard",)),
(2, 4, 1, 1, (2, 4), ("dp_replicate", "dp_shard")),
(1, 4, 2, 1, (4, 2), ("dp_shard", "tp")),
(2, 2, 2, 1, (2, 2, 2), ("dp_replicate", "dp_shard", "tp")),
(1, 1, 8, 1, (8,), ("tp",)),
(1, 1, 1, 4, (4,), ("cp",)),
(1, 4, 1, 2, (4, 2), ("dp_shard", "cp")),
(1, 2, 2, 2, (2, 2, 2), ("dp_shard", "cp", "tp")),
(2, 2, 2, 2, (2, 2, 2, 2), ("dp_replicate", "dp_shard", "cp", "tp")),
],
)
def test_build_device_mesh(
self,
dp_replicate_size,
dp_shard_size,
tp_size,
cp_size,
expected_shape,
expected_dim_names,
):
"""Test build_device_mesh creates correct mesh and applies flattening."""
# Skip tests based on version requirements
if _should_skip_cp_test(cp_size):
pytest.skip(f"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
if _should_skip_tp_test(tp_size):
pytest.skip(
f"tests with `tp_size>1` require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}"
)
config = ParallelismConfig(
dp_replicate_size=dp_replicate_size, dp_shard_size=dp_shard_size, tp_size=tp_size, cp_size=cp_size
)
device_mesh = config.build_device_mesh("cpu")
# Check mesh shape and dimension names match expected
assert device_mesh.shape == expected_shape
assert device_mesh.mesh_dim_names == expected_dim_names
# Check that correct flattening operations were called
expected_flattened = []
if config.dp_dim_names:
expected_flattened.append((config.dp_dim_names, "dp"))
if config.dp_shard_cp_dim_names:
expected_flattened.append((config.dp_shard_cp_dim_names, "dp_shard_cp"))
if config.dp_cp_dim_names:
expected_flattened.append((config.dp_cp_dim_names, "dp_cp"))
assert device_mesh.flattened_dims == expected_flattened
@pytest.mark.parametrize(
"dp_replicate_size, dp_shard_size, tp_size, cp_size",
[
(8, 1, 1, 1),
(1, 8, 1, 1),
(2, 4, 1, 1),
(1, 4, 2, 1),
(2, 2, 2, 1),
(1, 1, 8, 1),
(1, 1, 1, 4),
(1, 4, 1, 2),
(1, 2, 2, 2),
(2, 2, 2, 2),
],
)
def test_from_env(
self,
dp_replicate_size,
dp_shard_size,
tp_size,
cp_size,
):
if _should_skip_cp_test(cp_size):
pytest.skip(f"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
if _should_skip_tp_test(tp_size):
pytest.skip(
f"tests with `tp_size>1` require torch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}, transformers available and >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}"
)
new_env = {
"PARALLELISM_CONFIG_DP_REPLICATE_SIZE": dp_replicate_size,
"PARALLELISM_CONFIG_DP_SHARD_SIZE": dp_shard_size,
"PARALLELISM_CONFIG_TP_SIZE": tp_size,
"PARALLELISM_CONFIG_CP_SIZE": cp_size,
}
with patch_environment(**new_env):
config = ParallelismConfig()
for key, value in new_env.items():
assert getattr(config, key.split("PARALLELISM_CONFIG_")[-1].lower()) == value
def test_cp_handler(self):
"""Test CP handler with various configurations."""
# Any cp_size > 1 requires torch >= BETA_CP_AVAILABLE_PYTORCH_VERSION, we use placeholder for this check as this test doesn't depend on a specific size
if _should_skip_cp_test(2):
pytest.skip(f"tests with `cp_size>1` require torch >= {BETA_CP_AVAILABLE_PYTORCH_VERSION}")
from accelerate.utils import TorchContextParallelConfig
for setting in ("allgather", "alltoall"):
cp_handler = TorchContextParallelConfig(cp_comm_strategy=setting)
pc = ParallelismConfig(cp_size=2, cp_handler=cp_handler)
assert pc.cp_handler is not None, "CP handler should be set"
assert pc.cp_handler.cp_comm_strategy == setting, (
f"CP handler strategy should be {setting} but got {pc.cp_handler.cp_comm_strategy}"
)
for setting in ("allgather", "alltoall"):
with patch_environment(PARALLELISM_CONFIG_CP_COMM_STRATEGY=setting):
pc = ParallelismConfig(cp_size=2)
assert pc.cp_handler is not None, "CP handler should be set from environment"
assert pc.cp_handler.cp_comm_strategy == setting, (
f"CP handler strategy should be {setting} but got {pc.cp_handler.cp_comm_strategy}"
)
for setting in ("invalid", "unsupported"):
with pytest.raises(ValueError, match=f"Invalid cp_comm_strategy: {setting}"):
TorchContextParallelConfig(cp_comm_strategy=setting)
with patch_environment(PARALLELISM_CONFIG_CP_COMM_STRATEGY=setting):
with pytest.raises(ValueError, match=f"Invalid cp_comm_strategy: {setting}"):
pc = ParallelismConfig(cp_size=2)
def test_tp_handler(self):
assert True, "Tensor parallelism handler doesn't hold any logic yet"

View File

@ -19,7 +19,7 @@ import shutil
import tempfile
import unittest
from pathlib import Path
from unittest import mock
from unittest import mock, skip
import torch
@ -239,7 +239,10 @@ class FeatureExamplesTests(TempDirTestCase):
run_command(self.launch_args + testargs)
@require_trackers
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
@mock.patch.dict(
os.environ,
{"WANDB_MODE": "offline", "DVCLIVE_TEST": "true", "SWANLAB_MODE": "offline"},
)
def test_tracking(self):
with tempfile.TemporaryDirectory() as tmpdir:
testargs = f"""
@ -294,12 +297,14 @@ class FeatureExamplesTests(TempDirTestCase):
@require_pippy
@require_multi_device
@skip("Will soon deprecate pippy")
def test_pippy_examples_bert(self):
testargs = ["examples/inference/pippy/bert.py"]
run_command(self.launch_args + testargs)
@require_pippy
@require_multi_device
@skip("Will soon deprecate pippy")
def test_pippy_examples_gpt2(self):
testargs = ["examples/inference/pippy/gpt2.py"]
run_command(self.launch_args + testargs)

View File

@ -12,9 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import tempfile
import textwrap
import unittest
from pathlib import Path
import torch
@ -27,24 +31,29 @@ from accelerate.test_utils import (
require_multi_device,
require_torchao,
require_transformer_engine,
require_transformer_engine_mxfp8,
run_first,
)
from accelerate.test_utils.testing import require_deepspeed, run_command
from accelerate.utils import (
AORecipeKwargs,
FP8RecipeKwargs,
TERecipeKwargs,
has_ao_layers,
has_transformer_engine_layers,
is_torchao_available,
is_transformer_engine_available,
)
def can_convert_te_model():
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [FP8RecipeKwargs(backend="TE")]}
def can_convert_te_model(from_config=False):
if not from_config:
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [TERecipeKwargs()]}
else:
accelerator_kwargs = {}
accelerator = Accelerator(**accelerator_kwargs)
assert accelerator.fp8_enabled, "FP8 is not enabled"
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.Linear(32, 16))
model = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.LayerNorm(32, bias=False), torch.nn.Linear(32, 16))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
@ -58,10 +67,14 @@ def maintain_proper_deepspeed_config(expected_version):
)
def can_convert_ao_model():
def can_convert_ao_model(from_config=False):
from transformers import AutoModelForSequenceClassification
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [AORecipeKwargs()]}
if not from_config:
accelerator_kwargs = {"mixed_precision": "fp8", "kwargs_handlers": [AORecipeKwargs()]}
else:
accelerator_kwargs = {}
accelerator = Accelerator(**accelerator_kwargs)
dataloader = torch.utils.data.DataLoader(torch.randn(10, 32), batch_size=2)
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased")
@ -78,13 +91,51 @@ def can_convert_ao_model():
class TestTransformerEngine(unittest.TestCase):
def test_can_prepare_model_single_gpu(self):
command = get_launch_command(num_processes=1, monitor_interval=0.1)
command += ["-m", "tests.test_fp8"]
command += ["-m", "tests.test_fp8", "--test_te"]
run_command(command)
def test_can_prepare_model_single_gpu_from_config(self):
with tempfile.TemporaryDirectory() as dir_name:
config_file = Path(dir_name) / "config.yaml"
config_file.write_text(
textwrap.dedent(
"""
distributed_type: "NO"
num_processes: 1
mixed_precision: fp8
fp8_config:
backend: TE
"""
)
)
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
run_command(command)
@require_transformer_engine_mxfp8
def test_can_prepare_model_with_mxfp8_block_scaling(self):
with tempfile.TemporaryDirectory() as dir_name:
config_file = Path(dir_name) / "config.yaml"
config_file.write_text(
textwrap.dedent(
"""
distributed_type: "NO"
num_processes: 1
mixed_precision: fp8
fp8_config:
backend: TE
use_mxfp8_block_scaling: true
"""
)
)
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
run_command(command)
@require_multi_device
def test_can_prepare_model_multi_gpu(self):
command = get_launch_command(num_processes=2, monitor_interval=0.1)
command += ["-m", "tests.test_fp8"]
command += ["-m", "tests.test_fp8", "--test_te"]
run_command(command)
@require_deepspeed
@ -116,7 +167,36 @@ class TestTransformerEngine(unittest.TestCase):
command = get_launch_command(
num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config
)
command += ["-m", "tests.test_fp8"]
command += ["-m", "tests.test_fp8", "--test_te"]
run_command(command)
@require_deepspeed
@require_multi_device
def test_can_prepare_model_multigpu_deepspeed_from_config(self):
os.environ["ZERO_STAGE"] = str(1)
with tempfile.TemporaryDirectory() as dir_name:
config_file = Path(dir_name) / "config.yaml"
config_file.write_text(
textwrap.dedent(
"""
distributed_type: "DEEPSPEED"
deepspeed_config:
gradient_clipping: 1.0
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 1
deepspeed_multinode_launcher: standard
num_processes: 2
mixed_precision: fp8
fp8_config:
backend: TE
"""
)
)
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
command += ["-m", "tests.test_fp8", "--test_te", "--from_config"]
run_command(command)
@ -125,13 +205,31 @@ class TestTransformerEngine(unittest.TestCase):
class TestTorchAO(unittest.TestCase):
def test_can_prepare_model_single_accelerator(self):
command = get_launch_command(num_processes=1, monitor_interval=0.1)
command += ["-m", "tests.test_fp8"]
command += ["-m", "tests.test_fp8", "--test_ao"]
run_command(command)
def test_can_prepare_model_single_gpu_from_config(self):
with tempfile.TemporaryDirectory() as dir_name:
config_file = Path(dir_name) / "config.yaml"
config_file.write_text(
textwrap.dedent(
"""
distributed_type: "NO"
num_processes: 1
mixed_precision: fp8
fp8_config:
backend: AO
"""
)
)
command = get_launch_command(config_file=str(config_file), monitor_interval=0.1)
command += ["-m", "tests.test_fp8", "--test_ao", "--from_config"]
run_command(command)
@require_multi_device
def test_can_prepare_model_multi_accelerator(self):
command = get_launch_command(num_processes=2, monitor_interval=0.1)
command += ["-m", "tests.test_fp8"]
command += ["-m", "tests.test_fp8", "--test_ao"]
run_command(command)
@require_deepspeed
@ -163,16 +261,26 @@ class TestTorchAO(unittest.TestCase):
command = get_launch_command(
num_processes=2, monitor_interval=0.1, use_deepspeed=True, deepspeed_config_file=ds_config
)
command += ["-m", "tests.test_fp8"]
command += ["-m", "tests.test_fp8", "--test_ao"]
run_command(command)
if __name__ == "__main__":
# TE suite
if is_transformer_engine_available():
can_convert_te_model()
parser = argparse.ArgumentParser()
parser.add_argument("--test_te", action="store_true", default=False)
parser.add_argument("--test_ao", action="store_true", default=False)
parser.add_argument("--from_config", action="store_true", default=False)
args = parser.parse_args()
if not args.test_te and not args.test_ao:
raise ValueError("Must specify at least one of --test_te or --test_ao")
if args.test_te:
can_convert_te_model(args.from_config)
if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
maintain_proper_deepspeed_config(int(os.environ.get("ZERO_STAGE")))
# AO suite
if is_torchao_available():
can_convert_ao_model()
if args.test_ao:
can_convert_ao_model(args.from_config)

View File

@ -24,8 +24,10 @@ from torch.fx import symbolic_trace
from accelerate.big_modeling import attach_layerwise_casting_hooks
from accelerate.hooks import (
AlignDevicesHook,
CpuOffload,
ModelHook,
SequentialHook,
UserCpuOffloadHook,
add_hook_to_module,
attach_align_device_hook,
remove_hook_from_module,
@ -457,3 +459,58 @@ class HooksModelTester(unittest.TestCase):
with torch.no_grad():
_ = test_model(inputs)
def test_cpu_offload_hook_moves_model(self):
if not torch.cuda.is_available():
self.skipTest("CUDA not available for offload test.")
model = ModelForTest()
gpu_device = torch.device("cuda:0")
hook = CpuOffload(execution_device=gpu_device)
add_hook_to_module(model, hook)
x = torch.randn(2, 3).to(gpu_device)
output = model(x)
self.assertEqual(output.device, gpu_device)
remove_hook_from_module(model)
output2 = model(x)
self.assertEqual(output2.device, gpu_device)
# should be on the gpu
assert model.linear1.weight.device == gpu_device
assert model.batchnorm.weight.device == gpu_device
assert model.linear2.weight.device == gpu_device
def test_cpu_offload_hook_with_prev_module(self):
if not torch.cuda.is_available():
self.skipTest("CUDA not available for offload test.")
model1 = ModelForTest()
model2 = ModelForTest()
gpu_device = torch.device("cuda:0")
cpu_device = torch.device("cpu")
hook1 = CpuOffload(execution_device=gpu_device)
add_hook_to_module(model1, hook1)
user_hook1 = UserCpuOffloadHook(model1, hook1)
hook2 = CpuOffload(execution_device=gpu_device, prev_module_hook=user_hook1)
add_hook_to_module(model2, hook2)
x = torch.randn(2, 3).to(gpu_device)
output1 = model1(x)
self.assertEqual(output1.device, gpu_device)
output2 = model2(x)
self.assertEqual(output2.device, gpu_device)
# should be on the cpu
assert model1.linear1.weight.device == cpu_device
assert model1.batchnorm.weight.device == cpu_device
assert model1.linear2.weight.device == cpu_device
# should be on the gpu still
assert model2.linear1.weight.device == gpu_device
assert model2.batchnorm.weight.device == gpu_device
assert model2.linear2.weight.device == gpu_device

View File

@ -64,6 +64,7 @@ class TestPrepareMultiGpuEnv(unittest.TestCase):
num_cpu_threads_per_process=1,
enable_cpu_affinity=False,
same_network=False,
use_parallelism_config=False,
)
prepare_multi_gpu_env(args)

View File

@ -61,7 +61,31 @@ class MemoryTest(unittest.TestCase):
raise_fake_out_of_memory()
mock_training_loop_function()
assert batch_sizes == [128, 64, 32, 16, 8]
assert batch_sizes == [
128,
115,
103,
92,
82,
73,
65,
58,
52,
46,
41,
36,
32,
28,
25,
22,
19,
17,
15,
13,
11,
9,
8,
]
def test_memory_explicit(self):
batch_sizes = []
@ -75,7 +99,31 @@ class MemoryTest(unittest.TestCase):
return batch_size, arg1
bs, arg1 = mock_training_loop_function("hello")
assert batch_sizes == [128, 64, 32, 16, 8]
assert batch_sizes == [
128,
115,
103,
92,
82,
73,
65,
58,
52,
46,
41,
36,
32,
28,
25,
22,
19,
17,
15,
13,
11,
9,
8,
]
assert [bs, arg1] == [8, "hello"]
def test_start_zero(self):

View File

@ -349,6 +349,26 @@ class ModelingUtilsTester(unittest.TestCase):
check_device_map(model, {"linear1": 0, "linear2": 1, "batchnorm": 1})
def test_check_device_map_invalid_keys(self):
model = ModelForTest()
device_map = {
"linear1": "cpu", # Valid module
"batchnorm": "cpu", # Valid module
"linear2": "cpu", # Valid module
"invalid_module": 0, # Invalid - should trigger warning
"another_invalid": 1, # Invalid - should trigger warning
}
# Test for the warning about invalid keys
with self.assertWarns(UserWarning) as cm:
check_device_map(model, device_map)
warning_msg = str(cm.warning)
self.assertIn("device_map keys do not match any submodules", warning_msg)
self.assertIn("invalid_module", warning_msg)
self.assertIn("another_invalid", warning_msg)
def shard_test_model(self, model, tmp_dir):
module_index = {
"linear1": "checkpoint_part1.bin",

View File

@ -14,6 +14,7 @@
import inspect
import unittest
from unittest import skip
import torch
@ -28,7 +29,6 @@ from accelerate.test_utils import (
path_in_accelerate_package,
require_huggingface_suite,
require_multi_device,
require_non_hpu,
require_non_torch_xla,
require_pippy,
require_torchvision,
@ -70,7 +70,6 @@ class MultiDeviceTester(unittest.TestCase):
execute_subprocess_async(cmd)
@run_first
@require_non_hpu # Synapse detected a device critical error that requires a restart
@require_multi_device
def test_multi_device_merge_fsdp_weights(self):
print(f"Found {device_count} {torch_device} devices.")
@ -111,6 +110,7 @@ class MultiDeviceTester(unittest.TestCase):
@require_torchvision
@require_multi_device
@require_huggingface_suite
@skip("Will soon deprecate pippy")
def test_pippy(self):
"""
Checks the integration with the pippy framework

View File

@ -16,6 +16,7 @@ import csv
import json
import logging
import os
import random
import re
import subprocess
import tempfile
@ -42,7 +43,9 @@ from accelerate.test_utils.testing import (
require_matplotlib,
require_mlflow,
require_pandas,
require_swanlab,
require_tensorboard,
require_trackio,
require_wandb,
skip,
)
@ -53,7 +56,9 @@ from accelerate.tracking import (
DVCLiveTracker,
GeneralTracker,
MLflowTracker,
SwanLabTracker,
TensorBoardTracker,
TrackioTracker,
WandBTracker,
)
from accelerate.utils import (
@ -520,6 +525,123 @@ class ClearMLTest(TempDirTestCase, MockingTestCase):
self.assertCountEqual(plot["data"][0]["cells"]["values"], [[1, 2], [3, 4], [5, 6]])
@require_swanlab
@mock.patch.dict(os.environ, {"SWANLAB_MODE": "offline"})
class SwanLabTrackingTest(TempDirTestCase, MockingTestCase):
def setUp(self):
super().setUp()
# Setting Path where SwanLab parsed log files are saved via the SWANLAB_LOG_DIR env var
self.add_mocks(mock.patch.dict(os.environ, {"SWANLAB_LOG_DIR": self.tmpdir}))
@skip
def test_swanlab(self):
# Disable hardware monitoring to prevent errors in test mode.
import swanlab
from swanlab.log.backup import BackupHandler
from swanlab.log.backup.datastore import DataStore
from swanlab.log.backup.models import ModelsParser
swanlab.merge_settings(swanlab.Settings(hardware_monitor=False))
# Start a fake training session.
accelerator = Accelerator(log_with="swanlab")
project_name = "test_project_with_config"
experiment_name = "test"
description = "test project for swanlab"
tags = ["my_tag"]
config = {
"epochs": 10,
"learning_rate": 0.01,
"offset": 0.1,
}
kwargs = {
"swanlab": {
"experiment_name": experiment_name,
"description": description,
"tags": tags,
}
}
accelerator.init_trackers(project_name, config, kwargs)
record_metrics = []
record_scalars = []
record_images_count = 0
record_logs = []
for epoch in range(1, swanlab.config.epochs):
acc = 1 - 2**-epoch - random.random() / epoch - 0.1
loss = 2**-epoch + random.random() / epoch + 0.1
ll = swanlab.log(
{
"accuracy": acc,
"loss": loss,
"image": swanlab.Image(np.random.random((3, 3, 3))),
},
step=epoch,
)
log = f"epoch={epoch}, accuracy={acc}, loss={loss}"
print(log)
record_scalars.extend([acc, loss])
record_images_count += 1
record_logs.append(log)
record_metrics.extend([x for _, x in ll.items()])
accelerator.end_training()
# Load latest offline log
run_dir = swanlab.get_run().public.run_dir
assert os.path.exists(run_dir) is True
ds = DataStore()
ds.open_for_scan(os.path.join(run_dir.__str__(), BackupHandler.BACKUP_FILE).__str__())
with ModelsParser() as models_parser:
for record in ds:
if record is None:
continue
models_parser.parse_record(record)
header, project, experiment, logs, runtime, columns, scalars, medias, footer = models_parser.get_parsed()
# test file header
assert header.backup_type == "DEFAULT"
# test project info
assert project.name == project_name
assert project.workspace is None
assert project.public is None
# test experiment info
assert experiment.name is not None
assert experiment.description == description
assert experiment.tags == tags
# test log record
backup_logs = [log.message for log in logs]
for record_log in record_logs:
assert record_log in backup_logs, "Log not found in backup logs: " + record_log
# test runtime info
runtime_info = runtime.to_file_model(os.path.join(run_dir.__str__(), "files"))
assert runtime_info.conda is None, "Not using conda, should be None"
assert isinstance(runtime_info.requirements, str), "Requirements should be a string"
assert isinstance(runtime_info.metadata, dict), "Metadata should be a dictionary"
assert isinstance(runtime_info.config, dict), "Config should be a dictionary"
for key in runtime_info.config:
assert key in config, f"Config key {key} not found in original config"
assert runtime_info.config[key]["value"] == config[key], (
f"Config value for {key} does not match original value"
)
# test scalar
assert len(scalars) + len(medias) == len(record_metrics), "Total metrics count does not match"
backup_scalars = [
metric.metric["data"]
for metric in record_metrics
if metric.column_info.chart_type.value.column_type == "FLOAT"
]
assert len(backup_scalars) == len(scalars), "Total scalars count does not match"
for scalar in backup_scalars:
assert scalar in record_scalars, f"Scalar {scalar} not found in original scalars"
backup_images = [
metric for metric in record_metrics if metric.column_info.chart_type.value.column_type == "IMAGE"
]
assert len(backup_images) == record_images_count, "Total images count does not match"
class MyCustomTracker(GeneralTracker):
"Basic tracker that writes to a csv for testing"
@ -681,6 +803,15 @@ class TrackerDeferredInitializationTest(unittest.TestCase):
_ = Accelerator(log_with=tracker)
self.assertNotEqual(PartialState._shared_state, {})
@require_trackio
def test_trackio_deferred_init(self):
"""Test that trackio tracker initialization doesn't initialize distributed"""
PartialState._reset_state()
tracker = TrackioTracker(run_name="test_trackio")
self.assertEqual(PartialState._shared_state, {})
_ = Accelerator(log_with=tracker)
self.assertNotEqual(PartialState._shared_state, {})
@require_comet_ml
def test_comet_ml_deferred_init(self):
"""Test that CometML tracker initialization doesn't initialize distributed"""
@ -728,3 +859,12 @@ class TrackerDeferredInitializationTest(unittest.TestCase):
self.assertEqual(PartialState._shared_state, {})
_ = Accelerator(log_with=tracker)
self.assertNotEqual(PartialState._shared_state, {})
@require_swanlab
def test_swanlab_deferred_init(self):
"""Test that SwanLab tracker initialization doesn't initialize distributed"""
PartialState._reset_state()
tracker = SwanLabTracker(run_name="test_swanlab")
self.assertEqual(PartialState._shared_state, {})
_ = Accelerator(log_with=tracker)
self.assertNotEqual(PartialState._shared_state, {})