Parallelism config + TP + HSDP + BYODM (Bring Your Own Device Mesh) (#3682)

* Feat: init

* Feat: add validation + init from kwargs

* Fix: minor fixes

* Feat: more cleanup

* Minor refactor

* remove import

* adding support for pre-configured device mesh

* adding device mesh to fsdp2

* moving mesh dim defn to parralismconfig

* tests

* WIP device mesh/accelerator validation

* WIP more tests

* Test Driven Development (TDD)

* fixing build_device_mesh

* FSDP dim names

* adding example

* WIP

* fixing HSDP

* Feat: add back old options

* working example

* debugging

* adding parallelism config to partialstate

* Feat: revert ddp changes

* Revert DDP

* Feat: (untested) update mesh dims and some minor tweaks

* adding dp_cp dims

* updating comments

* WIP

* wip 2

* reverting

* storing state in accelerator rather than acceleratorstate

* Fix: minor tweaks

* wip example update

* Fixes for non-fsdp2 case

* Feat: ensure ddp/tp only works

* updating example

* updating example

* updating examples, fixing state

* fixed state

* comments

* fixing partial state check

* linting

* comments

* removing fn

* WIP: fix tp

* comments

* removing return

* reverting upcast

* add guards

* guards for empty self.parallelism_config

* use len on tuple to check if empty

* Feat: cleanup example

* Feat: some cleanup of example

* Feat: add trackio

* Fix: improve trackio

* Feat: TP works

* Feat: some fsdp2 improv

* Feat: working examples

* handle clipping for tensor parallel

* Implicit replicate

* Refactor: move to separate file + cleanup + basic comments

* Fix: add unadded files, fix circular import

* Feat: better readme

* Feat: add blog + ultrascale links

* Tmp: should_save_model now returns only true

* Fix: remove implicit_replication and style

* Fix: remove optional

* add guard on parallelism_config.tp_enabled

* fix import

* fixing empty parallelism_config

* fix import path for test patch

* fixing patch

---------

Co-authored-by: S1ro1 <matej.sirovatka@gmail.com>
Co-authored-by: Salman Mohammadi <“salman.mohammadi@outlook.com”>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
salman
2025-07-30 20:03:13 +01:00
committed by GitHub
parent 2f075c724c
commit 9359a0194f
15 changed files with 1237 additions and 122 deletions

View File

@ -2,6 +2,31 @@
This folder contains examples of using FSDP2 with Accelerate, utilizing extra methods to improve training speed, performance or accuracy.
### FSDP2 + ND Parallelism
With `ParallelismConfig`, you can use 🤗 accelerate to train models with n-dimensional parallelism. This builds on top of 🤗 transformers, which we utilize for tensor parallelism sharding.
Accelerate then takes care of everything else, such as data parallelism, FSDP, and more to come.
Script `nd_parallel.py` showcases just how you can do it. We enable you to configure 3 different parallel dimensions (for now 👀):
- dp_replicate_size: how many replicas of the model to create, each replica is trained on a different subset of the data and averaged at the end of each step, same as DDP in Torch
- dp_shard_size: across how many devices is the model sharded, this is utilizing FSDP2 to shard the model across devices, so each device has a different part of the model
- tp_size: how many devices to use for tensor parallelism, this is utilizing the tensor parallelism from 🤗 transformers
For example, with 8 nodes, you can run the script as such:
```bash
accelerate launch --num-processes 8 nd_parallel.py \
--dp-replicate-size 2 \
--dp-shard-size 2 \
--tp-size 2 \
```
<Tip>
Only use TP intra-node - therefore max TP size you should need is 8, you can also lower this as FSDP (`--dp-shard-size`) can be faster on smaller models with
shorter sequence lengths. If you still cannot fit into memory, utilize `--dp-shard-size` as much as you can. Then to scale up to utilize all your GPUs, fill the rest
with `--dp-replicate-size`. This is only a general guideline, you can (and should) experiment with different parallelism configurations to find the best one for your model and hardware. You can learn more about the general strategies for parallelism in our [blog](TODO) or if you wanna dive deep, read the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook).
</Tip>
We plan to add more parallelisms in the future, with context parallelism coming soon and pipeline parallelism being planned.
### FSDP2 + ao Float8Linear
In file `fsdp2_fp8.py` we use `Float8Linear` from `ao` to train a model partially in FP8 precision. We utilize `AORecipeKwargs` to pass the `Float8LinearConfig` to the accelerator,
@ -34,3 +59,4 @@ The figures above were generated on 8x H100 SXM GPUs, with 8192 sequence length
```bash
accelerate launch fsdp2_fp8.py --sequence-length 8192 --num-steps 1000 --log_with wandb --precision [fp8 | bf16]
```

View File

