Parallelism config + TP + HSDP + BYODM (Bring Your Own Device Mesh) (#3682)

* Feat: init

* Feat: add validation + init from kwargs

* Fix: minor fixes

* Feat: more cleanup

* Minor refactor

* remove import

* adding support for pre-configured device mesh

* adding device mesh to fsdp2

* moving mesh dim defn to parralismconfig

* tests

* WIP device mesh/accelerator validation

* WIP more tests

* Test Driven Development (TDD)

* fixing build_device_mesh

* FSDP dim names

* adding example

* WIP

* fixing HSDP

* Feat: add back old options

* working example

* debugging

* adding parallelism config to partialstate

* Feat: revert ddp changes

* Revert DDP

* Feat: (untested) update mesh dims and some minor tweaks

* adding dp_cp dims

* updating comments

* WIP

* wip 2

* reverting

* storing state in accelerator rather than acceleratorstate

* Fix: minor tweaks

* wip example update

* Fixes for non-fsdp2 case

* Feat: ensure ddp/tp only works

* updating example

* updating example

* updating examples, fixing state

* fixed state

* comments

* fixing partial state check

* linting

* comments

* removing fn

* WIP: fix tp

* comments

* removing return

* reverting upcast

* add guards

* guards for empty self.parallelism_config

* use len on tuple to check if empty

* Feat: cleanup example

* Feat: some cleanup of example

* Feat: add trackio

* Fix: improve trackio

* Feat: TP works

* Feat: some fsdp2 improv

* Feat: working examples

* handle clipping for tensor parallel

* Implicit replicate

* Refactor: move to separate file + cleanup + basic comments

* Fix: add unadded files, fix circular import

* Feat: better readme

* Feat: add blog + ultrascale links

* Tmp: should_save_model now returns only true

* Fix: remove implicit_replication and style

* Fix: remove optional

* add guard on parallelism_config.tp_enabled

* fix import

* fixing empty parallelism_config

* fix import path for test patch

* fixing patch

---------

Co-authored-by: S1ro1 <matej.sirovatka@gmail.com>
Co-authored-by: Salman Mohammadi <“salman.mohammadi@outlook.com”>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
salman
2025-07-30 20:03:13 +01:00
committed by GitHub
parent 2f075c724c
commit 9359a0194f
15 changed files with 1237 additions and 122 deletions

View File

@ -2,6 +2,31 @@
This folder contains examples of using FSDP2 with Accelerate, utilizing extra methods to improve training speed, performance or accuracy.
### FSDP2 + ND Parallelism
With `ParallelismConfig`, you can use 🤗 accelerate to train models with n-dimensional parallelism. This builds on top of 🤗 transformers, which we utilize for tensor parallelism sharding.
Accelerate then takes care of everything else, such as data parallelism, FSDP, and more to come.
Script `nd_parallel.py` showcases just how you can do it. We enable you to configure 3 different parallel dimensions (for now 👀):
- dp_replicate_size: how many replicas of the model to create, each replica is trained on a different subset of the data and averaged at the end of each step, same as DDP in Torch
- dp_shard_size: across how many devices is the model sharded, this is utilizing FSDP2 to shard the model across devices, so each device has a different part of the model
- tp_size: how many devices to use for tensor parallelism, this is utilizing the tensor parallelism from 🤗 transformers
For example, with 8 nodes, you can run the script as such:
```bash
accelerate launch --num-processes 8 nd_parallel.py \
--dp-replicate-size 2 \
--dp-shard-size 2 \
--tp-size 2 \
```
<Tip>
Only use TP intra-node - therefore max TP size you should need is 8, you can also lower this as FSDP (`--dp-shard-size`) can be faster on smaller models with
shorter sequence lengths. If you still cannot fit into memory, utilize `--dp-shard-size` as much as you can. Then to scale up to utilize all your GPUs, fill the rest
with `--dp-replicate-size`. This is only a general guideline, you can (and should) experiment with different parallelism configurations to find the best one for your model and hardware. You can learn more about the general strategies for parallelism in our [blog](TODO) or if you wanna dive deep, read the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook).
</Tip>
We plan to add more parallelisms in the future, with context parallelism coming soon and pipeline parallelism being planned.
### FSDP2 + ao Float8Linear
In file `fsdp2_fp8.py` we use `Float8Linear` from `ao` to train a model partially in FP8 precision. We utilize `AORecipeKwargs` to pass the `Float8LinearConfig` to the accelerator,
@ -34,3 +59,4 @@ The figures above were generated on 8x H100 SXM GPUs, with 8192 sequence length
```bash
accelerate launch fsdp2_fp8.py --sequence-length 8192 --num-steps 1000 --log_with wandb --precision [fp8 | bf16]
```

