mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-19 00:54:29 +08:00
Compare commits
23 Commits
v1.11.0
...
context-pa
| Author | SHA1 | Date | |
|---|---|---|---|
| 67be9a69ba | |||
| 8df21cf54a | |||
| 199cbedb01 | |||
| 50f60e3c6f | |||
| ed0fcaceb7 | |||
| 7859437aaa | |||
| dc828eba05 | |||
| 8cef8d4d26 | |||
| 7782a156fb | |||
| 0d20c3b110 | |||
| ad7e1ad349 | |||
| cc30dc60f3 | |||
| 17cd32f616 | |||
| b39e39f05e | |||
| 40999beba2 | |||
| 3670c6d2a8 | |||
| 77bd1fab74 | |||
| 351d9890f2 | |||
| 27edf35212 | |||
| f8bac5aaa1 | |||
| deb42c105b | |||
| b816a6762f | |||
| b1a48dc76f |
@ -82,6 +82,8 @@
|
||||
title: Accelerate's internal mechanism
|
||||
- local: concept_guides/big_model_inference
|
||||
title: Loading big models into memory
|
||||
- local: concept_guides/context_parallel
|
||||
title: Context parallelism
|
||||
- local: concept_guides/performance
|
||||
title: Comparing performance across distributed setups
|
||||
- local: concept_guides/deferring_execution
|
||||
|
||||
156
docs/source/concept_guides/context_parallel.md
Normal file
156
docs/source/concept_guides/context_parallel.md
Normal file
@ -0,0 +1,156 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
-->
|
||||
|
||||
# Context Parallel in 🤗`accelerate`
|
||||
|
||||
This guide will cover basics of using context parallelism in 🤗`accelerate`, for the more curious readers, we will also cover some technicalities in the later sections.
|
||||
|
||||
## Why context parallelism?
|
||||
|
||||
With the advent of large language models, and recently reasoning models, the sequence length has been growing rapidly. This, combined with quadratic memory complexity of attention, has lead to a need for more efficient ways to train models with long sequences.
|
||||
With sequence length of 128k, the memory requirement of the attention matrix is `128k * 128k * 2 bytes * num_heads = ~32 GB * num_heads` for `bf16` precision, given vanilla attention implementation. Granted, with usage of `flash attention` or `SDPA` which do not materialize these attention weights, this decreases drastically, but the growth in memory requirements is still considerable.
|
||||
|
||||
Context parallelism allows us to shard the inputs to the attention computation along the sequence dimension and compute the attention in parallel on multiple GPUs. With this, we can train models with long sequences, scaling potentially to 1M+ sequence length.
|
||||
|
||||
|
||||
## How to use context parallelism?
|
||||
|
||||
As with any other feature in 🤗`accelerate`, enabling context parallelism is as simple as passing the corresponding flags to `accelerate launch`.
|
||||
In this case, it's no different:
|
||||
|
||||
```bash
|
||||
accelerate launch --context-parallel-size 8 --context-parallel-shard-rotation [allgather|alltoall] ...
|
||||
```
|
||||
|
||||
Context parallelism is tightly coupled (for now) with `FSDP2`, which you can learn more about in the [FSDP2 introduction](fsdp1_vs_fsdp2.md). Meaning, context parallelism is applied only if `FSDP2` is enabled.
|
||||
You can also enable context parallelism programatically, by passing it in the `FullyShardedDataParallelPlugin` constructor:
|
||||
|
||||
```diff
|
||||
from accelerate.utils import FullyShardedDataParallelPlugin
|
||||
|
||||
plugin = FullyShardedDataParallelPlugin(
|
||||
...
|
||||
fsdp_version=2,
|
||||
+ cp_size=8,
|
||||
+ cp_comm_strategy="allgather",
|
||||
)
|
||||
accelerator = Accelerator(fsdp_plugin=plugin)
|
||||
```
|
||||
|
||||
After enabling context parallelism with the methods mentioned above, you can then apply it to your training loop. We provide a thin wrapper around [`torch.distributed.tensor.experimental.context_parallel`](https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel) that you can use in your training loop, that abstracts some of the complexity of using it (more on this later).
|
||||
You can use it as follows:
|
||||
|
||||
```python
|
||||
for batch in dataloader:
|
||||
with accelerator.context_parallel(
|
||||
buffers=[batch["input_ids"], batch["attention_mask"]],
|
||||
buffer_seq_dims=[1, 1],
|
||||
no_restore_buffers={batch["input_ids"]},
|
||||
):
|
||||
outputs = model(batch)
|
||||
...
|
||||
```
|
||||
|
||||
> [!Warning]
|
||||
> This context manager has to be recreated with each training step, as shown in the example above. It's crucial to do so.
|
||||
|
||||
This can scale your context size to 1M+ sequence length potentially. Below, we showcase speed and memory usage of context parallelism for up-to 256k context size. We can see that when we double the context size and number of GPUs, we can achieve consistent memory usage, potentiall enabling endless context length scaling.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_perf.png" alt="context parallelism memory usage" />
|
||||
<br>
|
||||
<em>Figure 1: Memory usage and speed of context parallelism for up-to 256k context size.</em>
|
||||
</p>
|
||||
|
||||
> [!Tip]
|
||||
> These examples were created with a script you can find [in the examples folder](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/fsdp2_context_parallel.py). For instructions on how to run it, see the [README](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/README.md) in the same folder.
|
||||
|
||||
|
||||
## Accelerate's interface
|
||||
|
||||
The context manager takes a few arguments, that are used to configure the context parallelism.
|
||||
|
||||
- `buffers`: This is a list of tensors that are to be sharded across the sequence dimension. These tensors are usually input ids, labels and attention mask.
|
||||
- `buffer_seq_dims`: This is a list of integers, that specify the sequence dimension of the buffers, in the order of the `buffers` list.
|
||||
- `no_restore_buffers`: The implementation of context parallelism modifies the buffers in-place, converting them to `torch.distributed.tensor.Dtensor`s. After the context manager is exited, a communication kernel would need to be launched to restore the buffers to their original state (usually all-gather). This takes some time, so it is reccomended to pass the same arguments as to the `buffers` argument, to avoid unnecessary communication, unless you are sure that you need to use the buffers after the context manager is exited.
|
||||
|
||||
## Configurable options
|
||||
Accelerate provides only a few options to configure context parallelism, which are:
|
||||
|
||||
- `cp_size`: The number of ranks to shard the inputs to the attention computation across the sequence dimension.
|
||||
- `cp_comm_strategy`: The rotation method to use for the shards. We strongly reccomend keeping this as `"allgather"`, as it's very likely it will outperform `"alltoall"` in most cases.
|
||||
|
||||
Context parallel size is rather self-explanatory, it's the number of ranks across which the inputs are to be-sharded.
|
||||
Context parallel shard rotation defines how the shards of the inputs are rotated across ranks. We'll cover the 2 options in more detail in the next section.
|
||||
|
||||
You can see an end-to-end example in the [FSDP2 context parallel example](https://github.com/huggingface/accelerate/blob/main/examples/fsdp2/fsdp2_context_parallel.py) file, where you can train an 8B model with 128k sequence length on 8x H100 SXM GPUs. Using multi-node training, you can scale this to 1M+ sequence length on 64x H100 SXM GPUs.
|
||||
|
||||
## Technical details
|
||||
|
||||
> [!Tip]
|
||||
> This section is fairly technical, so if you don't need to learn the internals of context parallelism, you can skip it and start building 🚀
|
||||
|
||||
We're going to be using word `shard` extensively in the following sections, so let's define it first. If we call tensor `sharded` across `Dth` dimension, across `N` ranks, we mean that this tensor is split into `N` parts, where each part of the tensor has shape `[..., D//N, ...]`.
|
||||
|
||||
|
||||
## So how does it work?
|
||||
|
||||
Context parallelism works on sharding the `Q, K and V` matrices across the sequence dimension. Each rank has its assigned shard of `Q`, let's call it `Q_i`. This matrix stays only on this rank, during the whole computation. Similarly, each rank has its own shard of `K` and `V`, let's call them `K_i` and `V_i`. Then, each rank calculates attention with its own shard of `Q_i`, `K_i` and `V_i`, let's call it `attn_i`. During this computation, a communication kernel is launched to gather the `Ks` and `Vs` from all other ranks. What communication primitive is used, depends on the `context_parallel_shard_rotation` option.
|
||||
This way, each rank gets to calculate local attention, first with `Q_i`, `K_i` and `V_i`, then with `K_j` and `V_j` from all other ranks. As each rank holds `Q, K and V` matrices that are sharded across the sequence dimension, the resulting matrices are smaller and can fit on a single GPU.
|
||||
|
||||
We can formalize this in a following pseudocode:
|
||||
```python
|
||||
comm_kernel = {"allgather": allgather, "alltoall": alltoall}[context_parallel_shard_rotation]
|
||||
Qi, Ki, Vi = shard(Q, K, V, seq_dim)
|
||||
attn[i] = attn(Qi, Ki, Vi)
|
||||
for j in range(context_parallel_size):
|
||||
Kj, Vj = comm_kernel()
|
||||
attn[j] = attn(Qi, Kj, Vj) # [batch, num_heads, seq_len // context_parallel_size, head_dim]
|
||||
|
||||
final_attn = combine(attn)
|
||||
```
|
||||
|
||||
## all-to-all vs all-gather
|
||||
|
||||
### all-gather
|
||||
So what's the difference between all-to-all and all-gather? With all-gather, the communication is very simple. After (well, before, as it usually takes longer) we compute the local attention `attn_i` we launch an all-gather to gather all other `Ks` and `Vs` from all other ranks. As this communication is done, each rank has all the `Ks` and `Vs` from all other ranks, and can compute the attention with them sequentially.
|
||||
In ideal scenario, all-gather finishes in the exact moment as the calculation of `attn_i` is done. However, this never happens in practice, so the ideal real overlap is achieved when the full `attn_i` is overlapped with a part of the communication, then to start the computation with `K_j` and `V_j`, we wait for the all-gather to finish.
|
||||
|
||||
### all-to-all
|
||||
All-to-all, or sometimes called `ring-rotation` utilizes a ring-like communication pattern. After concluding `attn_i` computation, an all-to-all is launched to send `K_i` and `V_i` to the neighbouring ranks. We then repeat this `context_parallel_size-1` times, so that each rank sees all the shards of `K` and `V` from all other ranks once. In ideal scenario, we prefetch shards `K_i+1` and `V_i+1` from the neighbouring rank and this communication is exactly overlapped with computation of our current `attn_i`. Again, realistically, this perfect overlap doesn't ever happen. Given the nature of this approach, if we don't achieve perfect overlap, the penalty is way larger than with all-gather.
|
||||
|
||||
## How to choose the right rotation method?
|
||||
In theory, all-to-all should be the better choice. Though in practice, it rarely is. Therefore, we default to all-gather, as it's more likely to achieve better performance. Extensive [benchmarks](https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082) from the `torchtitan` team also shows that all-to-all rarely outperforms all-gather. Though, we still provide both options, as you might find one to be better for your use case.
|
||||
|
||||
You can directly see this issue in the profiler output in the image below:
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_all_to_all.png" alt="all-to-all profiler output" />
|
||||
<br>
|
||||
<em>Figure 1: In red you can see the idle time, while we wait for the all-to-all kernel to finish. Highlighted in the first blue bar, you can see that it takes ~250us to finish, which is repeated N-1 times for each attention call, where N is the context parallel size.</em>
|
||||
</p>
|
||||
|
||||
|
||||
## Why only FSDP2?
|
||||
|
||||
We only support context parallelism with `FSDP2` for now, as we create a joint mesh of `context_parallel_size` and `dp_shard_size` to
|
||||
utilize its full potential. In the profiler output in the image below, you can see why this is the case.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_why_fsdp2.png" alt="why FSDP2+CP" />
|
||||
<br>
|
||||
<em>Figure 2: In blue rectangles (Stream 23), you can see that the pre-fetch of `FSDP` shard is fully overlapped with the computation of attention (Stream 7), while in red rectangles (Stream 24), you can see that the all-gather kernel results in a bubble of idle time, in which our compute stream (7) is idle.</em>
|
||||
</p>
|
||||
|
||||
In the figure above, you can also note the difference between all-to-all and all-gather. While in all-to-all (Figure 1), we launch a communication kernel N-1 times for each attention call, in all-gather (Figure 2), we launch a communication kernel only once. This results in a bigger bubble, but it only happens once per attention call, while in all-to-all, it happens N-1 times.
|
||||
@ -1,8 +1,8 @@
|
||||
## FSDP2 Examples
|
||||
# FSDP2 Examples
|
||||
|
||||
This folder contains examples of using FSDP2 with Accelerate, utilizing extra methods to improve training speed, performance or accuracy.
|
||||
|
||||
### FSDP2 + ao Float8Linear
|
||||
## FSDP2 + ao Float8Linear (`fsdp2_fp8.py`)
|
||||
|
||||
In file `fsdp2_fp8.py` we use `Float8Linear` from `ao` to train a model partially in FP8 precision. We utilize `AORecipeKwargs` to pass the `Float8LinearConfig` to the accelerator,
|
||||
which replaces the default `torch.nn.Linear` with `Float8Linear`. We also utilize `TorchDynamoPlugin` together with regional compilation to compile the model,
|
||||
@ -34,3 +34,25 @@ The figures above were generated on 8x H100 SXM GPUs, with 8192 sequence length
|
||||
```bash
|
||||
accelerate launch fsdp2_fp8.py --sequence-length 8192 --num-steps 1000 --log_with wandb --precision [fp8 | bf16]
|
||||
```
|
||||
|
||||
## FSDP2 + context parallelism (`fsdp2_context_parallel.py`)
|
||||
|
||||
In this file, we showcase integration of context parallelism with FSDP2. Context parallelism is a technique that allows us to scale the training to sequence length of up to a million tokens. With `accelerator.context_parallel` context manager, we replace the attention implementation with a context parallel version, which enables us to train on a sequence length of up to 128k tokens on 8x H100 GPUs, with possibility of endless scaling if we have enough GPUs.
|
||||
|
||||
For a detailed explanation and more details, please refer to [this guide](https://huggingface.co/docs/accelerate/concept_guides/context_parallel). You can run the example with the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch --num_processes 8 fsdp2_context_parallel.py --sequence-length 128000 --num-steps 1000 --log-with wandb --cp-size 8 --cp-comm-strategy allgather
|
||||
```
|
||||
|
||||
More details about the context parallelism can be found in the [concept guide](https://huggingface.co/docs/accelerate/concept_guides/context_parallel). You can see some results below:
|
||||
|
||||
<p align="center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate/examples/fsdp2/cp_perf.png" alt="context parallelism memory usage" />
|
||||
<br>
|
||||
<em>Figure 1: Memory usage and speed of context parallelism for up-to 256k context size.</em>
|
||||
</p>
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
185
examples/fsdp2/fsdp2_context_parallel.py
Normal file
185
examples/fsdp2/fsdp2_context_parallel.py
Normal file
@ -0,0 +1,185 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example of training with Context Parallel using FSDP2 via Accelerate.
|
||||
This example demonstrates how to use Accelerate's context_parallel feature for efficient long sequence training.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import FullyShardedDataParallelPlugin, set_seed
|
||||
from utils import PerformanceTracker, create_collate_fn, get_dataset, setup_tokenizer
|
||||
|
||||
|
||||
MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--sequence-length", type=int, default=128_000, help="Sequence length for the dataset")
|
||||
parser.add_argument("--num-steps", type=int, default=100, help="Number of training steps")
|
||||
parser.add_argument("--log-with", type=str, default="wandb", help="Logging service to use")
|
||||
parser.add_argument("--cp-size", type=int, default=8, help="Context parallel size")
|
||||
parser.add_argument("--cp-comm-strategy", type=str, default="allgather", help="Context parallel shard rotation")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def training_step(batch, model, optimizer, accelerator: Accelerator):
|
||||
"""
|
||||
Perform a single training step with context parallel.
|
||||
|
||||
Args:
|
||||
batch: Input batch containing input_ids and labels
|
||||
model: The model to train
|
||||
optimizer: Optimizer
|
||||
accelerator: Accelerator instance
|
||||
|
||||
Returns:
|
||||
loss: Training loss
|
||||
"""
|
||||
input_ids = batch["input_ids"]
|
||||
labels = batch["labels"]
|
||||
|
||||
# Use context parallel for efficient long sequence processing
|
||||
with accelerator.context_parallel(
|
||||
buffers=[input_ids, labels],
|
||||
buffer_seq_dims=[1, 1], # Sequence dimension is dimension 1 for both tensors
|
||||
no_restore_buffers={input_ids, labels}, # Don't restore these buffers after forward pass
|
||||
):
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def main():
|
||||
set_seed(42)
|
||||
args = parse_args()
|
||||
|
||||
# Configure FSDP2 plugin
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
auto_wrap_policy="transformer_based_wrap",
|
||||
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
|
||||
cpu_ram_efficient_loading=True,
|
||||
activation_checkpointing=True,
|
||||
fsdp_version=2,
|
||||
cp_size=args.cp_size,
|
||||
cp_comm_strategy=args.cp_comm_strategy,
|
||||
)
|
||||
|
||||
# Initialize accelerator
|
||||
accelerator = Accelerator(
|
||||
log_with=args.log_with,
|
||||
fsdp_plugin=fsdp_plugin,
|
||||
mixed_precision="bf16",
|
||||
)
|
||||
|
||||
accelerator.init_trackers(
|
||||
project_name="FSDP2_context_parallel",
|
||||
config={
|
||||
"sequence_length": args.sequence_length,
|
||||
"num_steps": args.num_steps,
|
||||
"cp_size": args.cp_size,
|
||||
"cp_comm_strategy": args.cp_comm_strategy,
|
||||
},
|
||||
)
|
||||
|
||||
# Prepare model and optimizer
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=torch.bfloat16,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
tokenizer = setup_tokenizer(MODEL_ID)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
||||
|
||||
model, optimizer = accelerator.prepare(model, optimizer)
|
||||
|
||||
accelerator.print("Preparing dataset... this might take a while")
|
||||
dataset = get_dataset(
|
||||
accelerator,
|
||||
tokenizer,
|
||||
args.sequence_length,
|
||||
processing_batch_size=args.sequence_length
|
||||
// 20, # we need to override the default processing batch size to avoid empty packed sequences
|
||||
)
|
||||
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
|
||||
dataloader = accelerator.prepare(dataloader)
|
||||
|
||||
model.train()
|
||||
|
||||
total_num_steps = min(args.num_steps, len(dataloader))
|
||||
performance_tracker = PerformanceTracker(warmup_steps=10)
|
||||
|
||||
accelerator.print(f"Starting training with context parallel for {total_num_steps} steps...")
|
||||
accelerator.print(f"Sequence length: {args.sequence_length}")
|
||||
accelerator.print("Warming up for 10 steps...")
|
||||
|
||||
accelerator.print(
|
||||
"Each step takes ~10 seconds with default settings on 8x H100 SXM GPUs, seeing logs takes a while"
|
||||
)
|
||||
for step, batch in enumerate(dataloader):
|
||||
if step >= total_num_steps:
|
||||
break
|
||||
|
||||
# get number of tokens before context_parallel shards the batch
|
||||
batch_tokens = batch["input_ids"].shape[0] * batch["input_ids"].shape[1]
|
||||
|
||||
loss = training_step(batch, model, optimizer, accelerator)
|
||||
|
||||
# each batch gets the same data, we divide by the number of processes to get the number of tokens per process
|
||||
metrics = performance_tracker.step(batch_tokens // accelerator.num_processes)
|
||||
|
||||
log_metrics = {"loss": loss.item()}
|
||||
|
||||
if "warmup_completed" in metrics:
|
||||
accelerator.print("Warmup completed! Starting performance tracking...")
|
||||
elif metrics:
|
||||
log_metrics.update(
|
||||
{
|
||||
"tokens_per_second": int(metrics["tokens_per_second"]),
|
||||
"steps_per_second": metrics["steps_per_second"],
|
||||
}
|
||||
)
|
||||
|
||||
if (step % 10 == 0 or step == total_num_steps - 1) and metrics:
|
||||
accelerator.print(
|
||||
f"Step {step}/{total_num_steps} | "
|
||||
f"Loss: {loss.item():.4f} | "
|
||||
f"Tokens/s: {int(metrics['tokens_per_second'])} | "
|
||||
f"Steps/s: {metrics['steps_per_second']:.2f} | "
|
||||
)
|
||||
|
||||
accelerator.log(log_metrics)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.end_training()
|
||||
accelerator.print("Training completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -18,20 +18,17 @@ This example demonstrates how to use torchao's Float8LinearConfig with Accelerat
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from torchao.float8 import Float8LinearConfig
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import AORecipeKwargs, FullyShardedDataParallelPlugin, TorchDynamoPlugin, set_seed
|
||||
from utils import PerformanceTracker, create_collate_fn, get_dataset, get_model_flops_per_token, setup_tokenizer
|
||||
|
||||
|
||||
WARMUP_STEPS = 10
|
||||
|
||||
MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
|
||||
|
||||
@ -46,88 +43,6 @@ def parse_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_model_flops_per_token(model: AutoModelForCausalLM, args: argparse.Namespace) -> float:
|
||||
"""
|
||||
Get the number of flops per token for the model.
|
||||
|
||||
Args:
|
||||
model (AutoModelForCausalLM): Model to get the flops for
|
||||
"""
|
||||
cfg = model.config
|
||||
|
||||
head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||
|
||||
# MLP: 3 matmuls
|
||||
mlp_flops = 18 * cfg.hidden_size * cfg.intermediate_size
|
||||
|
||||
# Attn (w/o dotproduct)
|
||||
attn_flops = 12 * head_dim * (cfg.num_attention_heads + cfg.num_key_value_heads)
|
||||
|
||||
# attn (dotproduct) - this scales quadratically with sequence length, therefore we have to account for it here too
|
||||
attn_dotproduct_flops = 12 * cfg.num_attention_heads * head_dim * args.sequence_length
|
||||
|
||||
# we also ignore embeddings and layernorms, etc
|
||||
return (mlp_flops + attn_flops + attn_dotproduct_flops) * cfg.num_hidden_layers
|
||||
|
||||
|
||||
def get_dataset(accelerator: Accelerator, tokenizer: AutoTokenizer, seq_len: int) -> Dataset:
|
||||
"""
|
||||
Load and prepare TinyStories dataset.
|
||||
|
||||
Args:
|
||||
accelerator (Accelerate): Accelerate accelerator instance
|
||||
tokenizer (AutoTokenizer): Hugging Face tokenizer
|
||||
seq_len (int): Sequence length for the dataset
|
||||
|
||||
Returns:
|
||||
Dataset: Packed dataset
|
||||
"""
|
||||
|
||||
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:50%]")
|
||||
|
||||
def tokenize_function(examples):
|
||||
tokenized_batch = tokenizer(
|
||||
examples["text"],
|
||||
padding=False,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
return_tensors=None,
|
||||
)
|
||||
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
|
||||
return tokenized_batch
|
||||
|
||||
with accelerator.main_process_first():
|
||||
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
|
||||
|
||||
def create_packed_sequences(examples):
|
||||
all_tokens = []
|
||||
for input_ids in examples["input_ids"]:
|
||||
all_tokens.extend(input_ids)
|
||||
|
||||
num_sequences = len(all_tokens) // (seq_len + 1)
|
||||
packed_input_ids = []
|
||||
packed_labels = []
|
||||
|
||||
for i in range(num_sequences):
|
||||
start_idx = i * (seq_len + 1)
|
||||
end_idx = start_idx + (seq_len + 1)
|
||||
full_sequence = all_tokens[start_idx:end_idx]
|
||||
packed_input_ids.append(full_sequence[:-1])
|
||||
packed_labels.append(full_sequence[1:])
|
||||
|
||||
return {"input_ids": packed_input_ids, "labels": packed_labels}
|
||||
|
||||
with accelerator.main_process_first():
|
||||
packed_dataset = tokenized_dataset.map(
|
||||
create_packed_sequences,
|
||||
batched=True,
|
||||
remove_columns=tokenized_dataset.column_names,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
return packed_dataset.shuffle(seed=42)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to train the model.
|
||||
@ -174,44 +89,26 @@ def main():
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
tokenizer = setup_tokenizer(MODEL_ID)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
|
||||
|
||||
model, optimizer = accelerator.prepare(model, optimizer)
|
||||
|
||||
dataset = get_dataset(accelerator, tokenizer, args.sequence_length)
|
||||
|
||||
def collate_fn(batch):
|
||||
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
|
||||
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
|
||||
# Transformers expect `labels` to not be shifted, though we already shifted them, so we pass them both
|
||||
# We need to pass both `shift_labels` and `labels` to the model, as the loss is calculated inside `if labels is not None`
|
||||
# `shift_labels` take precedence over `labels` in this case
|
||||
return {"input_ids": input_ids, "labels": labels, "shift_labels": labels}
|
||||
|
||||
# We keep batch size at 1, as it is basically the same as sequence length, which we use instead
|
||||
dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn)
|
||||
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
|
||||
dataloader = accelerator.prepare(dataloader)
|
||||
|
||||
model.train()
|
||||
|
||||
total_num_steps = min(args.num_steps, len(dataloader))
|
||||
num_tokens = 0
|
||||
is_in_warmup = True
|
||||
model_flops_per_token = get_model_flops_per_token(model, args)
|
||||
model_flops_per_token = get_model_flops_per_token(model, args.sequence_length)
|
||||
performance_tracker = PerformanceTracker(warmup_steps=10)
|
||||
|
||||
accelerator.print(f"Warming up for {WARMUP_STEPS} steps...")
|
||||
accelerator.print(f"Starting training with {args.precision} precision for {total_num_steps} steps...")
|
||||
accelerator.print(f"Sequence length: {args.sequence_length}")
|
||||
accelerator.print("Warming up for 10 steps...")
|
||||
|
||||
for step, batch in enumerate(dataloader):
|
||||
if step == WARMUP_STEPS:
|
||||
accelerator.print("Warm up completed! Starting training")
|
||||
start_time = time.perf_counter()
|
||||
num_tokens = 0
|
||||
is_in_warmup = False
|
||||
|
||||
if step >= total_num_steps:
|
||||
break
|
||||
|
||||
@ -222,32 +119,34 @@ def main():
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
steps_from_warmup = step - WARMUP_STEPS
|
||||
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
|
||||
metrics = {"loss": loss.item()}
|
||||
batch_tokens = batch["input_ids"].shape[1]
|
||||
metrics = performance_tracker.step(batch_tokens)
|
||||
|
||||
if not is_in_warmup and steps_from_warmup > 0:
|
||||
num_tokens += batch["input_ids"].shape[1]
|
||||
total_time = time.perf_counter() - start_time
|
||||
tps = num_tokens / total_time
|
||||
tflops = num_tokens * model_flops_per_token / (total_time * 1e12)
|
||||
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
|
||||
log_metrics = {"loss": loss.item()}
|
||||
|
||||
if "warmup_completed" in metrics:
|
||||
accelerator.print("Warm up completed! Starting performance tracking...")
|
||||
elif metrics:
|
||||
tps = metrics["tokens_per_second"]
|
||||
tflops = metrics["total_tokens"] * model_flops_per_token / (metrics["total_time"] * 1e12)
|
||||
|
||||
# it's rather hard to get a good estimate of MFU as we train with FP8, so both FP8 and BF16 tensor cores are used, therefore we just report TFLOPS (Tera floating point operations per second)
|
||||
# Given H100 SXM, the theoretical peak flops are ~990 TFLOPS for bf16 and ~1980 TFLOPS for fp8 [https://resources.nvidia.com/en-us-gpu-resources/h100-datasheet-24306]
|
||||
# This is WITH sparsity, so we divide by 2 to get the answer w/o sparsity
|
||||
print_msg += f", Average steps/s: {steps_from_warmup / total_time:.2f}, TPS per device: {tps:.2f}, TFLOPS per device: {tflops:.2f}"
|
||||
metrics.update(
|
||||
print_msg += f" | Average steps/s: {metrics['steps_per_second']:.2f} | TPS per device: {tps:.2f} | TFLOPS per device: {tflops:.2f}"
|
||||
log_metrics.update(
|
||||
{
|
||||
"steps_per_second": steps_from_warmup / total_time,
|
||||
"steps_per_second": metrics["steps_per_second"],
|
||||
"tps_per_device": tps,
|
||||
"tflops_per_device": tflops,
|
||||
}
|
||||
)
|
||||
|
||||
if steps_from_warmup % 10 == 0 or step == total_num_steps:
|
||||
if step % 10 == 0 or step == total_num_steps - 1:
|
||||
accelerator.print(print_msg)
|
||||
|
||||
accelerator.log(metrics)
|
||||
accelerator.log(log_metrics)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.end_training()
|
||||
|
||||
181
examples/fsdp2/utils.py
Normal file
181
examples/fsdp2/utils.py
Normal file
@ -0,0 +1,181 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Common utilities for FSDP2 examples.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
def get_dataset(
|
||||
accelerator: Accelerator,
|
||||
tokenizer: AutoTokenizer,
|
||||
seq_len: int,
|
||||
processing_batch_size: int = 1000,
|
||||
) -> Dataset:
|
||||
"""
|
||||
Load and prepare TinyStories dataset.
|
||||
|
||||
Args:
|
||||
accelerator (Accelerator): Accelerate accelerator instance
|
||||
tokenizer (AutoTokenizer): Hugging Face tokenizer
|
||||
seq_len (int): Sequence length for the dataset
|
||||
processing_batch_size (int): Batch size for processing the dataset
|
||||
|
||||
Returns:
|
||||
Dataset: Packed dataset
|
||||
"""
|
||||
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:50%]")
|
||||
|
||||
def tokenize_function(examples):
|
||||
tokenized_batch = tokenizer(
|
||||
examples["text"],
|
||||
padding=False,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
return_tensors=None,
|
||||
)
|
||||
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
|
||||
return tokenized_batch
|
||||
|
||||
with accelerator.main_process_first():
|
||||
tokenized_dataset = raw_dataset.map(
|
||||
tokenize_function, batched=True, remove_columns=["text"], batch_size=processing_batch_size
|
||||
)
|
||||
|
||||
def create_packed_sequences(examples):
|
||||
all_tokens = []
|
||||
for input_ids in examples["input_ids"]:
|
||||
all_tokens.extend(input_ids)
|
||||
|
||||
num_sequences = len(all_tokens) // (seq_len + 1)
|
||||
packed_input_ids = []
|
||||
packed_labels = []
|
||||
|
||||
for i in range(num_sequences):
|
||||
start_idx = i * (seq_len + 1)
|
||||
end_idx = start_idx + (seq_len + 1)
|
||||
full_sequence = all_tokens[start_idx:end_idx]
|
||||
packed_input_ids.append(full_sequence[:-1])
|
||||
packed_labels.append(full_sequence[1:])
|
||||
|
||||
return {"input_ids": packed_input_ids, "labels": packed_labels}
|
||||
|
||||
with accelerator.main_process_first():
|
||||
packed_dataset = tokenized_dataset.map(
|
||||
create_packed_sequences,
|
||||
batched=True,
|
||||
remove_columns=tokenized_dataset.column_names,
|
||||
batch_size=processing_batch_size,
|
||||
)
|
||||
|
||||
return packed_dataset.shuffle(seed=42)
|
||||
|
||||
|
||||
def get_model_flops_per_token(model: AutoModelForCausalLM, seq_len: int) -> float:
|
||||
"""
|
||||
Get the number of flops per token for the model.
|
||||
|
||||
Args:
|
||||
model (AutoModelForCausalLM): Model to get the flops for
|
||||
seq_len (int): Sequence length
|
||||
"""
|
||||
cfg = model.config
|
||||
head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||
|
||||
# MLP: 3 matmuls
|
||||
mlp_flops = 18 * cfg.hidden_size * cfg.intermediate_size
|
||||
|
||||
# Attn (w/o dotproduct)
|
||||
attn_flops = 12 * head_dim * (cfg.num_attention_heads + cfg.num_key_value_heads)
|
||||
|
||||
# attn (dotproduct) - this scales quadratically with sequence length
|
||||
attn_dotproduct_flops = 12 * cfg.num_attention_heads * head_dim * seq_len
|
||||
|
||||
# we also ignore embeddings and layernorms, etc
|
||||
return (mlp_flops + attn_flops + attn_dotproduct_flops) * cfg.num_hidden_layers
|
||||
|
||||
|
||||
def create_collate_fn():
|
||||
"""Create a collate function for batching."""
|
||||
|
||||
def collate_fn(batch):
|
||||
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
|
||||
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
|
||||
return collate_fn
|
||||
|
||||
|
||||
class PerformanceTracker:
|
||||
"""Track training performance metrics."""
|
||||
|
||||
def __init__(self, warmup_steps: int = 10):
|
||||
self.warmup_steps = warmup_steps
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Reset all tracking variables."""
|
||||
self.start_time = None
|
||||
self.num_tokens = 0
|
||||
self.is_in_warmup = True
|
||||
self.step_count = 0
|
||||
|
||||
def step(self, batch_tokens: int) -> dict:
|
||||
"""
|
||||
Update performance tracking with a new step.
|
||||
|
||||
Args:
|
||||
batch_tokens (int): Number of tokens in current batch
|
||||
|
||||
Returns:
|
||||
dict: Performance metrics if past warmup, empty dict otherwise
|
||||
"""
|
||||
self.step_count += 1
|
||||
|
||||
if self.step_count == self.warmup_steps:
|
||||
self.start_time = time.perf_counter()
|
||||
self.num_tokens = 0
|
||||
self.is_in_warmup = False
|
||||
return {"warmup_completed": True}
|
||||
|
||||
if not self.is_in_warmup and self.start_time is not None:
|
||||
self.num_tokens += batch_tokens
|
||||
total_time = time.perf_counter() - self.start_time
|
||||
steps_from_warmup = self.step_count - self.warmup_steps
|
||||
|
||||
if total_time > 0 and steps_from_warmup > 0:
|
||||
return {
|
||||
"tokens_per_second": self.num_tokens / total_time,
|
||||
"steps_per_second": steps_from_warmup / total_time,
|
||||
"total_tokens": self.num_tokens,
|
||||
"total_time": total_time,
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def setup_tokenizer(model_id: str) -> AutoTokenizer:
|
||||
"""Setup tokenizer with proper padding token."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
@ -521,6 +521,23 @@ class Accelerator:
|
||||
gradient_accumulation_plugin=gradient_accumulation_plugin,
|
||||
)
|
||||
|
||||
if self.is_fsdp2:
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
context_parallel_size = self.state.fsdp_plugin.cp_size
|
||||
|
||||
world_size = self.state.num_processes
|
||||
|
||||
fsdp_size = world_size // context_parallel_size
|
||||
|
||||
device_mesh = init_device_mesh(
|
||||
device_type=self.device.type,
|
||||
mesh_shape=(fsdp_size, context_parallel_size),
|
||||
mesh_dim_names=("fsdp", "cp"),
|
||||
)
|
||||
device_mesh["fsdp", "cp"]._flatten("fsdp_cp")
|
||||
self.state.torch_device_mesh = device_mesh
|
||||
|
||||
self.device_placement = device_placement
|
||||
if dataloader_config is None:
|
||||
dataloader_config = DataLoaderConfiguration()
|
||||
@ -1260,6 +1277,66 @@ class Accelerator:
|
||||
with contextlib.nullcontext(joinables):
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def context_parallel(
|
||||
self,
|
||||
buffers: list[torch.Tensor] | None = None,
|
||||
buffer_seq_dims: list[int] | None = None,
|
||||
no_restore_buffers: set[torch.Tensor] | None = None,
|
||||
):
|
||||
"""
|
||||
A context manager that enables context parallel training.
|
||||
|
||||
Args:
|
||||
buffers (`list[torch.Tensor]`, `optional`):
|
||||
Buffers, which are going to be sharded along the sequence dimension. Common examples are inputs, labels
|
||||
or positional embedding buffers. This context manager will modify these buffers in-place, and after
|
||||
exiting the context, the buffers will be restored to their original state. To avoid unnecessary
|
||||
restores, you can use `no_restore_buffers` to specify which buffers don't need to be restored.
|
||||
buffer_seq_dims (`list[int]`, `optional`):
|
||||
Sequence dimensions of `buffers`.
|
||||
no_restore_buffers (`set[torch.Tensor]`, `optional`):
|
||||
This set must be a subset of `buffers`. Specifies which buffers from `buffers` argument won't be
|
||||
restored after the context exits. These buffers will be then kept in sharded state.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
`context_parallel` is currently only supported together with FSDP2, and requires `cp_size` to be set. If either
|
||||
of these conditions are not met, this context manager will have no effect.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This context manager has to be recreated with each training step, as shown in the example below.
|
||||
|
||||
</Tip>
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> for batch in dataloader:
|
||||
... with accelerator.context_parallel(
|
||||
... buffers=[batch["input_ids"], batch["attention_mask"]],
|
||||
... buffer_seq_dims=[1, 1],
|
||||
... no_restore_buffers={batch["input_ids"]},
|
||||
... ):
|
||||
... outputs = model(batch)
|
||||
... ...
|
||||
```
|
||||
"""
|
||||
|
||||
if (
|
||||
getattr(self.state, "fsdp_plugin", None) is None
|
||||
or self.state.fsdp_plugin.cp_size == 1
|
||||
or (cp_context := getattr(self, "_cp_context", None)) is None
|
||||
):
|
||||
logger.warning("Context parallel + FSDP2 is not configured, this context manager will have no effect.")
|
||||
yield
|
||||
else:
|
||||
with cp_context(buffers=buffers, buffer_seq_dims=buffer_seq_dims, no_restore_buffers=no_restore_buffers):
|
||||
yield
|
||||
|
||||
def print(self, *args, **kwargs):
|
||||
"""
|
||||
Drop in replacement of `print()` to only print once per server.
|
||||
@ -1477,6 +1554,20 @@ class Accelerator:
|
||||
# Needs to be done first, to make sure AC + fully_shard will work as expected
|
||||
self.state.fsdp_plugin.set_auto_wrap_policy(model)
|
||||
|
||||
if (context_parallel_size := self.state.fsdp_plugin.cp_size) > 1:
|
||||
if context_parallel_size > self.state.num_processes:
|
||||
raise ValueError(
|
||||
f"`cp_size` set to {context_parallel_size}, which is greater than the number of processes {self.state.num_processes}. Please set to 1 to disable context parallel or use a smaller value."
|
||||
)
|
||||
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
|
||||
cp_comm_strategy = self.state.fsdp_plugin.cp_comm_strategy
|
||||
set_rotate_method(cp_comm_strategy)
|
||||
|
||||
self._cp_context = functools.partial(context_parallel, mesh=self.state.torch_device_mesh["cp"])
|
||||
|
||||
# Apply AC if needed
|
||||
if self.state.fsdp_plugin.activation_checkpointing:
|
||||
model = fsdp2_apply_ac(self, model)
|
||||
@ -2330,6 +2421,8 @@ class Accelerator:
|
||||
return self.state.torch_tp_plugin.torch_device_mesh
|
||||
elif self.distributed_type == DistributedType.DEEPSPEED and hasattr(self.state, "ds_device_mesh"):
|
||||
return self.state.ds_device_mesh
|
||||
elif self.is_fsdp2 and hasattr(self.state, "torch_device_mesh"):
|
||||
return self.state.torch_device_mesh
|
||||
return None
|
||||
|
||||
def _prepare_msamp(self, *args, device_placement):
|
||||
|
||||
@ -505,6 +505,22 @@ def get_cluster_input():
|
||||
error_message="Please enter yes or no.",
|
||||
)
|
||||
|
||||
if fsdp_version == 2:
|
||||
fsdp_config["fsdp_cp_size"] = _ask_field(
|
||||
"What should be your FSDP's context parallel size? (Input 1 or leave blank for no context parallel) [1]: ",
|
||||
int,
|
||||
default=1,
|
||||
error_message="Please enter an integer.",
|
||||
)
|
||||
|
||||
if fsdp_version == 2 and fsdp_config.get("fsdp_cp_size", 1) != 1:
|
||||
fsdp_config["fsdp_cp_comm_strategy"] = _ask_options(
|
||||
"What should be your FSDP's context parallel communication strategy? [allgather]: ",
|
||||
["allgather", "alltoall"],
|
||||
lambda x: ["allgather", "alltoall"][int(x)],
|
||||
default=0,
|
||||
)
|
||||
|
||||
megatron_lm_config = {}
|
||||
if distributed_type in [DistributedType.MULTI_GPU]:
|
||||
use_megatron_lm = _ask_field(
|
||||
|
||||
@ -610,6 +610,18 @@ def launch_command_parser(subparsers=None):
|
||||
type=str,
|
||||
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
|
||||
)
|
||||
fsdp_args.add_argument(
|
||||
"--fsdp_cp_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="FSDP's context parallel size. (useful only when `use_fsdp` flag is passed and `fsdp_version` is 2). Defaults to 1 (CP not applied).",
|
||||
)
|
||||
fsdp_args.add_argument(
|
||||
"--fsdp_cp_comm_strategy",
|
||||
type=str,
|
||||
default="allgather",
|
||||
help="FSDP's context parallel communication strategy. (useful only when `use_fsdp` flag is passed and `fsdp_version` is 2). Defaults to `allgather`.",
|
||||
)
|
||||
|
||||
# megatron_lm args
|
||||
megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
|
||||
|
||||
@ -1117,25 +1117,28 @@ def prepare_data_loader(
|
||||
process_index = process_index // submesh_tp_size
|
||||
num_processes = num_processes // submesh_tp_size
|
||||
else:
|
||||
# when device mesh is used, specifically with TP
|
||||
# when device mesh is used, specifically with TP or CP
|
||||
# then there is need to update process_index and num_processes
|
||||
# to bring in the effect of generating same batch across TP ranks
|
||||
# to bring in the effect of generating same batch across TP/CP ranks
|
||||
# and different batch across FSDP and DP ranks.
|
||||
# Example:
|
||||
# if device mesh is (dp,fsdp,tp) = (2, 2, 3)
|
||||
# ranks would range from 0...11
|
||||
# from data angle ranks should look like 0 0 0 1 1 1 2 2 2 3 3 3
|
||||
# if device mesh is (dp,fsdp,tp,cp) = (2, 2, 2, 3)
|
||||
# ranks would range from 0...23
|
||||
# from data angle ranks should look like 0 0 0 0 0 0 1 1 1 1 1 1 ...
|
||||
# processes with same ranks/ids would receive the same batch
|
||||
submesh_fsdp_size = 1
|
||||
submesh_dp_size = 1
|
||||
submesh_tp_size = 1
|
||||
submesh_cp_size = 1
|
||||
if "tp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_tp_size = torch_device_mesh["tp"].size()
|
||||
if "dp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_dp_size = torch_device_mesh["dp"].size()
|
||||
if "fsdp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
|
||||
process_index = process_index // submesh_tp_size
|
||||
if "cp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_cp_size = torch_device_mesh["cp"].size()
|
||||
process_index = process_index // (submesh_tp_size * submesh_cp_size)
|
||||
num_processes = submesh_fsdp_size * submesh_dp_size
|
||||
|
||||
# Sanity check
|
||||
|
||||
@ -0,0 +1,46 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
def main():
|
||||
accelerator = Accelerator()
|
||||
B, S, D = 2, 3, 4
|
||||
rank_data = torch.ones((B, S, D), device="cuda") * (accelerator.process_index + 1)
|
||||
all_rank_data = [torch.empty_like(rank_data) for _ in range(accelerator.num_processes)]
|
||||
torch.distributed.all_gather(all_rank_data, rank_data)
|
||||
|
||||
dataloader = DataLoader(all_rank_data, batch_size=B, shuffle=False)
|
||||
dataloader = accelerator.prepare(dataloader)
|
||||
for batch in dataloader:
|
||||
all_rank_batch = [torch.empty_like(batch) for _ in range(accelerator.num_processes)]
|
||||
torch.distributed.all_gather(all_rank_batch, batch)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
for rank_idx in range(accelerator.num_processes):
|
||||
torch.testing.assert_close(
|
||||
all_rank_batch[0],
|
||||
all_rank_batch[rank_idx],
|
||||
msg=f"Rank {rank_idx} batch {all_rank_batch[rank_idx]} differs from rank 0 batch {all_rank_batch[0]}",
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -44,6 +44,7 @@ FSDP_PYTORCH_VERSION = (
|
||||
"2.1.0.a0+32f93b1" # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image.
|
||||
)
|
||||
FSDP2_PYTORCH_VERSION = "2.6.0"
|
||||
CONTEXT_PARALLEL_PYTORCH_VERSION = "2.7.0"
|
||||
FSDP_MODEL_NAME = "pytorch_model_fsdp"
|
||||
DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich", "nossh", "slurm"]
|
||||
TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"]
|
||||
|
||||
@ -33,6 +33,7 @@ import torch
|
||||
|
||||
from .constants import (
|
||||
BETA_TP_AVAILABLE_PYTORCH_VERSION,
|
||||
CONTEXT_PARALLEL_PYTORCH_VERSION,
|
||||
FSDP2_PYTORCH_VERSION,
|
||||
FSDP_AUTO_WRAP_POLICY,
|
||||
FSDP_BACKWARD_PREFETCH,
|
||||
@ -1548,6 +1549,11 @@ class FullyShardedDataParallelPlugin:
|
||||
min_num_params (`Optional[int]`, defaults to `None`):
|
||||
The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy`
|
||||
is `size_based_wrap`.
|
||||
cp_size (`int`, defaults to `1`):
|
||||
The size of the context parallel group. Only applicable when `fsdp_version` is set to 2, else error will be
|
||||
raised. Defaults to 1 (CP not applied).
|
||||
cp_comm_strategy (`str`, defaults to `allgather`):
|
||||
The shard rotation strategy to use, only used when `cp_size` > 1 and `fsdp_version` is set to 2.
|
||||
"""
|
||||
|
||||
fsdp_version: int = field(
|
||||
@ -1693,6 +1699,18 @@ class FullyShardedDataParallelPlugin:
|
||||
"help": "The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy` is `size_based_wrap`."
|
||||
},
|
||||
)
|
||||
cp_size: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The size of the context parallel group. Only applicable when `fsdp_version` is set to 2, else error will be raised. Defaults to 1 (CP not applied)"
|
||||
},
|
||||
)
|
||||
cp_comm_strategy: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The shard rotation strategy to use, only used when `cp_size` > 1 and `fsdp_version` is set to 2. Defaults to `allgather`."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
from torch.distributed.fsdp import (
|
||||
@ -1855,6 +1873,28 @@ class FullyShardedDataParallelPlugin:
|
||||
)
|
||||
os.environ[env_var] = str(self.cpu_ram_efficient_loading)
|
||||
|
||||
if self.cp_size is None:
|
||||
self.cp_size = int(os.environ.get(env_prefix + "CP_SIZE", "1"))
|
||||
|
||||
if self.cp_size > 1 and self.fsdp_version != 2:
|
||||
raise ValueError(
|
||||
f"cp_size set to {self.cp_size}. This is not supported with FSDP1, please set to 1 or use `fsdp_version=2`"
|
||||
)
|
||||
|
||||
if self.cp_size > 1 and not is_torch_version(">=", CONTEXT_PARALLEL_PYTORCH_VERSION):
|
||||
raise ValueError(
|
||||
f"cp_size set to {self.cp_size}. This is not supported with PyTorch < {CONTEXT_PARALLEL_PYTORCH_VERSION}, please set to None or upgrade your PyTorch version."
|
||||
)
|
||||
|
||||
if self.cp_comm_strategy is None:
|
||||
self.cp_comm_strategy = os.environ.get(env_prefix + "CP_COMM_STRATEGY", "allgather")
|
||||
|
||||
# No need to further check versions, as that check is done in the `context_parallel_size` check
|
||||
if self.cp_comm_strategy not in ["allgather", "alltoall"]:
|
||||
raise ValueError(
|
||||
f"cp_comm_strategy set to {self.cp_comm_strategy}. Must be one of ['allgather', 'alltoall']."
|
||||
)
|
||||
|
||||
if isinstance(self.mixed_precision_policy, dict):
|
||||
self.set_mixed_precision(self.mixed_precision_policy)
|
||||
if self.mixed_precision_policy is not None:
|
||||
|
||||
@ -616,11 +616,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
|
||||
original_sd = model.state_dict()
|
||||
|
||||
mesh = getattr(accelerator.state, "torch_device_mesh", None)
|
||||
|
||||
fsdp2_kwargs = {
|
||||
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
|
||||
"offload_policy": fsdp2_plugin.cpu_offload,
|
||||
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
||||
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
||||
"mesh": mesh["fsdp_cp"] if mesh else None,
|
||||
}
|
||||
|
||||
model_has_params4bit = False
|
||||
|
||||
@ -328,6 +328,8 @@ def prepare_multi_gpu_env(args: argparse.Namespace) -> dict[str, str]:
|
||||
current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(args.fsdp_cpu_ram_efficient_loading).lower()
|
||||
current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower()
|
||||
current_env["FSDP_ACTIVATION_CHECKPOINTING"] = str(args.fsdp_activation_checkpointing).lower()
|
||||
current_env["FSDP_CP_SIZE"] = str(args.fsdp_cp_size)
|
||||
current_env["FSDP_CP_COMM_STRATEGY"] = str(args.fsdp_cp_comm_strategy)
|
||||
|
||||
if args.use_megatron_lm:
|
||||
prefix = "MEGATRON_LM_"
|
||||
|
||||
@ -33,11 +33,13 @@ from accelerate.test_utils.testing import (
|
||||
require_multi_device,
|
||||
require_non_cpu,
|
||||
require_non_torch_xla,
|
||||
require_torch_min_version,
|
||||
run_first,
|
||||
slow,
|
||||
)
|
||||
from accelerate.utils import is_bf16_available, is_fp16_available, is_hpu_available, patch_environment, set_seed
|
||||
from accelerate.utils.constants import (
|
||||
CONTEXT_PARALLEL_PYTORCH_VERSION,
|
||||
FSDP2_STATE_DICT_TYPE,
|
||||
FSDP_AUTO_WRAP_POLICY,
|
||||
FSDP_BACKWARD_PREFETCH,
|
||||
@ -84,7 +86,6 @@ class FSDPPluginIntegration(AccelerateTestCase):
|
||||
1: self.fsdp1_env,
|
||||
2: self.fsdp2_env,
|
||||
}
|
||||
|
||||
self.current_fsdp_version = 1
|
||||
|
||||
def test_sharding_strategy(self):
|
||||
@ -398,6 +399,24 @@ class FSDPPluginIntegration(AccelerateTestCase):
|
||||
assert fsdp_plugin.cpu_ram_efficient_loading is False
|
||||
assert os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING") == "False"
|
||||
|
||||
def test_cp(self):
|
||||
if (fsdp_version := self.current_fsdp_version) != 2:
|
||||
return
|
||||
|
||||
env = self.fsdp_envs[fsdp_version].copy()
|
||||
for cp_comm_strategy in ["allgather", "alltoall"]:
|
||||
env["FSDP_CP_COMM_STRATEGY"] = cp_comm_strategy
|
||||
env["FSDP_CP_SIZE"] = "2"
|
||||
with patch_environment(**env):
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin()
|
||||
assert fsdp_plugin.cp_comm_strategy == cp_comm_strategy
|
||||
|
||||
env = self.fsdp_envs[fsdp_version].copy()
|
||||
env["FSDP_CP_SIZE"] = "2"
|
||||
with patch_environment(**env):
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(cp_comm_strategy=cp_comm_strategy)
|
||||
assert fsdp_plugin.cp_comm_strategy == cp_comm_strategy
|
||||
|
||||
|
||||
@require_fsdp2
|
||||
@require_non_cpu
|
||||
@ -448,7 +467,6 @@ class FSDPIntegrationTest(TempDirTestCase):
|
||||
}
|
||||
self.n_train = 160
|
||||
self.n_val = 160
|
||||
|
||||
self.current_fsdp_version = 1
|
||||
|
||||
@require_fp16
|
||||
@ -604,6 +622,23 @@ class FSDPIntegrationTest(TempDirTestCase):
|
||||
with patch_environment(omp_num_threads=1):
|
||||
execute_subprocess_async(cmd_config)
|
||||
|
||||
# TODO: Should probably be moved to a separate test file
|
||||
@require_torch_min_version(version=CONTEXT_PARALLEL_PYTORCH_VERSION)
|
||||
def test_dist_dataloader(self):
|
||||
if (fsdp_version := self.current_fsdp_version) != 2:
|
||||
return
|
||||
|
||||
self.test_file_path = self.test_scripts_folder / "test_distributed_dataloader.py"
|
||||
cmd = get_launch_command(num_processes=2, num_machines=1, machine_rank=0, fsdp_version=fsdp_version)
|
||||
|
||||
cmd_config = cmd.copy()
|
||||
cmd_config.extend(["--use_fsdp", "--fsdp_cp_size=2"])
|
||||
|
||||
cmd_config.append(self.test_file_path)
|
||||
|
||||
with patch_environment(omp_num_threads=1):
|
||||
execute_subprocess_async(cmd_config)
|
||||
|
||||
|
||||
@require_fsdp2
|
||||
@run_first
|
||||
|
||||
Reference in New Issue
Block a user