@ -0,0 +1,155 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example of training with ND parallel using accelerate's ParallelismConfig
"""
import argparse
import warnings
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM
from accelerate import Accelerator
from accelerate.parallelism_config import ParallelismConfig
from accelerate.utils import FullyShardedDataParallelPlugin, set_seed
from utils import (
PerformanceTracker,
create_collate_fn,
get_dataset,
setup_tokenizer,
)
MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dp-replicate-size", type=int, default=1)
parser.add_argument("--dp-shard-size", type=int, default=1)
parser.add_argument("--tp-size", type=int, default=1)
parser.add_argument("--sequence-length", type=int, default=1024)
parser.add_argument("--num-steps", type=int, default=1000)
parser.add_argument("--save-dir", type=str, default="./outputs")
parser.add_argument("--checkpoint-frequency", type=int, default=100)
parser.add_argument("--model-name", type=str, default=MODEL_ID)
return parser.parse_args()
def forward(model, batch, optimizer, accelerator):
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
loss_reduce_grp = (
accelerator.torch_device_mesh["dp_cp"].get_group() if accelerator.parallelism_config.dp_cp_dim_names else None
)
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)
return loss
def train(args):
parallelism_config = ParallelismConfig(
dp_replicate_size=args.dp_replicate_size,
dp_shard_size=args.dp_shard_size,
tp_size=args.tp_size,
)
# FSDP needs extra configuration, so we properly shard the model
if parallelism_config.dp_shard_enabled:
fsdp2_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
auto_wrap_policy="transformer_based_wrap",
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
)
accelerator = Accelerator(
log_with=["wandb"],
mixed_precision="bf16",
parallelism_config=parallelism_config,
fsdp_plugin=fsdp2_plugin if parallelism_config.dp_shard_enabled else None,
)
accelerator.init_trackers("nd_parallel_training")
# If TP was enabled, we need to tell transformers to prepare the model for us
model_kwargs = (
{"tp_size": args.tp_size, "tp_plan": "auto", "device_mesh": accelerator.torch_device_mesh}
if args.tp_size > 1
else {}
)
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
torch_dtype=torch.bfloat16,
use_cache=False,
**model_kwargs,
)
tokenizer = setup_tokenizer(args.model_name)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)
dataset = get_dataset(accelerator, tokenizer, args.sequence_length)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
total_num_steps = min(args.num_steps, len(dataloader))
performance_tracker = PerformanceTracker(warmup_steps=5)
accelerator.print("Starting training...")
for step, batch in enumerate(dataloader):
if step >= total_num_steps:
break
loss = forward(model, batch, optimizer, accelerator)
# We report TPS per device, so we divide by the number of devices in the non-data parallel dimension
metrics = performance_tracker.step(batch["input_ids"].shape[1] / parallelism_config.non_data_parallel_size)
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
if "warmup_completed" in metrics:
accelerator.print("Warm up completed! Starting performance tracking...")
elif metrics:
print_msg += performance_tracker.get_print_message(metrics, with_memory=True)
if step % 10 == 0 or step == total_num_steps - 1:
accelerator.print(print_msg)
if step % args.checkpoint_frequency == 0 and step > 0 and parallelism_config.dp_shard_enabled:
accelerator.print(f"Saving checkpoint at step {step}...")
accelerator.save_state(args.save_dir + f"/checkpoint-{step}")
accelerator.log({"loss": loss.item()})
accelerator.print("Training completed!")
model.save_pretrained(args.save_dir + f"/{args.model_name}")
accelerator.print(f"Model saved to {args.save_dir}/{args.model_name}")
accelerator.end_training()
if __name__ == "__main__":
set_seed(42)
args = parse_args()
if args.dp_shard_size == 1:
# We currently don't support saving with `save_state` when using only
# tensor parallelism, fsdp must be enabled
warnings.warn(
"Accelerator.save_state() is not yet supported with pure tensor parallel training. Training will work, but intermediate checkpoints will not be saved."
)
train(args)

View File

@ -0,0 +1,187 @@
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example of training with ND parallel using accelerate's ParallelismConfig
"""
import argparse
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM
from accelerate import Accelerator
from accelerate.parallelism_config import ParallelismConfig
from accelerate.state import PartialState
from accelerate.utils import FullyShardedDataParallelPlugin, set_seed
from accelerate.utils.fsdp_utils import save_fsdp_optimizer
from utils import PerformanceTracker, create_collate_fn, get_dataset, gpu_memory_usage_all, setup_tokenizer
MODEL_ID = "NousResearch/Llama-3.2-1B"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str)
parser.add_argument("--fsdp2-cls-name-to-wrap", type=str, default="LlamaDecoderLayer")
parser.add_argument("--dp-replicate-size", type=int, default=1)
parser.add_argument("--dp-shard-size", type=int, default=1)
parser.add_argument("--tp-size", type=int, default=1)
parser.add_argument("--sequence-length", type=int, default=128)
parser.add_argument("--model-save-dir", type=str, default="./outputs")
parser.add_argument(
"--save-model", action="store_true", default=False, help="Whether to save the model after training."
)
parser.add_argument(
"--save-optimizer",
action="store_true",
default=False,
help="Whether to save the optimizer state after training.",
)
return parser.parse_args()
def main():
"""
Main function to train the model.
"""
args = parse_args()
set_seed(42)
if args.model:
model_id = args.model
else:
model_id = MODEL_ID
model_kwargs = {}
accelerator_kwargs = {}
parallelism_config = ParallelismConfig(
dp_replicate_size=args.dp_replicate_size,
dp_shard_size=args.dp_shard_size,
tp_size=args.tp_size,
)
device_mesh = parallelism_config.build_device_mesh("cuda")
if args.tp_size > 1:
model_kwargs["tp_size"] = args.tp_size
model_kwargs["tp_plan"] = "auto"
model_kwargs["device_mesh"] = device_mesh
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
use_cache=False,
**model_kwargs,
)
PartialState(device_mesh=device_mesh, parallelism_config=parallelism_config)
if parallelism_config.dp_shard_enabled:
fsdp2_plugin = FullyShardedDataParallelPlugin(
fsdp_version=2,
cpu_ram_efficient_loading=False,
auto_wrap_policy="transformer_based_wrap",
transformer_cls_names_to_wrap=[args.fsdp2_cls_name_to_wrap],
reshard_after_forward=True,
activation_checkpointing=True,
state_dict_type="FULL_STATE_DICT",
)
accelerator_kwargs["fsdp_plugin"] = fsdp2_plugin
accelerator = Accelerator(
mixed_precision="no",
**accelerator_kwargs,
)
accelerator.print("Memory usage after model load")
accelerator.print(gpu_memory_usage_all())
accelerator.print("=" * 20)
tokenizer = setup_tokenizer(model_id)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)
model, optimizer = accelerator.prepare(model, optimizer)
accelerator.print("Memory usage after model prepare")
accelerator.print(gpu_memory_usage_all())
accelerator.print("=" * 20)
dataset = get_dataset(accelerator, tokenizer, args.sequence_length)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
dataloader = accelerator.prepare(dataloader)
model.train()
total_num_steps = min(100, len(dataloader))
performance_tracker = PerformanceTracker(warmup_steps=10)
accelerator.print("Starting training...")
for step, batch in enumerate(dataloader):
if step >= total_num_steps:
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
dist.all_reduce(loss, op=dist.ReduceOp.AVG)
batch_tokens = batch["input_ids"].shape[1]
metrics = performance_tracker.step(batch_tokens)
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
log_metrics = {"loss": loss.item()}
if "warmup_completed" in metrics:
accelerator.print("Warm up completed! Starting performance tracking...")
elif metrics:
print_msg += f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f}\n"
print_msg += (
f"\tMemory (GB): active={metrics['peak_memory_active']:.1f}, "
f"alloc={metrics['peak_memory_alloc']:.1f}, "
f"reserved={metrics['peak_memory_reserved']:.1f}"
)
if step % 10 == 0 or step == total_num_steps - 1:
accelerator.print(print_msg)
accelerator.log(log_metrics)
accelerator.wait_for_everyone()
accelerator.end_training()
accelerator.print("Training completed!")
if parallelism_config.dp_shard_enabled and args.save_optimizer:
accelerator.print("Saving optimizer state...")
save_fsdp_optimizer(
fsdp2_plugin,
accelerator,
optimizer,
model,
args.model_save_dir + "/opt",
)
accelerator.print("Optimizer state saved.")
accelerator.print("Saving model state...")
if args.save_model:
model.save_pretrained(args.model_save_dir)
accelerator.print(f"Model saved to {args.model_save_dir}")
if __name__ == "__main__":
main()

201
examples/fsdp2/utils.py Normal file
View File

