Compare commits

...

23 Commits

Author SHA1 Message Date
67be9a69ba Tmp: create device mesh beforehand 2025-07-09 14:10:32 +00:00
8df21cf54a Fix: correct launch command
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-07-09 14:10:32 +00:00
199cbedb01 Feat: final? 2025-07-09 14:10:32 +00:00
50f60e3c6f Feat: final benchmarks 2025-07-09 14:10:32 +00:00
ed0fcaceb7 Final changes 2025-07-09 14:10:32 +00:00
7859437aaa Refactor: address comments 2025-07-09 14:10:31 +00:00
dc828eba05 Feat: minor refactor + tests 2025-07-09 14:10:28 +00:00
8cef8d4d26 Minor 2025-07-09 14:09:45 +00:00
7782a156fb Feat: some fixes 2025-07-09 14:09:45 +00:00
0d20c3b110 Feat: better tests 2025-07-09 14:09:42 +00:00
ad7e1ad349 Test: distributed dataloader 2025-07-09 14:09:01 +00:00
cc30dc60f3 Fix: correct batch dispatch 2025-07-09 14:08:35 +00:00
17cd32f616 Docs: start readme 2025-07-09 14:08:06 +00:00
b39e39f05e Tests: some cp tests 2025-07-09 14:08:05 +00:00
40999beba2 Feat: reword 2025-07-09 14:08:05 +00:00
3670c6d2a8 Reword 2025-07-09 14:08:04 +00:00
77bd1fab74 Feat: add to toctree 2025-07-09 14:08:04 +00:00
351d9890f2 Fix: add back commits lost in the rebase 2025-07-09 14:08:03 +00:00
27edf35212 Squashed commit of the following:
commit 2f8fd72e5112beb24082c252f8aa5e621bb10129
Author: Simon <80467011+sorgfresser@users.noreply.github.com>
Date:   Tue Jun 10 13:50:34 2025 +0100

    Remove device_count (#3587)

commit d2e6b0313d696be62fe69d19f15bf3098effbad2
Author: Matej Sirovatka <54212263+S1ro1@users.noreply.github.com>
Date:   Tue Jun 10 05:26:48 2025 -0700

    [FSDP2] Refactor + FP8 (#3585)

    * Fix double wrap

    * Clocking off, ~equal to torch baseline

    * works?

    * Working version

    * Partial rewrite

    * FSDP2 path works

    * Fix back prepare

    * Almost done, proper AC left

    * Feat: should work, cleanup + test more benchmarks left

    * Style+quality

    * Feat: fp8 example

    * Feat: better example

    * Feat: add readme

    * Docs + should be done

    * Fix: typos

    * Fix: protect imports

    * Feat: address comments

    * Feat: add flops image

commit b9fee48c85dc8b3c4db1e97258925660cdc6ee36
Author: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
Date:   Tue Jun 10 13:24:43 2025 +0100

    better handle FP8 with and without deepspeed (#3611)

    * use the state mixed precision which has undergone all preprocessing

    * Update src/accelerate/accelerator.py

    Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

    * Update src/accelerate/accelerator.py

    * accelerator state sets the mixed precision for deepspeed and fp8_enabled

    * fix

    * fix

    ---------

    Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

commit 3a82b056cf85b16976ca2760615897fe65ae5e64
Author: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Date:   Tue Jun 10 11:29:59 2025 +0200

    Fix bf16 training with TP  (#3610)

    * fix

    * Apply style fixes

    ---------

    Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

commit 6b61a373a2b4e72e3f003ba2277904ee31b9f7e0
Author: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
Date:   Fri Jun 6 13:48:43 2025 +0100

    fix deepspeed regional compilation (#3609)

commit 682691deaca2637e0a2efeaa5ebb6dd8bade8c30
Author: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
Date:   Tue Jun 3 12:36:56 2025 +0200

    Update Gaudi Runners (#3593)

    * test

    * fix

    * push

    * in the morning

    * fix backend

    * run first

    * set habana modules

    * dynamo backend

    * trigger

    * remove on pr

    * remove on file change

commit 791055b4848d2c11d3dfcd47faba79b524973756
Author: Matej Sirovatka <54212263+S1ro1@users.noreply.github.com>
Date:   Tue Jun 3 12:24:20 2025 +0200

    Fix: list object has no attribute keys (#3603)

commit 16bf1d89016e03f5b0d8545e9883df7fe9ab9b5f
Author: Yao Matrix <matrix.yao@intel.com>
Date:   Fri May 30 23:36:34 2025 +0800

    enable torchao and pippy test cases on XPU (#3599)

    * enable torchao and pippy test cases on XPU

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>

    * fix style

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>

    ---------

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>

commit ab3c604e48619f7cd08cfac46a7c542414b6661f
Author: Yao Matrix <matrix.yao@intel.com>
Date:   Fri May 30 23:23:26 2025 +0800

    enable big_model_inference on xpu (#3595)

    * enable big_model_inference on XPU

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>

    * fix style

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>

    * fix quality

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>

    ---------

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>

commit 273799c85d849a1954a4f2e65767216eb37fa089
Author: Yao Matrix <matrix.yao@intel.com>
Date:   Tue May 27 20:08:59 2025 +0800

    enable fsdp2 benchmark on XPU (#3590)

    * enable fsdp2 benchmark on XPU

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>

    * add deterministic

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>

    ---------

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>

commit 43526c5c089cc831530f42bbbe66a0cb0b0ea461
Author: Yao Matrix <matrix.yao@intel.com>
Date:   Tue May 27 17:44:50 2025 +0800

    add device-agnostic GradScaler (#3588)

    * add device-agnostic GradScaler

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>

    * fix bug

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>

    * fix review comments

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>

    * fix

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>

    * format

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>

    * Apply style fixes

    ---------

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>
    Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

commit 07f2392f40a92710b4fb7e51b2de1d40f24d44e2
Author: Yao Matrix <matrix.yao@intel.com>
Date:   Tue May 27 17:17:18 2025 +0800

    change to use torch.device (#3594)

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>

commit ee2f48c2c3d393187408a0f2cce1ece973033809
Author: Fanli Lin <fanli.lin@intel.com>
Date:   Tue May 27 17:16:42 2025 +0800

    [docs] no hard-coded cuda in the ddp documentation (#3589)

    * make device-agnostic

    * refactor

commit 4f3abb73a722f6275197c060346dd2f385039afc
Author: jiqing-feng <jiqing.feng@intel.com>
Date:   Mon May 26 21:55:10 2025 +0800

    Set ccl and KMP param in simple launch (#3575)

    * Even 1 CPU mechine can also run multi process

    Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

    * fix ccl and kml param setting

    Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

    * set master addr only when processes > 1

    Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

    * fix num process check

    Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

    * fix ccl args check

    Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

    ---------

    Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

commit db536cbfeb61a92e642462a436b51104ab96bd2f
Author: Yuanzhou Cai <80858000+yuanjua@users.noreply.github.com>
Date:   Mon May 26 21:08:13 2025 +0800

    Fix: Defer Tracker Initialization to Prevent Premature Distributed Setup (#3581)

    * Fix tracker initialize distributed before InitProcessGroupKwargs

    * Fix tracker initialize distributed before InitProcessGroupKwargs

    * Add test for bug #3550

    * Improve test for #3550

    * Remove redundant code

    Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

    * fix style

    ---------

    Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

commit 4e9d0deba6fd759f5f503f9b1587e79c51032a68
Author: Yao Matrix <matrix.yao@intel.com>
Date:   Mon May 26 21:05:42 2025 +0800

    enable regional_compilation benchmark on xpu (#3592)

    * enable regional_compilation benchmark on xpu

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>

    * Apply style fixes

    ---------

    Signed-off-by: Matrix YAO <matrix.yao@intel.com>
    Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

commit 8cb3ace89485af0488d93da6c080c36319cced9e
Author: Luiz F. G. dos Santos <luiz.fernando0992@gmail.com>
Date:   Thu May 22 10:21:54 2025 -0500

    Add kwargs to optimizer, scheduler and dataloader using function `accelerator().load_state()` (#3540)

    * Added artifacts and figure tracking at MLFlow tracker

    * Added `log_artifact` to the MLFlowTracker

    * Remove changes

    * Added kwargs when loading state.

    * added doc string

    * Adjusted correct default types of kwargs

    * Changed the load kwargs to a single one

    * removed None value from kwargs

    * fix kwargs for loading the model

    * removed load_kwargs from optimizer state dict

    * make load_kwargs a dictionary

    * revert last changes

    * reverted load_kwargs

    * fix docstring

    * added dict initiation

    * Fix quality error during PR

commit b6d97cb856ae0c9daa310ab8305340950ea8763a
Author: Emmanuel Ferdman <emmanuelferdman@gmail.com>
Date:   Thu May 22 17:26:31 2025 +0300

    Resolve logger warnings (#3582)

    Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>

commit 33967d4733ec5bf402d85462ec2bbbcd8e872ea9
Author: Francesco Laiti <25352428+laitifranz@users.noreply.github.com>
Date:   Tue May 20 12:29:53 2025 +0200

    Add support for standalone mode when default port is occupied on single node (#3576)

    * add standalone mode and replace ConnectionError with a warning when the main process port is in use, allowing for automatic port selection

    * address review feedback: warn on port conflict only for single-node; raise error for multi-node

    * Apply style fixes

    ---------

    Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

commit 5b1fcda371b049f76e1bd8536e114635d9eaf5b3
Author: Yao Matrix <matrix.yao@intel.com>
Date:   Tue May 20 18:04:24 2025 +0800

    enable test_cli & test_example cases on XPU (#3578)

    * enable test_cli & test_example cases on XPU

    Signed-off-by: Matrix Yao <matrix.yao@intel.com>

    * fix style

    Signed-off-by: Matrix Yao <matrix.yao@intel.com>

    * fix style

    Signed-off-by: Matrix Yao <matrix.yao@intel.com>

    * remove print

    Signed-off-by: Matrix Yao <matrix.yao@intel.com>

    * fix ci issue

    Signed-off-by: YAO Matrix <matrix.yao@intel.com>

    ---------

    Signed-off-by: Matrix Yao <matrix.yao@intel.com>
    Signed-off-by: YAO Matrix <matrix.yao@intel.com>

commit f55f0533b5726d85a62fb05760ec6a92d00e0099
Author: Yao Matrix <matrix.yao@intel.com>
Date:   Tue May 20 18:02:14 2025 +0800

    goodbye torch_ccl (#3580)

    Signed-off-by: Matrix Yao <matrix.yao@intel.com>

commit 1ec99f0b5842f2f246b6481248099920e74f6384
Author: Yao Matrix <yaoweifeng0301@126.com>
Date:   Mon May 19 17:27:40 2025 +0800

    enable test_load_checkpoint_and_dispatch_with_broadcast cases on XPU (#3579)

    * enable test_load_checkpoint_and_dispatch_with_broadcast cases on XPU

    Signed-off-by: Matrix Yao <matrix.yao@intel.com>

    * fix style

    Signed-off-by: Matrix Yao <matrix.yao@intel.com>

    * Update test_load_checkpoint_and_dispatch_with_broadcast.py

    ---------

    Signed-off-by: Matrix Yao <matrix.yao@intel.com>
2025-07-09 14:08:02 +00:00
f8bac5aaa1 Docs: add concept guide 2025-07-09 14:08:01 +00:00
deb42c105b Refactor: change fsdp2_fp8 example to use new utils 2025-07-09 14:07:56 +00:00
b816a6762f Feat: add context-parallel example 2025-07-09 14:07:06 +00:00
b1a48dc76f Feat: add context-parallel to accelerator 2025-07-09 14:07:05 +00:00
16 changed files with 831 additions and 135 deletions

View File

@ -82,6 +82,8 @@
title: Accelerate's internal mechanism
- local: concept_guides/big_model_inference
title: Loading big models into memory
- local: concept_guides/context_parallel
title: Context parallelism
- local: concept_guides/performance
title: Comparing performance across distributed setups
- local: concept_guides/deferring_execution

View File

@ -0,0 +1,156 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Context Parallel in 🤗`accelerate`
This guide will cover basics of using context parallelism in 🤗`accelerate`, for the more curious readers, we will also cover some technicalities in the later sections.
## Why context parallelism?
With the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has lead to a need for more efficient ways to train models with long sequences.
With sequence length of 128k, the memory requirement of the attention matrix is `128k * 128k * 2 bytes * num_heads = ~32 GB * num_heads` for `bf16` precision, given vanilla attention implementation. Granted, with usage of `flash attention` or `SDPA` which do not materialize these attention weights, this decreases drastically, but the growth in memory requirements is still considerable.
Context parallelism allows us to shard the inputs to the attention computation along the sequence dimension and compute the attention in parallel on multiple GPUs. With this, we can train models with long sequences, scaling potentially to 1M+ sequence length.
## How to use context parallelism?
As with any other feature in 🤗`accelerate`, enabling context parallelism is as simple as passing the corresponding flags to `accelerate launch`.
In this case, it's no different:
```bash
accelerate launch --context-parallel-size 8 --context-parallel-shard-rotation [allgather|alltoall] ...
```
Context parallelism is tightly coupled (for now) with `FSDP2`, which you can learn more about in the [FSDP2 introduction](fsdp1_vs_fsdp2.md). Meaning, context parallelism is applied only if `FSDP2` is enabled.
You can also enable context parallelism programatically, by passing it in the `FullyShardedDataParallelPlugin` constructor:
```diff
from accelerate.utils import FullyShardedDataParallelPlugin
plugin = FullyShardedDataParallelPlugin(
...
fsdp_version=2,
+ cp_size=8,
+ cp_comm_strategy="allgather",
)
accelerator = Accelerator(fsdp_plugin=plugin)
```
After enabling context parallelism with the methods mentioned above, you can then apply it to your training loop. We provide a thin wrapper around [`torch.distributed.tensor.experimental.context_parallel`](https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel) that you can use in your training loop, that abstracts some of the complexity of using it (more on this later).
You can use it as follows:
```python
for batch in dataloader:
with accelerator.context_parallel(
buffers=[batch["input_ids"], batch["attention_mask"]],
buffer_seq_dims=[1, 1],
no_restore_buffers={batch["input_ids"]},
):
outputs = model(batch)
...
```
> [!Warning]
> This context manager has to be recreated with each training step, as shown in the example above. It's crucial to do so.
This can scale your context size to 1M+ sequence length potentially. Below, we showcase speed and memory usage of context parallelism for up-to 256k context size. We can see that when we double the context size and number of GPUs, we can achieve consistent memory usage, potentiall enabling endless context length scaling.
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_perf.png" alt="context parallelism memory usage" />
<br>
<em>Figure 1: Memory usage and speed of context parallelism for up-to 256k context size.</em>
</p>
> [!Tip]
> These examples were created with a script you can find [in the examples folder](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/fsdp2_context_parallel.py). For instructions on how to run it, see the [README](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/README.md) in the same folder.
## Accelerate's interface
The context manager takes a few arguments, that are used to configure the context parallelism.
- `buffers`: This is a list of tensors that are to be sharded across the sequence dimension. These tensors are usually input ids, labels and attention mask.
- `buffer_seq_dims`: This is a list of integers, that specify the sequence dimension of the buffers, in the order of the `buffers` list.
- `no_restore_buffers`: The implementation of context parallelism modifies the buffers in-place, converting them to `torch.distributed.tensor.Dtensor`s. After the context manager is exited, a communication kernel would need to be launched to restore the buffers to their original state (usually all-gather). This takes some time, so it is reccomended to pass the same arguments as to the `buffers` argument, to avoid unnecessary communication, unless you are sure that you need to use the buffers after the context manager is exited.
## Configurable options
Accelerate provides only a few options to configure context parallelism, which are:
- `cp_size`: The number of ranks to shard the inputs to the attention computation across the sequence dimension.
- `cp_comm_strategy`: The rotation method to use for the shards. We strongly reccomend keeping this as `"allgather"`, as it's very likely it will outperform `"alltoall"` in most cases.
Context parallel size is rather self-explanatory, it's the number of ranks across which the inputs are to be-sharded.
Context parallel shard rotation defines how the shards of the inputs are rotated across ranks. We'll cover the 2 options in more detail in the next section.
You can see an end-to-end example in the [FSDP2 context parallel example](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/fsdp2_context_parallel.py) file, where you can train an 8B model with 128k sequence length on 8x H100 SXM GPUs. Using multi-node training, you can scale this to 1M+ sequence length on 64x H100 SXM GPUs.
## Technical details
> [!Tip]
> This section is fairly technical, so if you don't need to learn the internals of context parallelism, you can skip it and start building 🚀
We're going to be using word `shard` extensively in the following sections, so let's define it first. If we call tensor `sharded` across `Dth` dimension, across `N` ranks, we mean that this tensor is split into `N` parts, where each part of the tensor has shape `[..., D//N, ...]`.
## So how does it work?
Context parallelism works on sharding the `Q, K and V` matrices across the sequence dimension. Each rank has its assigned shard of `Q`, let's call it `Q_i`. This matrix stays only on this rank, during the whole computation. Similarly, each rank has its own shard of `K` and `V`, let's call them `K_i` and `V_i`. Then, each rank calculates attention with its own shard of `Q_i`, `K_i` and `V_i`, let's call it `attn_i`. During this computation, a communication kernel is launched to gather the `Ks` and `Vs` from all other ranks. What communication primitive is used, depends on the `context_parallel_shard_rotation` option.
This way, each rank gets to calculate local attention, first with `Q_i`, `K_i` and `V_i`, then with `K_j` and `V_j` from all other ranks. As each rank holds `Q, K and V` matrices that are sharded across the sequence dimension, the resulting matrices are smaller and can fit on a single GPU.
We can formalize this in a following pseudocode:
```python
comm_kernel = {"allgather": allgather, "alltoall": alltoall}[context_parallel_shard_rotation]
Qi, Ki, Vi = shard(Q, K, V, seq_dim)
attn[i] = attn(Qi, Ki, Vi)
for j in range(context_parallel_size):
Kj, Vj = comm_kernel()
attn[j] = attn(Qi, Kj, Vj) # [batch, num_heads, seq_len // context_parallel_size, head_dim]
final_attn = combine(attn)
```
## all-to-all vs all-gather
### all-gather
So what's the difference between all-to-all and all-gather? With all-gather, the communication is very simple. After (well, before, as it usually takes longer) we compute the local attention `attn_i` we launch an all-gather to gather all other `Ks` and `Vs` from all other ranks. As this communication is done, each rank has all the `Ks` and `Vs` from all other ranks, and can compute the attention with them sequentially.
In ideal scenario, all-gather finishes in the exact moment as the calculation of `attn_i` is done. However, this never happens in practice, so the ideal real overlap is achieved when the full `attn_i` is overlapped with a part of the communication, then to start the computation with `K_j` and `V_j`, we wait for the all-gather to finish.
### all-to-all
All-to-all, or sometimes called `ring-rotation` utilizes a ring-like communication pattern. After concluding `attn_i` computation, an all-to-all is launched to send `K_i` and `V_i` to the neighbouring ranks. We then repeat this `context_parallel_size-1` times, so that each rank sees all the shards of `K` and `V` from all other ranks once. In ideal scenario, we prefetch shards `K_i+1` and `V_i+1` from the neighbouring rank and this communication is exactly overlapped with computation of our current `attn_i`. Again, realistically, this perfect overlap doesn't ever happen. Given the nature of this approach, if we don't achieve perfect overlap, the penalty is way larger than with all-gather.
## How to choose the right rotation method?
In theory, all-to-all should be the better choice. Though in practice, it rarely is. Therefore, we default to all-gather, as it's more likely to achieve better performance. Extensive [benchmarks](https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082) from the `torchtitan` team also shows that all-to-all rarely outperforms all-gather. Though, we still provide both options, as you might find one to be better for your use case.
You can directly see this issue in the profiler output in the image below:
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_all_to_all.png" alt="all-to-all profiler output" />
<br>
<em>Figure 1: In red you can see the idle time, while we wait for the all-to-all kernel to finish. Highlighted in the first blue bar, you can see that it takes ~250us to finish, which is repeated N-1 times for each attention call, where N is the context parallel size.</em>
</p>
## Why only FSDP2?
We only support context parallelism with `FSDP2` for now, as we create a joint mesh of `context_parallel_size` and `dp_shard_size` to
utilize its full potential. In the profiler output in the image below, you can see why this is the case.
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_why_fsdp2.png" alt="why FSDP2+CP" />
<br>
<em>Figure 2: In blue rectangles (Stream 23), you can see that the pre-fetch of `FSDP` shard is fully overlapped with the computation of attention (Stream 7), while in red rectangles (Stream 24), you can see that the all-gather kernel results in a bubble of idle time, in which our compute stream (7) is idle.</em>
</p>
In the figure above, you can also note the difference between all-to-all and all-gather. While in all-to-all (Figure 1), we launch a communication kernel N-1 times for each attention call, in all-gather (Figure 2), we launch a communication kernel only once. This results in a bigger bubble, but it only happens once per attention call, while in all-to-all, it happens N-1 times.

View File

@ -1,8 +1,8 @@
## FSDP2 Examples
# FSDP2 Examples
This folder contains examples of using FSDP2 with Accelerate, utilizing extra methods to improve training speed, performance or accuracy.
### FSDP2 + ao Float8Linear
## FSDP2 + ao Float8Linear (`fsdp2_fp8.py`)
In file `fsdp2_fp8.py` we use `Float8Linear` from `ao` to train a model partially in FP8 precision. We utilize `AORecipeKwargs` to pass the `Float8LinearConfig` to the accelerator,
which replaces the default `torch.nn.Linear` with `Float8Linear`. We also utilize `TorchDynamoPlugin` together with regional compilation to compile the model,
@ -34,3 +34,25 @@ The figures above were generated on 8x H100 SXM GPUs, with 8192 sequence length
```bash
accelerate launch fsdp2_fp8.py --sequence-length 8192 --num-steps 1000 --log_with wandb --precision [fp8 | bf16]
```
## FSDP2 + context parallelism (`fsdp2_context_parallel.py`)
In this file, we showcase integration of context parallelism with FSDP2. Context parallelism is a technique that allows us to scale the training to sequence length of up to a million tokens. With `accelerator.context_parallel` context manager, we replace the attention implementation with a context parallel version, which enables us to train on a sequence length of up to 128k tokens on 8x H100 GPUs, with possibility of endless scaling if we have enough GPUs.
For a detailed explanation and more details, please refer to [this guide](https://huggingface.co/docs/accelerate/concept_guides/context_parallel). You can run the example with the following command:
```bash
accelerate launch --num_processes 8 fsdp2_context_parallel.py --sequence-length 128000 --num-steps 1000 --log-with wandb --cp-size 8 --cp-comm-strategy allgather
```
More details about the context parallelism can be found in the [concept guide](https://huggingface.co/docs/accelerate/concept_guides/context_parallel). You can see some results below:
<p align="center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_perf.png" alt="context parallelism memory usage" />
<br>
<em>Figure 1: Memory usage and speed of context parallelism for up-to 256k context size.</em>
</p>

View File

@ -0,0 +1,185 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example of training with Context Parallel using FSDP2 via Accelerate.
This example demonstrates how to use Accelerate's context_parallel feature for efficient long sequence training.
"""
import argparse
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM
from accelerate import Accelerator
from accelerate.utils import FullyShardedDataParallelPlugin, set_seed
from utils import PerformanceTracker, create_collate_fn, get_dataset, setup_tokenizer
MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--sequence-length", type=int, default=128_000, help="Sequence length for the dataset")
parser.add_argument("--num-steps", type=int, default=100, help="Number of training steps")
parser.add_argument("--log-with", type=str, default="wandb", help="Logging service to use")
parser.add_argument("--cp-size", type=int, default=8, help="Context parallel size")
parser.add_argument("--cp-comm-strategy", type=str, default="allgather", help="Context parallel shard rotation")
return parser.parse_args()
def training_step(batch, model, optimizer, accelerator: Accelerator):
"""
Perform a single training step with context parallel.
Args:
batch: Input batch containing input_ids and labels
model: The model to train
optimizer: Optimizer
accelerator: Accelerator instance
Returns:
loss: Training loss
"""
input_ids = batch["input_ids"]
labels = batch["labels"]
# Use context parallel for efficient long sequence processing
with accelerator.context_parallel(
buffers=[input_ids, labels],
buffer_seq_dims=[1, 1], # Sequence dimension is dimension 1 for both tensors
no_restore_buffers={input_ids, labels}, # Don't restore these buffers after forward pass
):
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG)
return loss
def main():
set_seed(42)
args = parse_args()
# Configure FSDP2 plugin
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy="transformer_based_wrap",
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
cpu_ram_efficient_loading=True,
activation_checkpointing=True,
fsdp_version=2,
cp_size=args.cp_size,
cp_comm_strategy=args.cp_comm_strategy,
)
# Initialize accelerator
accelerator = Accelerator(
log_with=args.log_with,
fsdp_plugin=fsdp_plugin,
mixed_precision="bf16",
)
accelerator.init_trackers(
project_name="FSDP2_context_parallel",
config={
"sequence_length": args.sequence_length,
"num_steps": args.num_steps,
"cp_size": args.cp_size,
"cp_comm_strategy": args.cp_comm_strategy,
},
)
# Prepare model and optimizer
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
use_cache=False,
)
tokenizer = setup_tokenizer(MODEL_ID)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
model, optimizer = accelerator.prepare(model, optimizer)
accelerator.print("Preparing dataset... this might take a while")
dataset = get_dataset(
accelerator,
tokenizer,
args.sequence_length,
processing_batch_size=args.sequence_length
// 20, # we need to override the default processing batch size to avoid empty packed sequences
)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
dataloader = accelerator.prepare(dataloader)
model.train()
total_num_steps = min(args.num_steps, len(dataloader))
performance_tracker = PerformanceTracker(warmup_steps=10)
accelerator.print(f"Starting training with context parallel for {total_num_steps} steps...")
accelerator.print(f"Sequence length: {args.sequence_length}")
accelerator.print("Warming up for 10 steps...")
accelerator.print(
"Each step takes ~10 seconds with default settings on 8x H100 SXM GPUs, seeing logs takes a while"
)
for step, batch in enumerate(dataloader):
if step >= total_num_steps:
break
# get number of tokens before context_parallel shards the batch
batch_tokens = batch["input_ids"].shape[0] * batch["input_ids"].shape[1]
loss = training_step(batch, model, optimizer, accelerator)
# each batch gets the same data, we divide by the number of processes to get the number of tokens per process
metrics = performance_tracker.step(batch_tokens // accelerator.num_processes)
log_metrics = {"loss": loss.item()}
if "warmup_completed" in metrics:
accelerator.print("Warmup completed! Starting performance tracking...")
elif metrics:
log_metrics.update(
{
"tokens_per_second": int(metrics["tokens_per_second"]),
"steps_per_second": metrics["steps_per_second"],
}
)
if (step % 10 == 0 or step == total_num_steps - 1) and metrics:
accelerator.print(
f"Step {step}/{total_num_steps} | "
f"Loss: {loss.item():.4f} | "
f"Tokens/s: {int(metrics['tokens_per_second'])} | "
f"Steps/s: {metrics['steps_per_second']:.2f} | "
)
accelerator.log(log_metrics)
accelerator.wait_for_everyone()
accelerator.end_training()
accelerator.print("Training completed!")
if __name__ == "__main__":
main()

View File

@ -18,20 +18,17 @@ This example demonstrates how to use torchao's Float8LinearConfig with Accelerat
"""
import argparse
import time
import torch
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader
from torchao.float8 import Float8LinearConfig
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoModelForCausalLM
from accelerate import Accelerator
from accelerate.utils import AORecipeKwargs, FullyShardedDataParallelPlugin, TorchDynamoPlugin, set_seed
from utils import PerformanceTracker, create_collate_fn, get_dataset, get_model_flops_per_token, setup_tokenizer
WARMUP_STEPS = 10
MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
@ -46,88 +43,6 @@ def parse_args():
return parser.parse_args()
def get_model_flops_per_token(model: AutoModelForCausalLM, args: argparse.Namespace) -> float:
"""
Get the number of flops per token for the model.
Args:
model (AutoModelForCausalLM): Model to get the flops for
"""
cfg = model.config
head_dim = cfg.hidden_size // cfg.num_attention_heads
# MLP: 3 matmuls
mlp_flops = 18 * cfg.hidden_size * cfg.intermediate_size
# Attn (w/o dotproduct)
attn_flops = 12 * head_dim * (cfg.num_attention_heads + cfg.num_key_value_heads)
# attn (dotproduct) - this scales quadratically with sequence length, therefore we have to account for it here too
attn_dotproduct_flops = 12 * cfg.num_attention_heads * head_dim * args.sequence_length
# we also ignore embeddings and layernorms, etc
return (mlp_flops + attn_flops + attn_dotproduct_flops) * cfg.num_hidden_layers
def get_dataset(accelerator: Accelerator, tokenizer: AutoTokenizer, seq_len: int) -> Dataset:
"""
Load and prepare TinyStories dataset.
Args:
accelerator (Accelerate): Accelerate accelerator instance
tokenizer (AutoTokenizer): Hugging Face tokenizer
seq_len (int): Sequence length for the dataset
Returns:
Dataset: Packed dataset
"""
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:50%]")
def tokenize_function(examples):
tokenized_batch = tokenizer(
examples["text"],
padding=False,
truncation=True,
max_length=seq_len,
return_tensors=None,
)
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
return tokenized_batch
with accelerator.main_process_first():
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
def create_packed_sequences(examples):
all_tokens = []
for input_ids in examples["input_ids"]:
all_tokens.extend(input_ids)
num_sequences = len(all_tokens) // (seq_len + 1)
packed_input_ids = []
packed_labels = []
for i in range(num_sequences):
start_idx = i * (seq_len + 1)
end_idx = start_idx + (seq_len + 1)
full_sequence = all_tokens[start_idx:end_idx]
packed_input_ids.append(full_sequence[:-1])
packed_labels.append(full_sequence[1:])
return {"input_ids": packed_input_ids, "labels": packed_labels}
with accelerator.main_process_first():
packed_dataset = tokenized_dataset.map(
create_packed_sequences,
batched=True,
remove_columns=tokenized_dataset.column_names,
batch_size=1000,
)
return packed_dataset.shuffle(seed=42)
def main():
"""
Main function to train the model.
@ -174,44 +89,26 @@ def main():
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer = setup_tokenizer(MODEL_ID)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
model, optimizer = accelerator.prepare(model, optimizer)
dataset = get_dataset(accelerator, tokenizer, args.sequence_length)
def collate_fn(batch):
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
# Transformers expect `labels` to not be shifted, though we already shifted them, so we pass them both
# We need to pass both `shift_labels` and `labels` to the model, as the loss is calculated inside `if labels is not None`
# `shift_labels` take precedence over `labels` in this case
return {"input_ids": input_ids, "labels": labels, "shift_labels": labels}
# We keep batch size at 1, as it is basically the same as sequence length, which we use instead
dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
dataloader = accelerator.prepare(dataloader)
model.train()
total_num_steps = min(args.num_steps, len(dataloader))
num_tokens = 0
is_in_warmup = True
model_flops_per_token = get_model_flops_per_token(model, args)
model_flops_per_token = get_model_flops_per_token(model, args.sequence_length)
performance_tracker = PerformanceTracker(warmup_steps=10)
accelerator.print(f"Warming up for {WARMUP_STEPS} steps...")
accelerator.print(f"Starting training with {args.precision} precision for {total_num_steps} steps...")
accelerator.print(f"Sequence length: {args.sequence_length}")
accelerator.print("Warming up for 10 steps...")
for step, batch in enumerate(dataloader):
if step == WARMUP_STEPS:
accelerator.print("Warm up completed! Starting training")
start_time = time.perf_counter()
num_tokens = 0
is_in_warmup = False
if step >= total_num_steps:
break
@ -222,32 +119,34 @@ def main():
optimizer.step()
optimizer.zero_grad()
steps_from_warmup = step - WARMUP_STEPS
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
metrics = {"loss": loss.item()}
batch_tokens = batch["input_ids"].shape[1]
metrics = performance_tracker.step(batch_tokens)
if not is_in_warmup and steps_from_warmup > 0:
num_tokens += batch["input_ids"].shape[1]
total_time = time.perf_counter() - start_time
tps = num_tokens / total_time
tflops = num_tokens * model_flops_per_token / (total_time * 1e12)
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
log_metrics = {"loss": loss.item()}
if "warmup_completed" in metrics:
accelerator.print("Warm up completed! Starting performance tracking...")
elif metrics:
tps = metrics["tokens_per_second"]
tflops = metrics["total_tokens"] * model_flops_per_token / (metrics["total_time"] * 1e12)
# it's rather hard to get a good estimate of MFU as we train with FP8, so both FP8 and BF16 tensor cores are used, therefore we just report TFLOPS (Tera floating point operations per second)
# Given H100 SXM, the theoretical peak flops are ~990 TFLOPS for bf16 and ~1980 TFLOPS for fp8 [https://resources.nvidia.com/en-us-gpu-resources/h100-datasheet-24306]
# This is WITH sparsity, so we divide by 2 to get the answer w/o sparsity
print_msg += f", Average steps/s: {steps_from_warmup / total_time:.2f}, TPS per device: {tps:.2f}, TFLOPS per device: {tflops:.2f}"
metrics.update(
print_msg += f" | Average steps/s: {metrics['steps_per_second']:.2f} | TPS per device: {tps:.2f} | TFLOPS per device: {tflops:.2f}"
log_metrics.update(
{
"steps_per_second": steps_from_warmup / total_time,
"steps_per_second": metrics["steps_per_second"],
"tps_per_device": tps,
"tflops_per_device": tflops,
}
)
if steps_from_warmup % 10 == 0 or step == total_num_steps:
if step % 10 == 0 or step == total_num_steps - 1:
accelerator.print(print_msg)
accelerator.log(metrics)
accelerator.log(log_metrics)
accelerator.wait_for_everyone()
accelerator.end_training()

181
examples/fsdp2/utils.py Normal file
View File

@ -0,0 +1,181 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Common utilities for FSDP2 examples.
"""
import time
import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
def get_dataset(
accelerator: Accelerator,
tokenizer: AutoTokenizer,
seq_len: int,
processing_batch_size: int = 1000,
) -> Dataset:
"""
Load and prepare TinyStories dataset.
Args:
accelerator (Accelerator): Accelerate accelerator instance
tokenizer (AutoTokenizer): Hugging Face tokenizer
seq_len (int): Sequence length for the dataset
processing_batch_size (int): Batch size for processing the dataset
Returns:
Dataset: Packed dataset
"""
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:50%]")
def tokenize_function(examples):
tokenized_batch = tokenizer(
examples["text"],
padding=False,
truncation=True,
max_length=seq_len,
return_tensors=None,
)
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
return tokenized_batch
with accelerator.main_process_first():
tokenized_dataset = raw_dataset.map(
tokenize_function, batched=True, remove_columns=["text"], batch_size=processing_batch_size
)
def create_packed_sequences(examples):
all_tokens = []
for input_ids in examples["input_ids"]:
all_tokens.extend(input_ids)
num_sequences = len(all_tokens) // (seq_len + 1)
packed_input_ids = []
packed_labels = []
for i in range(num_sequences):
start_idx = i * (seq_len + 1)
end_idx = start_idx + (seq_len + 1)
full_sequence = all_tokens[start_idx:end_idx]
packed_input_ids.append(full_sequence[:-1])
packed_labels.append(full_sequence[1:])
return {"input_ids": packed_input_ids, "labels": packed_labels}
with accelerator.main_process_first():
packed_dataset = tokenized_dataset.map(
create_packed_sequences,
batched=True,
remove_columns=tokenized_dataset.column_names,
batch_size=processing_batch_size,
)
return packed_dataset.shuffle(seed=42)
def get_model_flops_per_token(model: AutoModelForCausalLM, seq_len: int) -> float:
"""
Get the number of flops per token for the model.
Args:
model (AutoModelForCausalLM): Model to get the flops for
seq_len (int): Sequence length
"""
cfg = model.config
head_dim = cfg.hidden_size // cfg.num_attention_heads
# MLP: 3 matmuls
mlp_flops = 18 * cfg.hidden_size * cfg.intermediate_size
# Attn (w/o dotproduct)
attn_flops = 12 * head_dim * (cfg.num_attention_heads + cfg.num_key_value_heads)
# attn (dotproduct) - this scales quadratically with sequence length
attn_dotproduct_flops = 12 * cfg.num_attention_heads * head_dim * seq_len
# we also ignore embeddings and layernorms, etc
return (mlp_flops + attn_flops + attn_dotproduct_flops) * cfg.num_hidden_layers
def create_collate_fn():
"""Create a collate function for batching."""
def collate_fn(batch):
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
return {"input_ids": input_ids, "labels": labels}
return collate_fn
class PerformanceTracker:
"""Track training performance metrics."""
def __init__(self, warmup_steps: int = 10):
self.warmup_steps = warmup_steps
self.reset()
def reset(self):
"""Reset all tracking variables."""
self.start_time = None
self.num_tokens = 0
self.is_in_warmup = True
self.step_count = 0
def step(self, batch_tokens: int) -> dict:
"""
Update performance tracking with a new step.
Args:
batch_tokens (int): Number of tokens in current batch
Returns:
dict: Performance metrics if past warmup, empty dict otherwise
"""
self.step_count += 1
if self.step_count == self.warmup_steps:
self.start_time = time.perf_counter()
self.num_tokens = 0
self.is_in_warmup = False
return {"warmup_completed": True}
if not self.is_in_warmup and self.start_time is not None:
self.num_tokens += batch_tokens
total_time = time.perf_counter() - self.start_time
steps_from_warmup = self.step_count - self.warmup_steps
if total_time > 0 and steps_from_warmup > 0:
return {
"tokens_per_second": self.num_tokens / total_time,
"steps_per_second": steps_from_warmup / total_time,
"total_tokens": self.num_tokens,
"total_time": total_time,
}
return {}
def setup_tokenizer(model_id: str) -> AutoTokenizer:
"""Setup tokenizer with proper padding token."""
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer

View File

@ -521,6 +521,23 @@ class Accelerator:
gradient_accumulation_plugin=gradient_accumulation_plugin,
)
if self.is_fsdp2:
from torch.distributed.device_mesh import init_device_mesh
context_parallel_size = self.state.fsdp_plugin.cp_size
world_size = self.state.num_processes
fsdp_size = world_size // context_parallel_size
device_mesh = init_device_mesh(
device_type=self.device.type,
mesh_shape=(fsdp_size, context_parallel_size),
mesh_dim_names=("fsdp", "cp"),
)
device_mesh["fsdp", "cp"]._flatten("fsdp_cp")
self.state.torch_device_mesh = device_mesh
self.device_placement = device_placement
if dataloader_config is None:
dataloader_config = DataLoaderConfiguration()
@ -1260,6 +1277,66 @@ class Accelerator:
with contextlib.nullcontext(joinables):
yield
@contextmanager
def context_parallel(
self,
buffers: list[torch.Tensor] | None = None,
buffer_seq_dims: list[int] | None = None,
no_restore_buffers: set[torch.Tensor] | None = None,
):
"""
A context manager that enables context parallel training.
Args:
buffers (`list[torch.Tensor]`, `optional`):
Buffers, which are going to be sharded along the sequence dimension. Common examples are inputs, labels
or positional embedding buffers. This context manager will modify these buffers in-place, and after
exiting the context, the buffers will be restored to their original state. To avoid unnecessary
restores, you can use `no_restore_buffers` to specify which buffers don't need to be restored.
buffer_seq_dims (`list[int]`, `optional`):
Sequence dimensions of `buffers`.
no_restore_buffers (`set[torch.Tensor]`, `optional`):
This set must be a subset of `buffers`. Specifies which buffers from `buffers` argument won't be
restored after the context exits. These buffers will be then kept in sharded state.
<Tip warning={true}>
`context_parallel` is currently only supported together with FSDP2, and requires `cp_size` to be set. If either
of these conditions are not met, this context manager will have no effect.
</Tip>
<Tip warning={true}>
This context manager has to be recreated with each training step, as shown in the example below.
</Tip>
Example:
```python
>>> for batch in dataloader:
... with accelerator.context_parallel(
... buffers=[batch["input_ids"], batch["attention_mask"]],
... buffer_seq_dims=[1, 1],
... no_restore_buffers={batch["input_ids"]},
... ):
... outputs = model(batch)
... ...
```
"""
if (
getattr(self.state, "fsdp_plugin", None) is None
or self.state.fsdp_plugin.cp_size == 1
or (cp_context := getattr(self, "_cp_context", None)) is None
):
logger.warning("Context parallel + FSDP2 is not configured, this context manager will have no effect.")
yield
else:
with cp_context(buffers=buffers, buffer_seq_dims=buffer_seq_dims, no_restore_buffers=no_restore_buffers):
yield
def print(self, *args, **kwargs):
"""
Drop in replacement of `print()` to only print once per server.
@ -1477,6 +1554,20 @@ class Accelerator:
# Needs to be done first, to make sure AC + fully_shard will work as expected
self.state.fsdp_plugin.set_auto_wrap_policy(model)
if (context_parallel_size := self.state.fsdp_plugin.cp_size) > 1:
if context_parallel_size > self.state.num_processes:
raise ValueError(
f"`cp_size` set to {context_parallel_size}, which is greater than the number of processes {self.state.num_processes}. Please set to 1 to disable context parallel or use a smaller value."
)
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
cp_comm_strategy = self.state.fsdp_plugin.cp_comm_strategy
set_rotate_method(cp_comm_strategy)
self._cp_context = functools.partial(context_parallel, mesh=self.state.torch_device_mesh["cp"])
# Apply AC if needed
if self.state.fsdp_plugin.activation_checkpointing:
model = fsdp2_apply_ac(self, model)
@ -2330,6 +2421,8 @@ class Accelerator:
return self.state.torch_tp_plugin.torch_device_mesh
elif self.distributed_type == DistributedType.DEEPSPEED and hasattr(self.state, "ds_device_mesh"):
return self.state.ds_device_mesh
elif self.is_fsdp2 and hasattr(self.state, "torch_device_mesh"):
return self.state.torch_device_mesh
return None
def _prepare_msamp(self, *args, device_placement):

View File

@ -505,6 +505,22 @@ def get_cluster_input():
error_message="Please enter yes or no.",
)
if fsdp_version == 2:
fsdp_config["fsdp_cp_size"] = _ask_field(
"What should be your FSDP's context parallel size? (Input 1 or leave blank for no context parallel) [1]: ",
int,
default=1,
error_message="Please enter an integer.",
)
if fsdp_version == 2 and fsdp_config.get("fsdp_cp_size", 1) != 1:
fsdp_config["fsdp_cp_comm_strategy"] = _ask_options(
"What should be your FSDP's context parallel communication strategy? [allgather]: ",
["allgather", "alltoall"],
lambda x: ["allgather", "alltoall"][int(x)],
default=0,
)
megatron_lm_config = {}
if distributed_type in [DistributedType.MULTI_GPU]:
use_megatron_lm = _ask_field(

View File

@ -610,6 +610,18 @@ def launch_command_parser(subparsers=None):
type=str,
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
)
fsdp_args.add_argument(
"--fsdp_cp_size",
type=int,
default=1,
help="FSDP's context parallel size. (useful only when `use_fsdp` flag is passed and `fsdp_version` is 2). Defaults to 1 (CP not applied).",
)
fsdp_args.add_argument(
"--fsdp_cp_comm_strategy",
type=str,
default="allgather",
help="FSDP's context parallel communication strategy. (useful only when `use_fsdp` flag is passed and `fsdp_version` is 2). Defaults to `allgather`.",
)
# megatron_lm args
megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")

View File

@ -1117,25 +1117,28 @@ def prepare_data_loader(
process_index = process_index // submesh_tp_size
num_processes = num_processes // submesh_tp_size
else:
# when device mesh is used, specifically with TP
# when device mesh is used, specifically with TP or CP
# then there is need to update process_index and num_processes
# to bring in the effect of generating same batch across TP ranks
# to bring in the effect of generating same batch across TP/CP ranks
# and different batch across FSDP and DP ranks.
# Example:
# if device mesh is (dp,fsdp,tp) = (2, 2, 3)
# ranks would range from 0...11
# from data angle ranks should look like 0 0 0 1 1 1 2 2 2 3 3 3
# if device mesh is (dp,fsdp,tp,cp) = (2, 2, 2, 3)
# ranks would range from 0...23
# from data angle ranks should look like 0 0 0 0 0 0 1 1 1 1 1 1 ...
# processes with same ranks/ids would receive the same batch
submesh_fsdp_size = 1
submesh_dp_size = 1
submesh_tp_size = 1
submesh_cp_size = 1
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
if "dp" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp"].size()
if "fsdp" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
process_index = process_index // submesh_tp_size
if "cp" in torch_device_mesh.mesh_dim_names:
submesh_cp_size = torch_device_mesh["cp"].size()
process_index = process_index // (submesh_tp_size * submesh_cp_size)
num_processes = submesh_fsdp_size * submesh_dp_size
# Sanity check

View File

@ -0,0 +1,46 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.utils.data import DataLoader
from accelerate import Accelerator
def main():
accelerator = Accelerator()
B, S, D = 2, 3, 4
rank_data = torch.ones((B, S, D), device="cuda") * (accelerator.process_index + 1)
all_rank_data = [torch.empty_like(rank_data) for _ in range(accelerator.num_processes)]
torch.distributed.all_gather(all_rank_data, rank_data)
dataloader = DataLoader(all_rank_data, batch_size=B, shuffle=False)
dataloader = accelerator.prepare(dataloader)
for batch in dataloader:
all_rank_batch = [torch.empty_like(batch) for _ in range(accelerator.num_processes)]
torch.distributed.all_gather(all_rank_batch, batch)
if accelerator.is_main_process:
for rank_idx in range(accelerator.num_processes):
torch.testing.assert_close(
all_rank_batch[0],
all_rank_batch[rank_idx],
msg=f"Rank {rank_idx} batch {all_rank_batch[rank_idx]} differs from rank 0 batch {all_rank_batch[0]}",
)
accelerator.end_training()
if __name__ == "__main__":
main()

View File

@ -44,6 +44,7 @@ FSDP_PYTORCH_VERSION = (
"2.1.0.a0+32f93b1" # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image.
)
FSDP2_PYTORCH_VERSION = "2.6.0"
CONTEXT_PARALLEL_PYTORCH_VERSION = "2.7.0"
FSDP_MODEL_NAME = "pytorch_model_fsdp"
DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich", "nossh", "slurm"]
TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"]

View File

@ -33,6 +33,7 @@ import torch
from .constants import (
BETA_TP_AVAILABLE_PYTORCH_VERSION,
CONTEXT_PARALLEL_PYTORCH_VERSION,
FSDP2_PYTORCH_VERSION,
FSDP_AUTO_WRAP_POLICY,
FSDP_BACKWARD_PREFETCH,
@ -1548,6 +1549,11 @@ class FullyShardedDataParallelPlugin:
min_num_params (`Optional[int]`, defaults to `None`):
The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy`
is `size_based_wrap`.
cp_size (`int`, defaults to `1`):
The size of the context parallel group. Only applicable when `fsdp_version` is set to 2, else error will be
raised. Defaults to 1 (CP not applied).
cp_comm_strategy (`str`, defaults to `allgather`):
The shard rotation strategy to use, only used when `cp_size` > 1 and `fsdp_version` is set to 2.
"""
fsdp_version: int = field(
@ -1693,6 +1699,18 @@ class FullyShardedDataParallelPlugin:
"help": "The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy` is `size_based_wrap`."
},
)
cp_size: int = field(
default=None,
metadata={
"help": "The size of the context parallel group. Only applicable when `fsdp_version` is set to 2, else error will be raised. Defaults to 1 (CP not applied)"
},
)
cp_comm_strategy: str = field(
default=None,
metadata={
"help": "The shard rotation strategy to use, only used when `cp_size` > 1 and `fsdp_version` is set to 2. Defaults to `allgather`."
},
)
def __post_init__(self):
from torch.distributed.fsdp import (
@ -1855,6 +1873,28 @@ class FullyShardedDataParallelPlugin:
)
os.environ[env_var] = str(self.cpu_ram_efficient_loading)
if self.cp_size is None:
self.cp_size = int(os.environ.get(env_prefix + "CP_SIZE", "1"))
if self.cp_size > 1 and self.fsdp_version != 2:
raise ValueError(
f"cp_size set to {self.cp_size}. This is not supported with FSDP1, please set to 1 or use `fsdp_version=2`"
)
if self.cp_size > 1 and not is_torch_version(">=", CONTEXT_PARALLEL_PYTORCH_VERSION):
raise ValueError(
f"cp_size set to {self.cp_size}. This is not supported with PyTorch < {CONTEXT_PARALLEL_PYTORCH_VERSION}, please set to None or upgrade your PyTorch version."
)
if self.cp_comm_strategy is None:
self.cp_comm_strategy = os.environ.get(env_prefix + "CP_COMM_STRATEGY", "allgather")
# No need to further check versions, as that check is done in the `context_parallel_size` check
if self.cp_comm_strategy not in ["allgather", "alltoall"]:
raise ValueError(
f"cp_comm_strategy set to {self.cp_comm_strategy}. Must be one of ['allgather', 'alltoall']."
)
if isinstance(self.mixed_precision_policy, dict):
self.set_mixed_precision(self.mixed_precision_policy)
if self.mixed_precision_policy is not None:

View File

@ -616,11 +616,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
original_sd = model.state_dict()
mesh = getattr(accelerator.state, "torch_device_mesh", None)
fsdp2_kwargs = {
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
"offload_policy": fsdp2_plugin.cpu_offload,
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
"mesh": mesh["fsdp_cp"] if mesh else None,
}
model_has_params4bit = False

View File

@ -328,6 +328,8 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:
current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(args.fsdp_cpu_ram_efficient_loading).lower()
current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower()
current_env["FSDP_CP_SIZE"] = str(args.fsdp_cp_size)
current_env["FSDP_CP_COMM_STRATEGY"] = str(args.fsdp_cp_comm_strategy)
if args.use_megatron_lm:
prefix = "MEGATRON_LM_"

View File

@ -33,11 +33,13 @@ from accelerate.test_utils.testing import (
require_multi_device,
require_non_cpu,
require_non_torch_xla,
require_torch_min_version,
run_first,
slow,
)
from accelerate.utils import is_bf16_available, is_fp16_available, is_hpu_available, patch_environment, set_seed
from accelerate.utils.constants import (
CONTEXT_PARALLEL_PYTORCH_VERSION,
FSDP2_STATE_DICT_TYPE,
FSDP_AUTO_WRAP_POLICY,
FSDP_BACKWARD_PREFETCH,
@ -84,7 +86,6 @@ class FSDPPluginIntegration(AccelerateTestCase):
1: self.fsdp1_env,
2: self.fsdp2_env,
}
self.current_fsdp_version = 1
def test_sharding_strategy(self):
@ -398,6 +399,24 @@ class FSDPPluginIntegration(AccelerateTestCase):
assert fsdp_plugin.cpu_ram_efficient_loading is False
assert os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING") == "False"
def test_cp(self):
if (fsdp_version := self.current_fsdp_version) != 2:
return
env = self.fsdp_envs[fsdp_version].copy()
for cp_comm_strategy in ["allgather", "alltoall"]:
env["FSDP_CP_COMM_STRATEGY"] = cp_comm_strategy
env["FSDP_CP_SIZE"] = "2"
with patch_environment(**env):
fsdp_plugin = FullyShardedDataParallelPlugin()
assert fsdp_plugin.cp_comm_strategy == cp_comm_strategy
env = self.fsdp_envs[fsdp_version].copy()
env["FSDP_CP_SIZE"] = "2"
with patch_environment(**env):
fsdp_plugin = FullyShardedDataParallelPlugin(cp_comm_strategy=cp_comm_strategy)
assert fsdp_plugin.cp_comm_strategy == cp_comm_strategy
@require_fsdp2
@require_non_cpu
@ -448,7 +467,6 @@ class FSDPIntegrationTest(TempDirTestCase):
}
self.n_train = 160
self.n_val = 160
self.current_fsdp_version = 1
@require_fp16
@ -604,6 +622,23 @@ class FSDPIntegrationTest(TempDirTestCase):
with patch_environment(omp_num_threads=1):
execute_subprocess_async(cmd_config)
# TODO: Should probably be moved to a separate test file
@require_torch_min_version(version=CONTEXT_PARALLEL_PYTORCH_VERSION)
def test_dist_dataloader(self):
if (fsdp_version := self.current_fsdp_version) != 2:
return
self.test_file_path = self.test_scripts_folder / "test_distributed_dataloader.py"
cmd = get_launch_command(num_processes=2, num_machines=1, machine_rank=0, fsdp_version=fsdp_version)
cmd_config = cmd.copy()
cmd_config.extend(["--use_fsdp", "--fsdp_cp_size=2"])
cmd_config.append(self.test_file_path)
with patch_environment(omp_num_threads=1):
execute_subprocess_async(cmd_config)
@require_fsdp2
@run_first