View File

@ -0,0 +1,155 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example of training with ND parallel using accelerate's ParallelismConfig
"""
import argparse
import warnings
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM
from accelerate import Accelerator
from accelerate.parallelism_config import ParallelismConfig
from accelerate.utils import FullyShardedDataParallelPlugin, set_seed
from utils import (
PerformanceTracker,
create_collate_fn,
get_dataset,
setup_tokenizer,
)
MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dp-replicate-size", type=int, default=1)
parser.add_argument("--dp-shard-size", type=int, default=1)
parser.add_argument("--tp-size", type=int, default=1)
parser.add_argument("--sequence-length", type=int, default=1024)
parser.add_argument("--num-steps", type=int, default=1000)
parser.add_argument("--save-dir", type=str, default="./outputs")
parser.add_argument("--checkpoint-frequency", type=int, default=100)
parser.add_argument("--model-name", type=str, default=MODEL_ID)
return parser.parse_args()
def forward(model, batch, optimizer, accelerator):
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
loss_reduce_grp = (
accelerator.torch_device_mesh["dp_cp"].get_group() if accelerator.parallelism_config.dp_cp_dim_names else None
)
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)
return loss
def train(args):
parallelism_config = ParallelismConfig(
dp_replicate_size=args.dp_replicate_size,
dp_shard_size=args.dp_shard_size,
tp_size=args.tp_size,
)
# FSDP needs extra configuration, so we properly shard the model
if parallelism_config.dp_shard_enabled:
fsdp2_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
auto_wrap_policy="transformer_based_wrap",
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
)
accelerator = Accelerator(
log_with=["wandb"],
mixed_precision="bf16",
parallelism_config=parallelism_config,
fsdp_plugin=fsdp2_plugin if parallelism_config.dp_shard_enabled else None,
)
accelerator.init_trackers("nd_parallel_training")
# If TP was enabled, we need to tell transformers to prepare the model for us
model_kwargs = (
{"tp_size": args.tp_size, "tp_plan": "auto", "device_mesh": accelerator.torch_device_mesh}
if args.tp_size > 1
else {}
)
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
torch_dtype=torch.bfloat16,
use_cache=False,
**model_kwargs,
)
tokenizer = setup_tokenizer(args.model_name)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)
dataset = get_dataset(accelerator, tokenizer, args.sequence_length)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
total_num_steps = min(args.num_steps, len(dataloader))
performance_tracker = PerformanceTracker(warmup_steps=5)
accelerator.print("Starting training...")
for step, batch in enumerate(dataloader):
if step >= total_num_steps:
break
loss = forward(model, batch, optimizer, accelerator)
# We report TPS per device, so we divide by the number of devices in the non-data parallel dimension
metrics = performance_tracker.step(batch["input_ids"].shape[1] / parallelism_config.non_data_parallel_size)
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
if "warmup_completed" in metrics:
accelerator.print("Warm up completed! Starting performance tracking...")
elif metrics:
print_msg += performance_tracker.get_print_message(metrics, with_memory=True)
if step % 10 == 0 or step == total_num_steps - 1:
accelerator.print(print_msg)
if step % args.checkpoint_frequency == 0 and step > 0 and parallelism_config.dp_shard_enabled:
accelerator.print(f"Saving checkpoint at step {step}...")
accelerator.save_state(args.save_dir + f"/checkpoint-{step}")
accelerator.log({"loss": loss.item()})
accelerator.print("Training completed!")
model.save_pretrained(args.save_dir + f"/{args.model_name}")
accelerator.print(f"Model saved to {args.save_dir}/{args.model_name}")
accelerator.end_training()
if __name__ == "__main__":
set_seed(42)
args = parse_args()
if args.dp_shard_size == 1:
# We currently don't support saving with `save_state` when using only
# tensor parallelism, fsdp must be enabled
warnings.warn(
"Accelerator.save_state() is not yet supported with pure tensor parallel training. Training will work, but intermediate checkpoints will not be saved."
)
train(args)

View File

@ -0,0 +1,187 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example of training with ND parallel using accelerate's ParallelismConfig
"""
import argparse
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM
from accelerate import Accelerator
from accelerate.parallelism_config import ParallelismConfig
from accelerate.state import PartialState
from accelerate.utils import FullyShardedDataParallelPlugin, set_seed
from accelerate.utils.fsdp_utils import save_fsdp_optimizer
from utils import PerformanceTracker, create_collate_fn, get_dataset, gpu_memory_usage_all, setup_tokenizer
MODEL_ID = "NousResearch/Llama-3.2-1B"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)
parser.add_argument("--fsdp2-cls-name-to-wrap", type=str, default="LlamaDecoderLayer")
parser.add_argument("--dp-replicate-size", type=int, default=1)
parser.add_argument("--dp-shard-size", type=int, default=1)
parser.add_argument("--tp-size", type=int, default=1)
parser.add_argument("--sequence-length", type=int, default=128)
parser.add_argument("--model-save-dir", type=str, default="./outputs")
parser.add_argument(
"--save-model", action="store_true", default=False, help="Whether to save the model after training."
)
parser.add_argument(
"--save-optimizer",
action="store_true",
default=False,
help="Whether to save the optimizer state after training.",
)
return parser.parse_args()
def main():
"""
Main function to train the model.
"""
args = parse_args()
set_seed(42)
if args.model:
model_id = args.model
else:
model_id = MODEL_ID
model_kwargs = {}
accelerator_kwargs = {}
parallelism_config = ParallelismConfig(
dp_replicate_size=args.dp_replicate_size,
dp_shard_size=args.dp_shard_size,
tp_size=args.tp_size,
)
device_mesh = parallelism_config.build_device_mesh("cuda")
if args.tp_size > 1:
model_kwargs["tp_size"] = args.tp_size
model_kwargs["tp_plan"] = "auto"
model_kwargs["device_mesh"] = device_mesh
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
use_cache=False,
**model_kwargs,
)
PartialState(device_mesh=device_mesh, parallelism_config=parallelism_config)
if parallelism_config.dp_shard_enabled:
fsdp2_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
cpu_ram_efficient_loading=False,
auto_wrap_policy="transformer_based_wrap",
transformer_cls_names_to_wrap=[args.fsdp2_cls_name_to_wrap],
reshard_after_forward=True,
activation_checkpointing=True,
state_dict_type="FULL_STATE_DICT",
)
accelerator_kwargs["fsdp_plugin"] = fsdp2_plugin
accelerator = Accelerator(
mixed_precision="no",
**accelerator_kwargs,
)
accelerator.print("Memory usage after model load")
accelerator.print(gpu_memory_usage_all())
accelerator.print("=" * 20)
tokenizer = setup_tokenizer(model_id)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)
model, optimizer = accelerator.prepare(model, optimizer)
accelerator.print("Memory usage after model prepare")
accelerator.print(gpu_memory_usage_all())
accelerator.print("=" * 20)
dataset = get_dataset(accelerator, tokenizer, args.sequence_length)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
dataloader = accelerator.prepare(dataloader)
model.train()
total_num_steps = min(100, len(dataloader))
performance_tracker = PerformanceTracker(warmup_steps=10)
accelerator.print("Starting training...")
for step, batch in enumerate(dataloader):
if step >= total_num_steps:
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
dist.all_reduce(loss, op=dist.ReduceOp.AVG)
batch_tokens = batch["input_ids"].shape[1]
metrics = performance_tracker.step(batch_tokens)
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
log_metrics = {"loss": loss.item()}
if "warmup_completed" in metrics:
accelerator.print("Warm up completed! Starting performance tracking...")
elif metrics:
print_msg += f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f}\n"
print_msg += (
f"\tMemory (GB): active={metrics['peak_memory_active']:.1f}, "
f"alloc={metrics['peak_memory_alloc']:.1f}, "
f"reserved={metrics['peak_memory_reserved']:.1f}"
)
if step % 10 == 0 or step == total_num_steps - 1:
accelerator.print(print_msg)
accelerator.log(log_metrics)
accelerator.wait_for_everyone()
accelerator.end_training()
accelerator.print("Training completed!")
if parallelism_config.dp_shard_enabled and args.save_optimizer:
accelerator.print("Saving optimizer state...")
save_fsdp_optimizer(
fsdp2_plugin,
accelerator,
optimizer,
model,
args.model_save_dir + "/opt",
)
accelerator.print("Optimizer state saved.")
accelerator.print("Saving model state...")
if args.save_model:
model.save_pretrained(args.model_save_dir)
accelerator.print(f"Model saved to {args.model_save_dir}")
if __name__ == "__main__":
main()