@ -0,0 +1,201 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Common utilities for FSDP2 examples.
"""
import time
import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
def get_dataset(accelerator: Accelerator, tokenizer: AutoTokenizer, seq_len: int) -> Dataset:
"""
Load and prepare TinyStories dataset.
Args:
accelerator (Accelerator): Accelerate accelerator instance
tokenizer (AutoTokenizer): Hugging Face tokenizer
seq_len (int): Sequence length for the dataset
Returns:
Dataset: Packed dataset
"""
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:50%]")
def tokenize_function(examples):
tokenized_batch = tokenizer(
examples["text"],
padding=False,
truncation=True,
max_length=seq_len,
return_tensors=None,
)
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
return tokenized_batch
with accelerator.main_process_first():
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
def create_packed_sequences(examples):
all_tokens = []
for input_ids in examples["input_ids"]:
all_tokens.extend(input_ids)
num_sequences = len(all_tokens) // (seq_len + 1)
packed_input_ids = []
packed_labels = []
for i in range(num_sequences):
start_idx = i * (seq_len + 1)
end_idx = start_idx + (seq_len + 1)
full_sequence = all_tokens[start_idx:end_idx]
packed_input_ids.append(full_sequence[:-1])
packed_labels.append(full_sequence[1:])
return {"input_ids": packed_input_ids, "labels": packed_labels}
with accelerator.main_process_first():
packed_dataset = tokenized_dataset.map(
create_packed_sequences,
batched=True,
remove_columns=tokenized_dataset.column_names,
batch_size=1000,
)
return packed_dataset.shuffle(seed=42)
def get_model_flops_per_token(model: AutoModelForCausalLM, seq_len: int) -> float:
"""
Get the number of flops per token for the model.
Args:
model (AutoModelForCausalLM): Model to get the flops for
seq_len (int): Sequence length
"""
cfg = model.config
head_dim = cfg.hidden_size // cfg.num_attention_heads
# MLP: 3 matmuls
mlp_flops = 18 * cfg.hidden_size * cfg.intermediate_size
# Attn (w/o dotproduct)
attn_flops = 12 * head_dim * (cfg.num_attention_heads + cfg.num_key_value_heads)
# attn (dotproduct) - this scales quadratically with sequence length
attn_dotproduct_flops = 12 * cfg.num_attention_heads * head_dim * seq_len
# we also ignore embeddings and layernorms, etc
return (mlp_flops + attn_flops + attn_dotproduct_flops) * cfg.num_hidden_layers
def create_collate_fn():
"""Create a collate function for batching."""
def collate_fn(batch):
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
return {"input_ids": input_ids, "labels": labels}
return collate_fn
class PerformanceTracker:
"""Track training performance metrics."""
def __init__(self, warmup_steps: int = 10):
self.warmup_steps = warmup_steps
self.reset()
def reset(self):
"""Reset all tracking variables."""
self.start_time = None
self.num_tokens = 0
self.is_in_warmup = True
self.step_count = 0
def step(self, batch_tokens: int) -> dict:
"""
Update performance tracking with a new step.
Args:
batch_tokens (int): Number of tokens in current batch
Returns:
dict: Performance metrics if past warmup, empty dict otherwise
"""
self.step_count += 1
if self.step_count == self.warmup_steps:
self.start_time = time.perf_counter()
self.num_tokens = 0
self.is_in_warmup = False
return {"warmup_completed": True}
if not self.is_in_warmup and self.start_time is not None:
self.num_tokens += batch_tokens
total_time = time.perf_counter() - self.start_time
steps_from_warmup = self.step_count - self.warmup_steps
if total_time > 0 and steps_from_warmup > 0:
memory_stats = gpu_memory_usage_all()
return {
"tokens_per_second": self.num_tokens / total_time,
"steps_per_second": steps_from_warmup / total_time,
"total_tokens": self.num_tokens,
"total_time": total_time,
**memory_stats,
}
return {}
def get_print_message(self, metrics: dict, with_memory: bool = False) -> str:
print_msg = f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f}\n"
if with_memory:
print_msg += (
f"\tMemory (GB): active={metrics['peak_memory_active']:.1f}, "
f"alloc={metrics['peak_memory_alloc']:.1f}, "
f"reserved={metrics['peak_memory_reserved']:.1f}"
)
return print_msg
def setup_tokenizer(model_id: str) -> AutoTokenizer:
"""Setup tokenizer with proper padding token."""
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
return tokenizer
def gpu_memory_usage_all(device=0):
device = torch.device(f"cuda:{device}")
_BYTES_IN_GIB = 1024**3
peak_memory_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
peak_memory_alloc = torch.cuda.max_memory_allocated(device) / _BYTES_IN_GIB
peak_memory_reserved = torch.cuda.max_memory_reserved(device) / _BYTES_IN_GIB
memory_stats = {
"peak_memory_active": peak_memory_active,
"peak_memory_alloc": peak_memory_alloc,
"peak_memory_reserved": peak_memory_reserved,
}
torch.cuda.reset_peak_memory_stats(device)
return memory_stats

View File

@ -26,6 +26,7 @@ from .big_modeling import (
from .data_loader import skip_first_batches
from .inference import prepare_pippy
from .launchers import debug_launcher, notebook_launcher
from .parallelism_config import ParallelismConfig
from .state import PartialState
from .utils import (
AutocastKwargs,

View File

@ -39,6 +39,7 @@ from .checkpointing import load_accelerator_state, load_custom_state, save_accel
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
from .logging import get_logger
from .optimizer import AcceleratedOptimizer
from .parallelism_config import ParallelismConfig
from .scheduler import AcceleratedScheduler
from .state import AcceleratorState, GradientState, PartialState
from .tracking import LOGGER_TYPE_TO_CLASS, GeneralTracker, filter_trackers
@ -122,8 +123,6 @@ from .utils import (
wait_for_everyone,
)
from .utils.constants import (
BETA_TP_AVAILABLE_PYTORCH_VERSION,
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION,
FSDP2_PYTORCH_VERSION,
FSDP_PYTORCH_VERSION,
PROFILE_PATTERN_NAME,
@ -211,8 +210,7 @@ class Accelerator:
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
using *accelerate config*
torch_tp_plugin ([`~utils.TorchTensorParallelPlugin`], *optional*):
Tweak your torch tensor parallel. This argument is optional and can be configured directly using
*accelerate config*
Deprecated: use `parallelism_config` with `tp_size` instead.
megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*):
Tweak your MegatronLM related args using this argument. This argument is optional and can be configured
directly using *accelerate config*
@ -287,7 +285,7 @@ class Accelerator:
dataloader_config: DataLoaderConfiguration | None = None,
deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
torch_tp_plugin: TorchTensorParallelPlugin | None = None,
torch_tp_plugin: TorchTensorParallelPlugin | None = None, # Deprecate later, warning in `post_init`
megatron_lm_plugin: MegatronLMPlugin | None = None,
rng_types: list[str | RNGType] | None = None,
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
@ -299,6 +297,7 @@ class Accelerator:
dynamo_backend: DynamoBackend | str | None = None,
dynamo_plugin: TorchDynamoPlugin | None = None,
deepspeed_plugins: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
parallelism_config: ParallelismConfig | None = None,
):
self.trackers = []
if project_config is not None:
@ -314,6 +313,12 @@ class Accelerator:
raise ValueError(
f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}"
)
if torch_tp_plugin is not None:
warnings.warn(
"`TorchTensorParallelPlugin` is deprecated and will be removed in a future version of Accelerate. "
"Please use the `ParallelismConfig` with `tp_size` instead.",
FutureWarning,
)
if dynamo_plugin is not None and dynamo_backend is not None:
raise ValueError("You cannot pass in both `dynamo_plugin` and `dynamo_backend`, please only pass in one.")
@ -376,13 +381,6 @@ class Accelerator:
if not is_torch_version(">=", FSDP_PYTORCH_VERSION):
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")
if isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
raise ValueError(f"TP requires PyTorch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}")
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
if fsdp_plugin is None: # init from env variables
fsdp_plugin = (
FullyShardedDataParallelPlugin() if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" else None
@ -396,9 +394,6 @@ class Accelerator:
if not is_torch_version(">=", FSDP2_PYTORCH_VERSION):
raise ImportError(f"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}")
if torch_tp_plugin is not None and not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
if megatron_lm_plugin is None: # init from env variables
megatron_lm_plugin = (
MegatronLMPlugin() if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" else None
@ -451,19 +446,25 @@ class Accelerator:
if "recipe_handler" in handler_attr and not self.has_fp8_handler:
self.has_fp8_handler = True
parallelism_config = self._setup_parallelism_config(parallelism_config, torch_tp_plugin)
kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {}
kwargs["parallelism_config"] = parallelism_config
self.state = AcceleratorState(
mixed_precision=mixed_precision,
cpu=cpu,
dynamo_plugin=dynamo_plugin,
deepspeed_plugin=deepspeed_plugins,
fsdp_plugin=fsdp_plugin,
torch_tp_plugin=torch_tp_plugin,
megatron_lm_plugin=megatron_lm_plugin,
_from_accelerator=True,
**kwargs,
)
if self.parallelism_config:
self._build_torch_device_mesh(self.parallelism_config)
self.parallelism_config._validate_accelerator(self)
self.fp8_enabled = self.state.mixed_precision == "fp8" or mixed_precision == "fp8"
# Check for automatic FP8 recipe creation
@ -646,6 +647,18 @@ class Accelerator:
"""
return self.state.use_distributed
@property
def multi_device(self):
return self.use_distributed and self.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_MLU,
DistributedType.MULTI_SDAA,
DistributedType.MULTI_MUSA,
DistributedType.MULTI_NPU,
DistributedType.MULTI_XPU,
DistributedType.MULTI_HPU,
)
@property
def distributed_type(self):
return self.state.distributed_type
@ -730,6 +743,55 @@ class Accelerator:
def is_fsdp2(self):
return self.state.is_fsdp2
@property
def is_composable_parallelism_enabled(self):
return self.is_fsdp2
@property
def parallelism_config(self) -> Union[ParallelismConfig, None]:
return self.state.parallelism_config
@property
def torch_device_mesh(self):
return self.state.device_mesh
@property
def should_save_model(self):
if (pc := self.parallelism_config) is None:
# shouldn't even happen
return self.state.is_local_main_process
_non_model_shard_dims = {
pc.dp_replicate_enabled: "dp_replicate",
pc.cp_enabled: "cp",
}
# return all(
# self.torch_device_mesh[dim].get_local_rank() == 0 for key, dim in non_model_shard_dims.items() if key
# )
# TODO: S1ro - this is a temporary solution until we figure out why `save_safe_file` is slow when not all processes
return True
def _setup_parallelism_config(
self, parallelism_config: ParallelismConfig | None, torch_tp_plugin: TorchTensorParallelPlugin | None
):
if parallelism_config is None:
if PartialState._shared_state != {} and PartialState().parallelism_config is not None:
parallelism_config = PartialState().parallelism_config
else:
# TODO: Remove after deprecating tp_plugin
tp_size = 1 if torch_tp_plugin is None else torch_tp_plugin.tp_size
parallelism_config = ParallelismConfig(tp_size=tp_size)
return parallelism_config
def _build_torch_device_mesh(self, parallelism_config):
if PartialState._shared_state != {} and getattr(PartialState(), "device_mesh", None) is not None:
device_mesh = PartialState().device_mesh
else:
device_mesh = parallelism_config.build_device_mesh(self.device.type)
self.state.device_mesh = device_mesh
PartialState().device_mesh = device_mesh
@contextmanager
def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
"""
@ -1235,15 +1297,7 @@ class Accelerator:
... optimizer.zero_grad()
```
"""
if self.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_NPU,
DistributedType.MULTI_MLU,
DistributedType.MULTI_SDAA,
DistributedType.MULTI_MUSA,
DistributedType.MULTI_XPU,
DistributedType.MULTI_HPU,
):
if self.multi_device:
dl_even_batches_values = []
if even_batches is not None:
@ -1440,6 +1494,9 @@ class Accelerator:
"You are using lower version of PyTorch(< 2.7.0) with ipex acceleration on Intel CPU or XPU, Intel has upstreamed most of the optimizations into stock PyTorch from 2.7.0, we enourage you to install the latest stock PyTorch and enjoy the out-of-experience on Intel CPU/XPU."
)
args = self._prepare_ipex(*args)
if self.parallelism_config and self.parallelism_config.tp_enabled:
args = self._prepare_tp(*args)
if self.fp8_backend == FP8BackendType.TE:
args = self._prepare_te(*args)
elif self.fp8_backend == FP8BackendType.AO:
@ -1476,6 +1533,34 @@ class Accelerator:
return result if len(result) > 1 else result[0]
def _prepare_tp(self, *args):
device_mesh = self.torch_device_mesh
for arg in args:
if not isinstance(arg, torch.nn.Module):
continue
from torch.distributed.tensor import DTensor, Replicate
from transformers.integrations.tensor_parallel import ReplicateParallel
model: torch.nn.Module = arg
tp_plan = ReplicateParallel
for name, param in model.named_parameters():
if isinstance(param, DTensor):
continue
dp = DTensor.from_local(param, device_mesh=device_mesh["tp"], placements=[Replicate()])
param_name, param_type = name.rsplit(".", 1)
module_to_tp = model.get_submodule(param_name)
tp_plan().prepare_module_tp(module_to_tp, device_mesh["tp"])
if not isinstance(dp, torch.nn.Parameter):
dp = torch.nn.Parameter(dp, requires_grad=param.requires_grad)
setattr(module_to_tp, param_type, dp)
return args
def _prepare_fsdp2(self, *args):
# First pass: prepare everything except schedulers (and model, which is prepared separately below)
result = [
@ -1514,14 +1599,18 @@ class Accelerator:
old_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=True))
# Swap the optimizer parameters with empty, so `fully_shard` after will not allocate too much memory
from torch.distributed.tensor import DTensor
for obj in result:
if isinstance(obj, torch.optim.Optimizer):
for param_group in obj.param_groups:
for i, p in enumerate(param_group["params"]):
# We drop a reference to the original param here, so that _move_states_to_device triggers a reallocation
# We reassign the data_ptr to the original param, so that we preserve the mapping to the new ones
param_group["params"][i] = torch.empty_like(p)
param_group["params"][i].data_ptr = p.data_ptr()
param_group["params"][i] = torch.empty(1, dtype=p.dtype, device=p.device)
param_group["params"][i].data_ptr = (
p._local_tensor.data_ptr() if isinstance(p, DTensor) else p.data_ptr()
)
self._models.append(model)
@ -1645,15 +1734,7 @@ class Accelerator:
elif device_placement and not self.verify_device_map(model):
model = model.to(self.device)
if not evaluation_mode:
if self.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_MLU,
DistributedType.MULTI_SDAA,
DistributedType.MULTI_MUSA,
DistributedType.MULTI_NPU,
DistributedType.MULTI_XPU,
DistributedType.MULTI_HPU,
):
if self.multi_device and not (self.parallelism_config and self.parallelism_config.tp_enabled):
if model_has_dtensor(model):
raise ValueError(
"Your model contains `DTensor` parameters, which is incompatible with DDP. Maybe you loaded your model with `device_map='auto'`? Specify `device_map='cuda'` or 'cpu' instead."
@ -1668,23 +1749,20 @@ class Accelerator:
device_ids, output_device = [self.local_process_index], self.local_process_index
else:
device_ids, output_device = None, None
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=device_ids, output_device=output_device, **kwargs
)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.TP:
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
elif self.parallelism_config and self.parallelism_config.tp_enabled:
if not hasattr(model, "tp_size"):
raise NotImplementedError(
"Model should undergo tensor parallel before passing it to accelerate."
"You can use .from_pretrained(..., tp_plan='auto') if the model supports"
)
if model.tp_size != self.state.torch_tp_plugin.tp_size:
if model.tp_size != self.parallelism_config.tp_size:
raise ValueError(
f"tp_size in the plugin {self.state.torch_tp_plugin.tp_size} should be same as model's tp size {model.tp_size}"
f"tp_size in the plugin {self.parallelism_config.tp_size} should be same as model's tp size {model.tp_size}"
)
elif self.is_fsdp2:
raise ValueError(
@ -1820,7 +1898,7 @@ class Accelerator:
del self._models[-2]
self._models[-1] = model
elif self.distributed_type == DistributedType.MULTI_CPU:
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler else {}
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
@ -2347,11 +2425,10 @@ class Accelerator:
Prepare the device mesh for distributed training. The dataloader will determine how to load data based on the
device mesh.
"""
if self.state.torch_tp_plugin:
return self.state.torch_tp_plugin.torch_device_mesh
elif self.distributed_type == DistributedType.DEEPSPEED and hasattr(self.state, "ds_device_mesh"):
if self.distributed_type == DistributedType.DEEPSPEED and hasattr(self.state, "ds_device_mesh"):
return self.state.ds_device_mesh
return None
else:
return self.torch_device_mesh
def _prepare_msamp(self, *args, device_placement):
if not is_msamp_available():
@ -3587,14 +3664,7 @@ class Accelerator:
map_location = load_model_func_kwargs.pop("map_location", None)
if map_location is None:
if self.num_processes > 1 and self.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_MLU,
DistributedType.MULTI_SDAA,
DistributedType.MULTI_MUSA,
DistributedType.MULTI_NPU,
DistributedType.MULTI_HPU,
):
if self.num_processes > 1 and self.multi_device and self.distributed_type != DistributedType.MULTI_XPU:
map_location = "on_device"
else:
map_location = "cpu"
@ -3693,6 +3763,11 @@ class Accelerator:
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
accessor_mapping[WeightWithDynamicFloat8CastTensor] = "_tensor"
# we know we're in FSDP2 so DTensor is available
if self.is_fsdp2:
from torch.distributed.tensor import DTensor
accessor_mapping[DTensor] = "_local_tensor"
named_parameters.update(
{

View File

@ -1128,16 +1128,20 @@ def prepare_data_loader(
# ranks would range from 0...11
# from data angle ranks should look like 0 0 0 1 1 1 2 2 2 3 3 3
# processes with same ranks/ids would receive the same batch
# for CP the same as TP applies
submesh_fsdp_size = 1
submesh_dp_size = 1
submesh_tp_size = 1
submesh_cp_size = 1
if "tp" in torch_device_mesh.mesh_dim_names:
submesh_tp_size = torch_device_mesh["tp"].size()
if "dp" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp"].size()
if "fsdp" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
process_index = process_index // submesh_tp_size
if "cp" in torch_device_mesh.mesh_dim_names:
submesh_cp_size = torch_device_mesh["cp"].size()
if "dp_replicate" in torch_device_mesh.mesh_dim_names:
submesh_dp_size = torch_device_mesh["dp_replicate"].size()
if "dp_shard" in torch_device_mesh.mesh_dim_names:
submesh_fsdp_size = torch_device_mesh["dp_shard"].size()
process_index = process_index // (submesh_tp_size * submesh_cp_size)
num_processes = submesh_fsdp_size * submesh_dp_size
# Sanity check

View File

@ -0,0 +1,268 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Union
from torch.distributed.device_mesh import init_device_mesh
from accelerate.utils.dataclasses import TorchTensorParallelConfig
if TYPE_CHECKING:
from accelerate import Accelerator
@dataclass
class ParallelismConfig:
"""
A dataclass to configure parallelisms applied to the model. Inspired by torchtitan's `ParallelDims`
https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py
Args:
dp_replicate_size (`int`, defaults to `1`):
The size of the data parallel group. If `dp_replicate_size` is set to 1, the data parallel replication
group will not be used.
dp_shard_size (`int`, defaults to `1`):
The size of the model shard group. If `dp_replicate_size > 1` and `tp_size > 1`, `dp_shard_size` must also
be greater than 1, as composing DDP + TP is currently not supported.
tp_size (`int`, defaults to `1`):
The size of the tensor parallel group. If `tp_size` is set to `1`, the tensor parallel group will not be
used.
cp_size (`int`, defaults to `1`):
The size of the context parallel group. Currently not supported, but reserved for future use and enabled
for downstream libraries.
tp_handler (`~utils.TorchTensorParallelConfig`, defaults to `None`):
The handler for the tensor parallel group.
You may obtain different distributed data parallel paradigms by configuring `dp_replicate_size` and `dp_shard_size`
together:
- `dp_replicate_size == 1` and `dp_shard_size > 1`, we obtain Fully Sharded Data Parallel (FSDP).
- `dp_replicate_size > 1` and `dp_shard_size > 1`, we obtain Hybrid Sharded Data Parallel (HSDP).
- `dp_replicate_size > 1` and `dp_shard_size == 1` is an invalid configuration, to use pure DP, use
`DistributedDataParallelKwargs` instead.
"""
dp_replicate_size: int = 1
dp_shard_size: int = 1
tp_size: int = 1
cp_size: int = 1
# we use Union because we might support other x parallel plugins (i.e. deepspeed, etc)
tp_handler: Union[None, TorchTensorParallelConfig] = None
def __repr__(self):
return (
"ParallelismConfig(\n "
f"\tdp_replicate_size={self.dp_replicate_size},\n"
f"\tdp_shard_size={self.dp_shard_size},\n"
f"\ttp_size={self.tp_size},\n"
f"\tcp_size={self.cp_size},\n"
f"\ttotal_size={self.total_size}\n)"
)
@property
def dp_dim_names(self):
"""Names of enabled dimensions across which data parallelism is applied."""
dims = []
if self.dp_replicate_enabled:
dims += ["dp_replicate"]
if self.dp_shard_enabled:
dims += ["dp_shard"]
return dims
@property
def non_dp_dim_names(self):
"""Names of enabled dimensions which will receive the same batch (non-data parallel dimensions)."""
dims = []
if self.tp_enabled:
dims += ["tp"]
if self.cp_enabled:
dims += ["cp"]
return dims
@property
def dp_shard_cp_dim_names(self):
"""Names of enabled dimensions which will be flattened into a joint mesh across which is model sharded in FSDP."""
dims = []
if self.dp_shard_enabled:
dims += ["dp_shard"]
if self.cp_enabled:
dims += ["cp"]
return dims
@property
def dp_cp_dim_names(self):
"""Names of enabled dimensions across which loss should be averaged"""
dims = []
if self.dp_replicate_enabled:
dims += ["dp_replicate"]
if self.dp_shard_enabled:
dims += ["dp_shard"]
if self.cp_enabled:
dims += ["cp"]
return dims
@property
def fsdp_dim_names(self):
"""Names of enabled dimensions across which FSDP is applied, including data parallel replication."""
dims = []
if self.dp_replicate_enabled:
dims += ["dp_replicate"]
dims += ["dp_shard_cp"]
return dims
@property
def total_size(self):
"""The total size of the parallelism configuration, which is the product of all sizes."""
return self.dp_replicate_size * self.dp_shard_size * self.tp_size * self.cp_size
@property
def non_data_parallel_size(self):
"""The size of the non-data parallel dimensions, which is the product of tensor and context parallel sizes."""
return self.tp_size * self.cp_size
@property
def data_parallel_size(self):
"""The size of the data parallel dimensions, which is the product of data parallel replication and"""
return self.dp_replicate_size * self.dp_shard_size
@property
def dp_replicate_enabled(self):
"""True if data parallel replication is enabled, i.e. `dp_replicate_size > 1`."""
return self.dp_replicate_size > 1
@property
def dp_shard_enabled(self):
"""True if data parallel sharding is enabled, i.e. `dp_shard_size > 1`."""
return self.dp_shard_size > 1
@property
def tp_enabled(self):
"""True if tensor parallelism is enabled, i.e. `tp_size > 1`."""
return self.tp_size > 1
@property
def cp_enabled(self):
"""True if context parallelism is enabled, i.e. `cp_size > 1`."""
return self.cp_size > 1
@property
def active_mesh_dims(self):
"""Names of all active mesh dimensions."""
return self.dp_dim_names + self.non_dp_dim_names
def build_device_mesh(self, device_type: str):
"""Builds a device mesh for the given device type based on the parallelism configuration.
This method will also create required joint meshes (e.g. `dp_shard_cp`, `dp_cp`, `dp`).
Args:
device_type (`str`): The type of device for which to build the mesh, e
"""
mesh = self._get_mesh()
if len(mesh) == 0:
return
mesh_dim_names, mesh_shape = mesh
device_mesh = init_device_mesh(
device_type,
mesh_shape,
mesh_dim_names=mesh_dim_names,
)
if self.dp_dim_names:
device_mesh[self.dp_dim_names]._flatten("dp")
if self.dp_shard_cp_dim_names:
device_mesh[self.dp_shard_cp_dim_names]._flatten("dp_shard_cp")
if self.dp_cp_dim_names:
device_mesh[self.dp_cp_dim_names]._flatten("dp_cp")
return device_mesh
def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]:
"""Generate mesh shape and dimension names for torch.distributed.init_device_mesh()."""
# Build mesh dimensions dictionary
mesh_dims = {parallelism: self._sizes[parallelism] for parallelism in self.active_mesh_dims}
# Apply canonical ordering
mesh_order = ["dp_replicate", "dp_shard", "cp", "tp"]
sorted_items = sorted(
mesh_dims.items(),
key=lambda x: (mesh_order.index(x[0])),
)
return tuple(zip(*sorted_items))
def __post_init__(self):
# Basic size validation
if self.dp_replicate_size < 1:
raise ValueError(f"dp_replicate_size must be at least 1, but got {self.dp_replicate_size}")
if self.dp_shard_size < 1:
raise ValueError(f"dp_shard_size must be at least 1, but got {self.dp_shard_size}")
if self.tp_size < 1:
raise ValueError(f"tp_size must be at least 1, but got {self.tp_size}")
if self.cp_size < 1:
raise ValueError(f"cp_size must be at least 1, but got {self.cp_size}")
if (self.tp_size > 1 or self.cp_size > 1) and self.dp_replicate_size > 1 and self.dp_shard_size == 1:
raise ValueError(
"Tensor/Context parallelism (tp/cp_size > 1) cannot be used with pure data parallelism (dp_replicate_size > 1 and dp_shard_size == 1). "
"Please set dp_shard_size > 1 and dp_replicate_size == 1 to compose FSDP + TP/CP for 2D parallel, "
"or set dp_replicate_size == 1 and dp_shard_size > 1 to compose HSDP + TP/CP for 3D parallel."
)
self._sizes = {
"dp_replicate": self.dp_replicate_size,
"dp_shard": self.dp_shard_size,
"tp": self.tp_size,
"cp": self.cp_size,
}
def _set_size(self, parallelism: str, size: int):
assert parallelism in self._sizes.keys(), f"Parallelism must be one of {self._sizes.keys()}"
self._sizes[parallelism] = size
setattr(self, f"{parallelism}_size", size)
def _validate_accelerator(self, accelerator: "Accelerator"):
_warnings = set()
if not accelerator.multi_device and self.total_size == 1:
# No distributed setup, valid parallelism config
return
# We need this to ensure DDP works
if self.total_size == 1:
self._set_size("dp_replicate", accelerator.num_processes)
if self.total_size != accelerator.num_processes:
raise ValueError(
f"ParallelismConfig total_size ({self.total_size}) does not match "
f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
f"dp_shard_size/tp_size/cp_size."
)
if self.total_size > 1 and not (accelerator.is_fsdp2 or accelerator.multi_device):
raise ValueError(
f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}."
)
for parallelism, size in self._sizes.items():
if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
_warnings.add(
f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
)
if _warnings and accelerator.is_main_process:
warnings.warn(
"ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
UserWarning,
)

View File

@ -180,6 +180,8 @@ class PartialState:
if not self.initialized:
self._cpu = cpu
self.backend = None
self.parallelism_config = kwargs.pop("parallelism_config", None)
self.device_mesh = kwargs.pop("device_mesh", None)
env_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None)
self.device = torch.device(env_device) if env_device is not None else None
self.debug = parse_flag_from_env("ACCELERATE_DEBUG_MODE")
@ -869,6 +871,8 @@ class AcceleratorState:
- **device** (`torch.device`) -- The device to use.
- **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
in use.
- **parallelism_config** ([`~accelerate.utils.ParallelismConfig`]) -- The parallelism configuration for the
current training environment. This is used to configure the distributed training environment.
- **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`.
- **local_process_index** (`int`) -- The index of the current process on the current server.
- **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
@ -988,8 +992,6 @@ class AcceleratorState:
self.distributed_type = DistributedType.MEGATRON_LM
megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
self.megatron_lm_plugin = megatron_lm_plugin
if self.torch_tp_plugin is not None:
self.distributed_type = DistributedType.TP
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
if is_ipex_available():
# check if user disables it explicitly

View File

@ -25,7 +25,8 @@ from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
from accelerate import Accelerator, DistributedType
from accelerate.utils import SAFE_WEIGHTS_NAME, TorchTensorParallelPlugin, set_seed
from accelerate.parallelism_config import ParallelismConfig
from accelerate.utils import SAFE_WEIGHTS_NAME, set_seed
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
@ -83,7 +84,7 @@ def training_function(config, args):
accelerator_kwargs = {}
# need this for DeepSpeed tests as `args.tp_size` would be None and `torch.distributed.init_device_mesh` would fail
if args.tp_size is not None:
accelerator_kwargs["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=args.tp_size)
accelerator_kwargs["parallelism_config"] = ParallelismConfig(tp_size=args.tp_size)
# Initialize accelerator
accelerator = Accelerator(**accelerator_kwargs)

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..parallelism_config import ParallelismConfig
from .ao import convert_model_to_fp8_ao, filter_first_and_last_linear_layers, has_ao_layers
from .constants import (
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION,
@ -61,6 +62,7 @@ from .dataclasses import (
TensorInformation,
TERecipeKwargs,
TorchDynamoPlugin,
TorchTensorParallelConfig,
TorchTensorParallelPlugin,
add_model_config_to_megatron_parser,
)

View File

@ -33,6 +33,7 @@ import torch
from .constants import (
BETA_TP_AVAILABLE_PYTORCH_VERSION,
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION,
FSDP2_PYTORCH_VERSION,
FSDP_AUTO_WRAP_POLICY,
FSDP_BACKWARD_PREFETCH,
@ -58,6 +59,7 @@ if TYPE_CHECKING:
# Mock imports for type checking
from torchao.float8 import Float8LinearConfig
logger = logging.getLogger(__name__)
@ -184,7 +186,9 @@ class DistributedDataParallelKwargs(KwargsHandler):
comm_hook: DDPCommunicationHookType = DDPCommunicationHookType.NO
comm_wrapper: Literal[
DDPCommunicationHookType.NO, DDPCommunicationHookType.FP16, DDPCommunicationHookType.BF16
DDPCommunicationHookType.NO,
DDPCommunicationHookType.FP16,
DDPCommunicationHookType.BF16,
] = DDPCommunicationHookType.NO
comm_state_option: dict = field(default_factory=dict)
@ -192,7 +196,10 @@ class DistributedDataParallelKwargs(KwargsHandler):
return {k: v for k, v in super().to_dict().items() if k not in ignore_keys}
def register_comm_hook(self, model):
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks, powerSGD_hook
from torch.distributed.algorithms.ddp_comm_hooks import (
default_hooks,
powerSGD_hook,
)
hook_map: dict[DDPCommunicationHookType, Callable] = {
DDPCommunicationHookType.FP16: default_hooks.fp16_compress_hook,
@ -215,7 +222,11 @@ class DistributedDataParallelKwargs(KwargsHandler):
if hook:
state = (
powerSGD_hook.PowerSGDState(None, **self.comm_state_option)
if self.comm_hook in (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.BATCHED_POWER_SGD)
if self.comm_hook
in (
DDPCommunicationHookType.POWER_SGD,
DDPCommunicationHookType.BATCHED_POWER_SGD,
)
else None
)
model.register_comm_hook(
@ -582,7 +593,6 @@ class DistributedType(str, enum.Enum):
MULTI_XPU = "MULTI_XPU"
DEEPSPEED = "DEEPSPEED"
FSDP = "FSDP"
TP = "TP"
XLA = "XLA"
MEGATRON_LM = "MEGATRON_LM"
MULTI_HPU = "MULTI_HPU"
@ -955,7 +965,10 @@ class GradientAccumulationPlugin(KwargsHandler):
```
"""
num_steps: int = field(default=None, metadata={"help": "The number of steps to accumulate gradients for."})
num_steps: int = field(
default=None,
metadata={"help": "The number of steps to accumulate gradients for."},
)
adjust_scheduler: bool = field(
default=True,
metadata={
@ -1006,12 +1019,22 @@ class TorchDynamoPlugin(KwargsHandler):
metadata={"help": f"Possible options are {[b.value.lower() for b in DynamoBackend]}"},
)
mode: str = field(
default=None, metadata={"help": "Possible options are 'default', 'reduce-overhead' or 'max-autotune'"}
default=None,
metadata={"help": "Possible options are 'default', 'reduce-overhead' or 'max-autotune'"},
)
fullgraph: bool = field(
default=None,
metadata={"help": "Whether it is ok to break model into several subgraphs"},
)
fullgraph: bool = field(default=None, metadata={"help": "Whether it is ok to break model into several subgraphs"})
dynamic: bool = field(default=None, metadata={"help": "Whether to use dynamic shape for tracing"})
options: Any = field(default=None, metadata={"help": "A dictionary of options to pass to the backend."})
disable: bool = field(default=False, metadata={"help": "Turn torch.compile() into a no-op for testing"})
options: Any = field(
default=None,
metadata={"help": "A dictionary of options to pass to the backend."},
)
disable: bool = field(
default=False,
metadata={"help": "Turn torch.compile() into a no-op for testing"},
)
use_regional_compilation: bool = field(
default=None,
@ -1243,13 +1266,13 @@ class DeepSpeedPlugin:
"stage": self.zero_stage,
"offload_optimizer": {
"device": self.offload_optimizer_device,
"nvme_path": self.offload_optimizer_nvme_path
if self.offload_optimizer_device == "nvme"
else None,
"nvme_path": (
self.offload_optimizer_nvme_path if self.offload_optimizer_device == "nvme" else None
),
},
"offload_param": {
"device": self.offload_param_device,
"nvme_path": self.offload_param_nvme_path if self.offload_param_device == "nvme" else None,
"nvme_path": (self.offload_param_nvme_path if self.offload_param_device == "nvme" else None),
},
"stage3_gather_16bit_weights_on_model_save": self.zero3_save_16bit_model,
},
@ -1262,7 +1285,13 @@ class DeepSpeedPlugin:
self.deepspeed_config["steps_per_print"] = float("inf") # this will stop deepspeed from logging @ stdout
if self.zero3_init_flag is None:
self.zero3_init_flag = (
str_to_bool(os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_INIT", str(self.hf_ds_config.is_zero3()))) == 1
str_to_bool(
os.environ.get(
"ACCELERATE_DEEPSPEED_ZERO3_INIT",
str(self.hf_ds_config.is_zero3()),
)
)
== 1
)
if self.zero3_init_flag and not self.hf_ds_config.is_zero3():
warnings.warn("DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.")
@ -1279,7 +1308,10 @@ class DeepSpeedPlugin:
)
if self.msamp_opt_level not in ["O1", "O2"]:
raise ValueError("Invalid optimization level for MS-AMP. Please use one of ['O1' or'O2'].")
self.deepspeed_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level}
self.deepspeed_config["msamp"] = {
"enabled": True,
"opt_level": self.msamp_opt_level,
}
def fill_match(self, ds_key_long, mismatches=None, must_match=True, **kwargs):
mismatches = [] if mismatches is None else mismatches
@ -1324,7 +1356,11 @@ class DeepSpeedPlugin:
for key, value in config.items():
if isinstance(value, dict):
self.deepspeed_config_process(
prefix=prefix + key + ".", mismatches=mismatches, config=value, must_match=must_match, **kwargs
prefix=prefix + key + ".",
mismatches=mismatches,
config=value,
must_match=must_match,
**kwargs,
)
else:
self.fill_match(prefix + key, mismatches, must_match=must_match, **kwargs)
@ -1351,7 +1387,10 @@ class DeepSpeedPlugin:
if mixed_precision == "fp8" and self.enable_msamp:
if "msamp" not in ds_config:
ds_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level}
ds_config["msamp"] = {
"enabled": True,
"opt_level": self.msamp_opt_level,
}
if mixed_precision != "no":
diff_dtype = "bf16" if mixed_precision == "fp16" else "fp16"
@ -1383,9 +1422,15 @@ class DeepSpeedPlugin:
del ds_config["train_batch_size"]
if compare_versions("transformers", "<", "4.46"):
from transformers.deepspeed import HfDeepSpeedConfig, unset_hf_deepspeed_config
from transformers.deepspeed import (
HfDeepSpeedConfig,
unset_hf_deepspeed_config,
)
else:
from transformers.integrations import HfDeepSpeedConfig, unset_hf_deepspeed_config
from transformers.integrations import (
HfDeepSpeedConfig,
unset_hf_deepspeed_config,
)
unset_hf_deepspeed_config()
self.dschf = HfDeepSpeedConfig(ds_config) # keep this object alive # noqa
@ -1583,7 +1628,11 @@ class FullyShardedDataParallelPlugin:
},
)
mixed_precision_policy: Optional[
Union[dict, "torch.distributed.fsdp.MixedPrecision", "torch.distributed.fsdp.MixedPrecisionPolicy"]
Union[
dict,
"torch.distributed.fsdp.MixedPrecision",
"torch.distributed.fsdp.MixedPrecisionPolicy",
]
] = field(
default=None,
metadata={
@ -1601,7 +1650,11 @@ class FullyShardedDataParallelPlugin:
},
)
)
cpu_offload: Union[bool, "torch.distributed.fsdp.CPUOffload", "torch.distributed.fsdp.CPUOffloadPolicy"] = field(
cpu_offload: Union[
bool,
"torch.distributed.fsdp.CPUOffload",
"torch.distributed.fsdp.CPUOffloadPolicy",
] = field(
default=None,
metadata={
"help": "Whether to offload parameters to CPU. Should be either a `bool` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. Defaults to `False`"
@ -1628,7 +1681,10 @@ class FullyShardedDataParallelPlugin:
metadata={"help": "State dict config to use. Is determined based on the `state_dict_type` if not passed in."},
)
optim_state_dict_config: Optional[
Union["torch.distributed.fsdp.FullOptimStateDictConfig", "torch.distributed.fsdp.ShardedOptimStateDictConfig"]
Union[
"torch.distributed.fsdp.FullOptimStateDictConfig",
"torch.distributed.fsdp.ShardedOptimStateDictConfig",
]
] = field(
default=None,
metadata={
@ -1701,10 +1757,7 @@ class FullyShardedDataParallelPlugin:
)
def __post_init__(self):
from torch.distributed.fsdp import (
BackwardPrefetch,
ShardingStrategy,
)
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
_fsdp2_warnings = set()
@ -1738,7 +1791,8 @@ class FullyShardedDataParallelPlugin:
# Fallback to `reshard_after_forward` in FSDP1 if `sharding_strategy` is not set
if self.reshard_after_forward is None and self.sharding_strategy is None:
reshard_after_forward = os.environ.get(
env_prefix + "RESHARD_AFTER_FORWARD", "true" if self.fsdp_version == 2 else "FULL_SHARD"
env_prefix + "RESHARD_AFTER_FORWARD",
"true" if self.fsdp_version == 2 else "FULL_SHARD",
)
if self.fsdp_version == 2:
self.reshard_after_forward = str_to_bool(reshard_after_forward.lower(), to_bool=True)
@ -1795,7 +1849,10 @@ class FullyShardedDataParallelPlugin:
raise ValueError(
f"Invalid auto wrap policy: {self.auto_wrap_policy}. Must be one of {FSDP_AUTO_WRAP_POLICY}"
)
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
if self.auto_wrap_policy.upper() == "TRANSFORMER_BASED_WRAP":
self.auto_wrap_policy = transformer_auto_wrap_policy
@ -1910,7 +1967,8 @@ class FullyShardedDataParallelPlugin:
if self.state_dict_type is None:
self.state_dict_type = os.environ.get(
"FSDP_STATE_DICT_TYPE", "FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT"
"FSDP_STATE_DICT_TYPE",
"FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT",
)
if isinstance(self.state_dict_type, str):
if self.state_dict_type.isdigit():
@ -1940,7 +1998,10 @@ class FullyShardedDataParallelPlugin:
Given `model`, creates an `auto_wrap_policy` baesd on the passed in policy and if we can use the
`transformer_cls_to_wrap`
"""
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
# First base off of `_no_split_modules`
no_split_modules = getattr(model, "_no_split_modules", None)
@ -2077,35 +2138,30 @@ class TorchTensorParallelPlugin:
metadata={"help": "tensor parallel size will be used in the device mesh preparation"},
)
# torch_device_mesh is fo type "torch.distributed.DeviceMesh"
# torch_device_mesh is of type "torch.distributed.DeviceMesh"
torch_device_mesh: Optional["torch.distributed.DeviceMesh"] = field(default=None)
@dataclass
class TorchTensorParallelConfig:
"""
Use this object in your [`Accelerator`] to customize your torch tensor parallelism.
"""
enable_async_tp: bool = False
def __post_init__(self):
if not isinstance(self.tp_size, int):
raise ValueError(f"`tp_size` set to {self.tp_size}, please set to an `int`.")
if self.tp_size <= 1:
raise ValueError("`tp_size` must be greater than 1.")
if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION):
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
raise ValueError(
f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel."
f"Torch tensor parallelism is only available in PyTorch {BETA_TP_AVAILABLE_PYTORCH_VERSION} and later versions. "
"Please upgrade your PyTorch version."
)
from torch.distributed.device_mesh import init_device_mesh
# support for other devices has to be investigated
if is_hpu_available(init_hccl=True):
device = "hpu"
elif is_xpu_available():
device = "xpu"
else:
device = "cuda"
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
mesh_dim_name = "tp"
# device mesh is not used for model sharding
# it is only used for preparing data loader
self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,))
if self.enable_async_tp:
warnings.warn("Async tensor parallelism is currently not supported, ignoring this option.")
@dataclass
@ -2209,7 +2265,8 @@ class MegatronLMPlugin:
pp_degree: int = field(default=None, metadata={"help": "pipeline parallelism degree."})
num_micro_batches: int = field(default=None, metadata={"help": "number of micro-batches."})
gradient_clipping: float = field(
default=None, metadata={"help": "gradient clipping value based on global L2 Norm (0 to disable)"}
default=None,
metadata={"help": "gradient clipping value based on global L2 Norm (0 to disable)"},
)
sequence_parallelism: bool = field(
default=None,
@ -2224,7 +2281,8 @@ class MegatronLMPlugin:
metadata={"help": "enable distributed optimizer"},
)
pipeline_model_parallel_split_rank: int = field(
default=None, metadata={"help": "Rank where encoder and decoder should be split."}
default=None,
metadata={"help": "Rank where encoder and decoder should be split."},
)
num_layers_per_virtual_pipeline_stage: int = field(
default=None, metadata={"help": "Number of layers per virtual pipeline stage."}
@ -2321,10 +2379,12 @@ class MegatronLMPlugin:
metadata={"help": "Whether to set all logging options."},
)
eval_iters: int = field(
default=100, metadata={"help": "Number of iterations to run for evaluation validation/test for."}
default=100,
metadata={"help": "Number of iterations to run for evaluation validation/test for."},
)
eval_interval: int = field(
default=1000, metadata={"help": "Interval between running evaluation on validation set."}
default=1000,
metadata={"help": "Interval between running evaluation on validation set."},
)
return_logits: bool = field(
default=False,
@ -2691,7 +2751,8 @@ class BnbQuantizationConfig:
load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."})
llm_int8_threshold: float = field(
default=6.0, metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"}
default=6.0,
metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"},
)
load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."})

View File

@ -548,6 +548,11 @@ def fsdp2_switch_optimizer_parameters(optimizer: torch.optim.Optimizer, mapping:
indicates a bug. If we kept the original params instead of raising, the training wouldn't be numerically
correct and weights wouldn't get updated.
"""
from torch.distributed.tensor import DTensor
accessor_mapping = {}
accessor_mapping[DTensor] = "_local_tensor"
try:
for param_group in optimizer.param_groups:
param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
@ -615,12 +620,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
fsdp2_plugin.set_auto_wrap_policy(model)
original_sd = model.state_dict()
mesh = getattr(accelerator, "torch_device_mesh", None)
fsdp2_kwargs = {
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
"offload_policy": fsdp2_plugin.cpu_offload,
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
"mesh": mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None,
}
model_has_params4bit = False

View File

@ -2097,7 +2097,6 @@ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwarg
DistributedType.MULTI_HPU,
DistributedType.FSDP,
DistributedType.XLA,
DistributedType.TP,
]:
return torch.autocast(device_type=device_type, dtype=torch.bfloat16, **autocast_kwargs)
else:

126
tests/test_dataclasses.py Normal file
View File

@ -0,0 +1,126 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock, patch
import pytest
from accelerate.parallelism_config import ParallelismConfig
class TestParallelismConfig:
@pytest.fixture(autouse=True)
def mock_init_device_mesh(self):
def mock_init_mesh(device_type, mesh_shape, mesh_dim_names):
mesh = Mock()
mesh.size.return_value = 1
for dim in mesh_shape:
mesh.size.return_value *= dim
mesh.shape = mesh_shape
mesh.mesh_dim_names = mesh_dim_names
# mock device_mesh._flatten
mesh.flattened_dims = []
def mock_getitem(key):
submesh = Mock()
def mock_flatten(name):
mesh.flattened_dims.append((key, name))
submesh._flatten = Mock(side_effect=mock_flatten)
return submesh
mesh.__getitem__ = Mock(side_effect=mock_getitem)
return mesh
with patch("accelerate.parallelism_config.init_device_mesh", side_effect=mock_init_mesh):
yield mock_init_mesh
@pytest.mark.parametrize(
"dp_replicate_size, dp_shard_size, tp_size, cp_size, expected_shape, expected_dim_names",
[
(8, 1, 1, 1, (8,), ("dp_replicate",)), # DDP
(1, 8, 1, 1, (8,), ("dp_shard",)), # FSDP
(2, 4, 1, 1, (2, 4), ("dp_replicate", "dp_shard")), # HSDP
(1, 4, 2, 1, (4, 2), ("dp_shard", "tp")), # FSDP + TP
(2, 2, 2, 1, (2, 2, 2), ("dp_replicate", "dp_shard", "tp")), # HSDP + TP
(1, 1, 8, 1, (8,), ("tp",)), # TP only
(1, 1, 1, 4, (4,), ("cp",)), # CP only
(1, 4, 1, 2, (4, 2), ("dp_shard", "cp")), # FSDP + CP
(1, 2, 2, 2, (2, 2, 2), ("dp_shard", "cp", "tp")), # FSDP + CP + TP
(2, 2, 2, 2, (2, 2, 2, 2), ("dp_replicate", "dp_shard", "cp", "tp")), # HSDP + CP + TP
],
)
def test_get_mesh(
self,
dp_replicate_size,
dp_shard_size,
tp_size,
cp_size,
expected_shape,
expected_dim_names,
):
config = ParallelismConfig(
dp_replicate_size=dp_replicate_size, dp_shard_size=dp_shard_size, tp_size=tp_size, cp_size=cp_size
)
mesh_dim_names, mesh_shape = config._get_mesh()
assert mesh_shape == expected_shape
assert mesh_dim_names == expected_dim_names
@pytest.mark.parametrize(
"dp_replicate_size, dp_shard_size, tp_size, cp_size, expected_shape, expected_dim_names",
[
(8, 1, 1, 1, (8,), ("dp_replicate",)),
(1, 8, 1, 1, (8,), ("dp_shard",)),
(2, 4, 1, 1, (2, 4), ("dp_replicate", "dp_shard")),
(1, 4, 2, 1, (4, 2), ("dp_shard", "tp")),
(2, 2, 2, 1, (2, 2, 2), ("dp_replicate", "dp_shard", "tp")),
(1, 1, 8, 1, (8,), ("tp",)),
(1, 1, 1, 4, (4,), ("cp",)),
(1, 4, 1, 2, (4, 2), ("dp_shard", "cp")),
(1, 2, 2, 2, (2, 2, 2), ("dp_shard", "cp", "tp")),
(2, 2, 2, 2, (2, 2, 2, 2), ("dp_replicate", "dp_shard", "cp", "tp")),
],
)
def test_build_device_mesh(
self,
dp_replicate_size,
dp_shard_size,
tp_size,
cp_size,
expected_shape,
expected_dim_names,
):
"""Test build_device_mesh creates correct mesh and applies flattening."""
config = ParallelismConfig(
dp_replicate_size=dp_replicate_size, dp_shard_size=dp_shard_size, tp_size=tp_size, cp_size=cp_size
)
device_mesh = config.build_device_mesh("cpu")
# Check mesh shape and dimension names match expected
assert device_mesh.shape == expected_shape
assert device_mesh.mesh_dim_names == expected_dim_names
# Check that correct flattening operations were called
expected_flattened = []
if config.dp_dim_names:
expected_flattened.append((config.dp_dim_names, "dp"))
if config.dp_shard_cp_dim_names:
expected_flattened.append((config.dp_shard_cp_dim_names, "dp_shard_cp"))
if config.dp_cp_dim_names:
expected_flattened.append((config.dp_cp_dim_names, "dp_cp"))
assert device_mesh.flattened_dims == expected_flattened