201
examples/fsdp2/utils.py Normal file
View File

@ -0,0 +1,201 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Common utilities for FSDP2 examples.
"""
import time
import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
def get_dataset(accelerator: Accelerator, tokenizer: AutoTokenizer, seq_len: int) -> Dataset:
"""
Load and prepare TinyStories dataset.
Args:
accelerator (Accelerator): Accelerate accelerator instance
tokenizer (AutoTokenizer): Hugging Face tokenizer
seq_len (int): Sequence length for the dataset
Returns:
Dataset: Packed dataset
"""
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:50%]")
def tokenize_function(examples):
tokenized_batch = tokenizer(
examples["text"],
padding=False,
truncation=True,
max_length=seq_len,
return_tensors=None,
)
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
return tokenized_batch
with accelerator.main_process_first():
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
def create_packed_sequences(examples):
all_tokens = []
for input_ids in examples["input_ids"]:
all_tokens.extend(input_ids)
num_sequences = len(all_tokens) // (seq_len + 1)
packed_input_ids = []
packed_labels = []
for i in range(num_sequences):
start_idx = i * (seq_len + 1)
end_idx = start_idx + (seq_len + 1)
full_sequence = all_tokens[start_idx:end_idx]
packed_input_ids.append(full_sequence[:-1])
packed_labels.append(full_sequence[1:])
return {"input_ids": packed_input_ids, "labels": packed_labels}
with accelerator.main_process_first():
packed_dataset = tokenized_dataset.map(
create_packed_sequences,
batched=True,
remove_columns=tokenized_dataset.column_names,
batch_size=1000,
)
return packed_dataset.shuffle(seed=42)
def get_model_flops_per_token(model: AutoModelForCausalLM, seq_len: int) -> float:
"""
Get the number of flops per token for the model.
Args:
model (AutoModelForCausalLM): Model to get the flops for
seq_len (int): Sequence length
"""
cfg = model.config
head_dim = cfg.hidden_size // cfg.num_attention_heads
# MLP: 3 matmuls
mlp_flops = 18 * cfg.hidden_size * cfg.intermediate_size
# Attn (w/o dotproduct)
attn_flops = 12 * head_dim * (cfg.num_attention_heads + cfg.num_key_value_heads)
# attn (dotproduct) - this scales quadratically with sequence length
attn_dotproduct_flops = 12 * cfg.num_attention_heads * head_dim * seq_len
# we also ignore embeddings and layernorms, etc
return (mlp_flops + attn_flops + attn_dotproduct_flops) * cfg.num_hidden_layers
def create_collate_fn():
"""Create a collate function for batching."""
def collate_fn(batch):
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
return {"input_ids": input_ids, "labels": labels}
return collate_fn
class PerformanceTracker:
"""Track training performance metrics."""
def __init__(self, warmup_steps: int = 10):
self.warmup_steps = warmup_steps
self.reset()
def reset(self):
"""Reset all tracking variables."""
self.start_time = None
self.num_tokens = 0
self.is_in_warmup = True
self.step_count = 0
def step(self, batch_tokens: int) -> dict:
"""
Update performance tracking with a new step.
Args:
batch_tokens (int): Number of tokens in current batch
Returns:
dict: Performance metrics if past warmup, empty dict otherwise
"""
self.step_count += 1
if self.step_count == self.warmup_steps:
self.start_time = time.perf_counter()
self.num_tokens = 0
self.is_in_warmup = False
return {"warmup_completed": True}
if not self.is_in_warmup and self.start_time is not None:
self.num_tokens += batch_tokens
total_time = time.perf_counter() - self.start_time
steps_from_warmup = self.step_count - self.warmup_steps
if total_time > 0 and steps_from_warmup > 0:
memory_stats = gpu_memory_usage_all()
return {
"tokens_per_second": self.num_tokens / total_time,
"steps_per_second": steps_from_warmup / total_time,
"total_tokens": self.num_tokens,
"total_time": total_time,
**memory_stats,
}
return {}
def get_print_message(self, metrics: dict, with_memory: bool = False) -> str:
print_msg = f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f}\n"
if with_memory:
print_msg += (
f"\tMemory (GB): active={metrics['peak_memory_active']:.1f}, "
f"alloc={metrics['peak_memory_alloc']:.1f}, "
f"reserved={metrics['peak_memory_reserved']:.1f}"
)
return print_msg
def setup_tokenizer(model_id: str) -> AutoTokenizer:
"""Setup tokenizer with proper padding token."""
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def gpu_memory_usage_all(device=0):
device = torch.device(f"cuda:{device}")
_BYTES_IN_GIB = 1024**3
peak_memory_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
peak_memory_alloc = torch.cuda.max_memory_allocated(device) / _BYTES_IN_GIB
peak_memory_reserved = torch.cuda.max_memory_reserved(device) / _BYTES_IN_GIB
memory_stats = {
"peak_memory_active": peak_memory_active,
"peak_memory_alloc": peak_memory_alloc,
"peak_memory_reserved": peak_memory_reserved,
}
torch.cuda.reset_peak_memory_stats(device)
return memory_stats