mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
Parallelism config + TP + HSDP + BYODM (Bring Your Own Device Mesh) (#3682)
* Feat: init * Feat: add validation + init from kwargs * Fix: minor fixes * Feat: more cleanup * Minor refactor * remove import * adding support for pre-configured device mesh * adding device mesh to fsdp2 * moving mesh dim defn to parralismconfig * tests * WIP device mesh/accelerator validation * WIP more tests * Test Driven Development (TDD) * fixing build_device_mesh * FSDP dim names * adding example * WIP * fixing HSDP * Feat: add back old options * working example * debugging * adding parallelism config to partialstate * Feat: revert ddp changes * Revert DDP * Feat: (untested) update mesh dims and some minor tweaks * adding dp_cp dims * updating comments * WIP * wip 2 * reverting * storing state in accelerator rather than acceleratorstate * Fix: minor tweaks * wip example update * Fixes for non-fsdp2 case * Feat: ensure ddp/tp only works * updating example * updating example * updating examples, fixing state * fixed state * comments * fixing partial state check * linting * comments * removing fn * WIP: fix tp * comments * removing return * reverting upcast * add guards * guards for empty self.parallelism_config * use len on tuple to check if empty * Feat: cleanup example * Feat: some cleanup of example * Feat: add trackio * Fix: improve trackio * Feat: TP works * Feat: some fsdp2 improv * Feat: working examples * handle clipping for tensor parallel * Implicit replicate * Refactor: move to separate file + cleanup + basic comments * Fix: add unadded files, fix circular import * Feat: better readme * Feat: add blog + ultrascale links * Tmp: should_save_model now returns only true * Fix: remove implicit_replication and style * Fix: remove optional * add guard on parallelism_config.tp_enabled * fix import * fixing empty parallelism_config * fix import path for test patch * fixing patch --------- Co-authored-by: S1ro1 <matej.sirovatka@gmail.com> Co-authored-by: Salman Mohammadi <“salman.mohammadi@outlook.com”> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@ -2,6 +2,31 @@
|
||||
|
||||
This folder contains examples of using FSDP2 with Accelerate, utilizing extra methods to improve training speed, performance or accuracy.
|
||||
|
||||
### FSDP2 + ND Parallelism
|
||||
|
||||
With `ParallelismConfig`, you can use 🤗 accelerate to train models with n-dimensional parallelism. This builds on top of 🤗 transformers, which we utilize for tensor parallelism sharding.
|
||||
Accelerate then takes care of everything else, such as data parallelism, FSDP, and more to come.
|
||||
Script `nd_parallel.py` showcases just how you can do it. We enable you to configure 3 different parallel dimensions (for now 👀):
|
||||
- dp_replicate_size: how many replicas of the model to create, each replica is trained on a different subset of the data and averaged at the end of each step, same as DDP in Torch
|
||||
- dp_shard_size: across how many devices is the model sharded, this is utilizing FSDP2 to shard the model across devices, so each device has a different part of the model
|
||||
- tp_size: how many devices to use for tensor parallelism, this is utilizing the tensor parallelism from 🤗 transformers
|
||||
|
||||
For example, with 8 nodes, you can run the script as such:
|
||||
```bash
|
||||
accelerate launch --num-processes 8 nd_parallel.py \
|
||||
--dp-replicate-size 2 \
|
||||
--dp-shard-size 2 \
|
||||
--tp-size 2 \
|
||||
```
|
||||
|
||||
<Tip>
|
||||
Only use TP intra-node - therefore max TP size you should need is 8, you can also lower this as FSDP (`--dp-shard-size`) can be faster on smaller models with
|
||||
shorter sequence lengths. If you still cannot fit into memory, utilize `--dp-shard-size` as much as you can. Then to scale up to utilize all your GPUs, fill the rest
|
||||
with `--dp-replicate-size`. This is only a general guideline, you can (and should) experiment with different parallelism configurations to find the best one for your model and hardware. You can learn more about the general strategies for parallelism in our [blog](TODO) or if you wanna dive deep, read the [Ultra-Scale Playbook](https://huggingface.co/spaces/nanotron/ultrascale-playbook).
|
||||
</Tip>
|
||||
|
||||
We plan to add more parallelisms in the future, with context parallelism coming soon and pipeline parallelism being planned.
|
||||
|
||||
### FSDP2 + ao Float8Linear
|
||||
|
||||
In file `fsdp2_fp8.py` we use `Float8Linear` from `ao` to train a model partially in FP8 precision. We utilize `AORecipeKwargs` to pass the `Float8LinearConfig` to the accelerator,
|
||||
@ -34,3 +59,4 @@ The figures above were generated on 8x H100 SXM GPUs, with 8192 sequence length
|
||||
```bash
|
||||
accelerate launch fsdp2_fp8.py --sequence-length 8192 --num-steps 1000 --log_with wandb --precision [fp8 | bf16]
|
||||
```
|
||||
|
||||
|
155
examples/fsdp2/nd_parallel.py
Normal file
155
examples/fsdp2/nd_parallel.py
Normal file
@ -0,0 +1,155 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example of training with ND parallel using accelerate's ParallelismConfig
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.parallelism_config import ParallelismConfig
|
||||
from accelerate.utils import FullyShardedDataParallelPlugin, set_seed
|
||||
from utils import (
|
||||
PerformanceTracker,
|
||||
create_collate_fn,
|
||||
get_dataset,
|
||||
setup_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dp-replicate-size", type=int, default=1)
|
||||
parser.add_argument("--dp-shard-size", type=int, default=1)
|
||||
parser.add_argument("--tp-size", type=int, default=1)
|
||||
parser.add_argument("--sequence-length", type=int, default=1024)
|
||||
parser.add_argument("--num-steps", type=int, default=1000)
|
||||
parser.add_argument("--save-dir", type=str, default="./outputs")
|
||||
parser.add_argument("--checkpoint-frequency", type=int, default=100)
|
||||
parser.add_argument("--model-name", type=str, default=MODEL_ID)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def forward(model, batch, optimizer, accelerator):
|
||||
# To get the proper loss value, we need to average across devices that are participating in data parallel/context parallel training
|
||||
loss_reduce_grp = (
|
||||
accelerator.torch_device_mesh["dp_cp"].get_group() if accelerator.parallelism_config.dp_cp_dim_names else None
|
||||
)
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
dist.all_reduce(loss, op=dist.ReduceOp.AVG, group=loss_reduce_grp)
|
||||
return loss
|
||||
|
||||
|
||||
def train(args):
|
||||
parallelism_config = ParallelismConfig(
|
||||
dp_replicate_size=args.dp_replicate_size,
|
||||
dp_shard_size=args.dp_shard_size,
|
||||
tp_size=args.tp_size,
|
||||
)
|
||||
|
||||
# FSDP needs extra configuration, so we properly shard the model
|
||||
if parallelism_config.dp_shard_enabled:
|
||||
fsdp2_plugin = FullyShardedDataParallelPlugin(
|
||||
fsdp_version=2,
|
||||
auto_wrap_policy="transformer_based_wrap",
|
||||
transformer_cls_names_to_wrap=["LlamaDecoderLayer"],
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
log_with=["wandb"],
|
||||
mixed_precision="bf16",
|
||||
parallelism_config=parallelism_config,
|
||||
fsdp_plugin=fsdp2_plugin if parallelism_config.dp_shard_enabled else None,
|
||||
)
|
||||
accelerator.init_trackers("nd_parallel_training")
|
||||
|
||||
# If TP was enabled, we need to tell transformers to prepare the model for us
|
||||
model_kwargs = (
|
||||
{"tp_size": args.tp_size, "tp_plan": "auto", "device_mesh": accelerator.torch_device_mesh}
|
||||
if args.tp_size > 1
|
||||
else {}
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
use_cache=False,
|
||||
**model_kwargs,
|
||||
)
|
||||
tokenizer = setup_tokenizer(args.model_name)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)
|
||||
dataset = get_dataset(accelerator, tokenizer, args.sequence_length)
|
||||
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
total_num_steps = min(args.num_steps, len(dataloader))
|
||||
performance_tracker = PerformanceTracker(warmup_steps=5)
|
||||
|
||||
accelerator.print("Starting training...")
|
||||
for step, batch in enumerate(dataloader):
|
||||
if step >= total_num_steps:
|
||||
break
|
||||
|
||||
loss = forward(model, batch, optimizer, accelerator)
|
||||
|
||||
# We report TPS per device, so we divide by the number of devices in the non-data parallel dimension
|
||||
metrics = performance_tracker.step(batch["input_ids"].shape[1] / parallelism_config.non_data_parallel_size)
|
||||
|
||||
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
|
||||
if "warmup_completed" in metrics:
|
||||
accelerator.print("Warm up completed! Starting performance tracking...")
|
||||
elif metrics:
|
||||
print_msg += performance_tracker.get_print_message(metrics, with_memory=True)
|
||||
|
||||
if step % 10 == 0 or step == total_num_steps - 1:
|
||||
accelerator.print(print_msg)
|
||||
|
||||
if step % args.checkpoint_frequency == 0 and step > 0 and parallelism_config.dp_shard_enabled:
|
||||
accelerator.print(f"Saving checkpoint at step {step}...")
|
||||
accelerator.save_state(args.save_dir + f"/checkpoint-{step}")
|
||||
|
||||
accelerator.log({"loss": loss.item()})
|
||||
|
||||
accelerator.print("Training completed!")
|
||||
|
||||
model.save_pretrained(args.save_dir + f"/{args.model_name}")
|
||||
accelerator.print(f"Model saved to {args.save_dir}/{args.model_name}")
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
set_seed(42)
|
||||
args = parse_args()
|
||||
if args.dp_shard_size == 1:
|
||||
# We currently don't support saving with `save_state` when using only
|
||||
# tensor parallelism, fsdp must be enabled
|
||||
warnings.warn(
|
||||
"Accelerator.save_state() is not yet supported with pure tensor parallel training. Training will work, but intermediate checkpoints will not be saved."
|
||||
)
|
||||
train(args)
|
187
examples/fsdp2/nd_parallel_prepared_device_mesh.py
Normal file
187
examples/fsdp2/nd_parallel_prepared_device_mesh.py
Normal file
@ -0,0 +1,187 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Example of training with ND parallel using accelerate's ParallelismConfig
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.parallelism_config import ParallelismConfig
|
||||
from accelerate.state import PartialState
|
||||
from accelerate.utils import FullyShardedDataParallelPlugin, set_seed
|
||||
from accelerate.utils.fsdp_utils import save_fsdp_optimizer
|
||||
from utils import PerformanceTracker, create_collate_fn, get_dataset, gpu_memory_usage_all, setup_tokenizer
|
||||
|
||||
|
||||
MODEL_ID = "NousResearch/Llama-3.2-1B"
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", type=str)
|
||||
parser.add_argument("--fsdp2-cls-name-to-wrap", type=str, default="LlamaDecoderLayer")
|
||||
parser.add_argument("--dp-replicate-size", type=int, default=1)
|
||||
parser.add_argument("--dp-shard-size", type=int, default=1)
|
||||
parser.add_argument("--tp-size", type=int, default=1)
|
||||
parser.add_argument("--sequence-length", type=int, default=128)
|
||||
parser.add_argument("--model-save-dir", type=str, default="./outputs")
|
||||
parser.add_argument(
|
||||
"--save-model", action="store_true", default=False, help="Whether to save the model after training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-optimizer",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to save the optimizer state after training.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main function to train the model.
|
||||
"""
|
||||
args = parse_args()
|
||||
|
||||
set_seed(42)
|
||||
|
||||
if args.model:
|
||||
model_id = args.model
|
||||
else:
|
||||
model_id = MODEL_ID
|
||||
|
||||
model_kwargs = {}
|
||||
accelerator_kwargs = {}
|
||||
|
||||
parallelism_config = ParallelismConfig(
|
||||
dp_replicate_size=args.dp_replicate_size,
|
||||
dp_shard_size=args.dp_shard_size,
|
||||
tp_size=args.tp_size,
|
||||
)
|
||||
|
||||
device_mesh = parallelism_config.build_device_mesh("cuda")
|
||||
|
||||
if args.tp_size > 1:
|
||||
model_kwargs["tp_size"] = args.tp_size
|
||||
model_kwargs["tp_plan"] = "auto"
|
||||
model_kwargs["device_mesh"] = device_mesh
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
use_cache=False,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
PartialState(device_mesh=device_mesh, parallelism_config=parallelism_config)
|
||||
|
||||
if parallelism_config.dp_shard_enabled:
|
||||
fsdp2_plugin = FullyShardedDataParallelPlugin(
|
||||
fsdp_version=2,
|
||||
cpu_ram_efficient_loading=False,
|
||||
auto_wrap_policy="transformer_based_wrap",
|
||||
transformer_cls_names_to_wrap=[args.fsdp2_cls_name_to_wrap],
|
||||
reshard_after_forward=True,
|
||||
activation_checkpointing=True,
|
||||
state_dict_type="FULL_STATE_DICT",
|
||||
)
|
||||
accelerator_kwargs["fsdp_plugin"] = fsdp2_plugin
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision="no",
|
||||
**accelerator_kwargs,
|
||||
)
|
||||
|
||||
accelerator.print("Memory usage after model load")
|
||||
accelerator.print(gpu_memory_usage_all())
|
||||
accelerator.print("=" * 20)
|
||||
tokenizer = setup_tokenizer(model_id)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)
|
||||
|
||||
model, optimizer = accelerator.prepare(model, optimizer)
|
||||
accelerator.print("Memory usage after model prepare")
|
||||
accelerator.print(gpu_memory_usage_all())
|
||||
accelerator.print("=" * 20)
|
||||
|
||||
dataset = get_dataset(accelerator, tokenizer, args.sequence_length)
|
||||
dataloader = DataLoader(dataset, batch_size=1, collate_fn=create_collate_fn())
|
||||
dataloader = accelerator.prepare(dataloader)
|
||||
|
||||
model.train()
|
||||
|
||||
total_num_steps = min(100, len(dataloader))
|
||||
performance_tracker = PerformanceTracker(warmup_steps=10)
|
||||
|
||||
accelerator.print("Starting training...")
|
||||
for step, batch in enumerate(dataloader):
|
||||
if step >= total_num_steps:
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
dist.all_reduce(loss, op=dist.ReduceOp.AVG)
|
||||
|
||||
batch_tokens = batch["input_ids"].shape[1]
|
||||
metrics = performance_tracker.step(batch_tokens)
|
||||
|
||||
print_msg = f"Step {step}/{total_num_steps}, Loss: {loss.item():.4f}"
|
||||
log_metrics = {"loss": loss.item()}
|
||||
|
||||
if "warmup_completed" in metrics:
|
||||
accelerator.print("Warm up completed! Starting performance tracking...")
|
||||
elif metrics:
|
||||
print_msg += f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f}\n"
|
||||
print_msg += (
|
||||
f"\tMemory (GB): active={metrics['peak_memory_active']:.1f}, "
|
||||
f"alloc={metrics['peak_memory_alloc']:.1f}, "
|
||||
f"reserved={metrics['peak_memory_reserved']:.1f}"
|
||||
)
|
||||
if step % 10 == 0 or step == total_num_steps - 1:
|
||||
accelerator.print(print_msg)
|
||||
|
||||
accelerator.log(log_metrics)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.end_training()
|
||||
accelerator.print("Training completed!")
|
||||
if parallelism_config.dp_shard_enabled and args.save_optimizer:
|
||||
accelerator.print("Saving optimizer state...")
|
||||
save_fsdp_optimizer(
|
||||
fsdp2_plugin,
|
||||
accelerator,
|
||||
optimizer,
|
||||
model,
|
||||
args.model_save_dir + "/opt",
|
||||
)
|
||||
accelerator.print("Optimizer state saved.")
|
||||
accelerator.print("Saving model state...")
|
||||
if args.save_model:
|
||||
model.save_pretrained(args.model_save_dir)
|
||||
accelerator.print(f"Model saved to {args.model_save_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
201
examples/fsdp2/utils.py
Normal file
201
examples/fsdp2/utils.py
Normal file
@ -0,0 +1,201 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Common utilities for FSDP2 examples.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
def get_dataset(accelerator: Accelerator, tokenizer: AutoTokenizer, seq_len: int) -> Dataset:
|
||||
"""
|
||||
Load and prepare TinyStories dataset.
|
||||
|
||||
Args:
|
||||
accelerator (Accelerator): Accelerate accelerator instance
|
||||
tokenizer (AutoTokenizer): Hugging Face tokenizer
|
||||
seq_len (int): Sequence length for the dataset
|
||||
|
||||
Returns:
|
||||
Dataset: Packed dataset
|
||||
"""
|
||||
raw_dataset = load_dataset("roneneldan/TinyStories", split="train[:50%]")
|
||||
|
||||
def tokenize_function(examples):
|
||||
tokenized_batch = tokenizer(
|
||||
examples["text"],
|
||||
padding=False,
|
||||
truncation=True,
|
||||
max_length=seq_len,
|
||||
return_tensors=None,
|
||||
)
|
||||
tokenized_batch["labels"] = tokenized_batch["input_ids"].copy()
|
||||
return tokenized_batch
|
||||
|
||||
with accelerator.main_process_first():
|
||||
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
|
||||
|
||||
def create_packed_sequences(examples):
|
||||
all_tokens = []
|
||||
for input_ids in examples["input_ids"]:
|
||||
all_tokens.extend(input_ids)
|
||||
|
||||
num_sequences = len(all_tokens) // (seq_len + 1)
|
||||
packed_input_ids = []
|
||||
packed_labels = []
|
||||
|
||||
for i in range(num_sequences):
|
||||
start_idx = i * (seq_len + 1)
|
||||
end_idx = start_idx + (seq_len + 1)
|
||||
full_sequence = all_tokens[start_idx:end_idx]
|
||||
packed_input_ids.append(full_sequence[:-1])
|
||||
packed_labels.append(full_sequence[1:])
|
||||
|
||||
return {"input_ids": packed_input_ids, "labels": packed_labels}
|
||||
|
||||
with accelerator.main_process_first():
|
||||
packed_dataset = tokenized_dataset.map(
|
||||
create_packed_sequences,
|
||||
batched=True,
|
||||
remove_columns=tokenized_dataset.column_names,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
return packed_dataset.shuffle(seed=42)
|
||||
|
||||
|
||||
def get_model_flops_per_token(model: AutoModelForCausalLM, seq_len: int) -> float:
|
||||
"""
|
||||
Get the number of flops per token for the model.
|
||||
|
||||
Args:
|
||||
model (AutoModelForCausalLM): Model to get the flops for
|
||||
seq_len (int): Sequence length
|
||||
"""
|
||||
cfg = model.config
|
||||
head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||
|
||||
# MLP: 3 matmuls
|
||||
mlp_flops = 18 * cfg.hidden_size * cfg.intermediate_size
|
||||
|
||||
# Attn (w/o dotproduct)
|
||||
attn_flops = 12 * head_dim * (cfg.num_attention_heads + cfg.num_key_value_heads)
|
||||
|
||||
# attn (dotproduct) - this scales quadratically with sequence length
|
||||
attn_dotproduct_flops = 12 * cfg.num_attention_heads * head_dim * seq_len
|
||||
|
||||
# we also ignore embeddings and layernorms, etc
|
||||
return (mlp_flops + attn_flops + attn_dotproduct_flops) * cfg.num_hidden_layers
|
||||
|
||||
|
||||
def create_collate_fn():
|
||||
"""Create a collate function for batching."""
|
||||
|
||||
def collate_fn(batch):
|
||||
input_ids = torch.tensor([item["input_ids"] for item in batch], dtype=torch.long)
|
||||
labels = torch.tensor([item["labels"] for item in batch], dtype=torch.long)
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
|
||||
return collate_fn
|
||||
|
||||
|
||||
class PerformanceTracker:
|
||||
"""Track training performance metrics."""
|
||||
|
||||
def __init__(self, warmup_steps: int = 10):
|
||||
self.warmup_steps = warmup_steps
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Reset all tracking variables."""
|
||||
self.start_time = None
|
||||
self.num_tokens = 0
|
||||
self.is_in_warmup = True
|
||||
self.step_count = 0
|
||||
|
||||
def step(self, batch_tokens: int) -> dict:
|
||||
"""
|
||||
Update performance tracking with a new step.
|
||||
|
||||
Args:
|
||||
batch_tokens (int): Number of tokens in current batch
|
||||
|
||||
Returns:
|
||||
dict: Performance metrics if past warmup, empty dict otherwise
|
||||
"""
|
||||
self.step_count += 1
|
||||
|
||||
if self.step_count == self.warmup_steps:
|
||||
self.start_time = time.perf_counter()
|
||||
self.num_tokens = 0
|
||||
self.is_in_warmup = False
|
||||
return {"warmup_completed": True}
|
||||
|
||||
if not self.is_in_warmup and self.start_time is not None:
|
||||
self.num_tokens += batch_tokens
|
||||
total_time = time.perf_counter() - self.start_time
|
||||
steps_from_warmup = self.step_count - self.warmup_steps
|
||||
|
||||
if total_time > 0 and steps_from_warmup > 0:
|
||||
memory_stats = gpu_memory_usage_all()
|
||||
return {
|
||||
"tokens_per_second": self.num_tokens / total_time,
|
||||
"steps_per_second": steps_from_warmup / total_time,
|
||||
"total_tokens": self.num_tokens,
|
||||
"total_time": total_time,
|
||||
**memory_stats,
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
def get_print_message(self, metrics: dict, with_memory: bool = False) -> str:
|
||||
print_msg = f" | Average steps/s: {metrics['steps_per_second']:.2f} | Average tokens/s: {metrics['tokens_per_second']:.2f}\n"
|
||||
if with_memory:
|
||||
print_msg += (
|
||||
f"\tMemory (GB): active={metrics['peak_memory_active']:.1f}, "
|
||||
f"alloc={metrics['peak_memory_alloc']:.1f}, "
|
||||
f"reserved={metrics['peak_memory_reserved']:.1f}"
|
||||
)
|
||||
return print_msg
|
||||
|
||||
|
||||
def setup_tokenizer(model_id: str) -> AutoTokenizer:
|
||||
"""Setup tokenizer with proper padding token."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
|
||||
|
||||
def gpu_memory_usage_all(device=0):
|
||||
device = torch.device(f"cuda:{device}")
|
||||
_BYTES_IN_GIB = 1024**3
|
||||
peak_memory_active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / _BYTES_IN_GIB
|
||||
peak_memory_alloc = torch.cuda.max_memory_allocated(device) / _BYTES_IN_GIB
|
||||
peak_memory_reserved = torch.cuda.max_memory_reserved(device) / _BYTES_IN_GIB
|
||||
memory_stats = {
|
||||
"peak_memory_active": peak_memory_active,
|
||||
"peak_memory_alloc": peak_memory_alloc,
|
||||
"peak_memory_reserved": peak_memory_reserved,
|
||||
}
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
|
||||
return memory_stats
|
@ -26,6 +26,7 @@ from .big_modeling import (
|
||||
from .data_loader import skip_first_batches
|
||||
from .inference import prepare_pippy
|
||||
from .launchers import debug_launcher, notebook_launcher
|
||||
from .parallelism_config import ParallelismConfig
|
||||
from .state import PartialState
|
||||
from .utils import (
|
||||
AutocastKwargs,
|
||||
|
@ -39,6 +39,7 @@ from .checkpointing import load_accelerator_state, load_custom_state, save_accel
|
||||
from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
|
||||
from .logging import get_logger
|
||||
from .optimizer import AcceleratedOptimizer
|
||||
from .parallelism_config import ParallelismConfig
|
||||
from .scheduler import AcceleratedScheduler
|
||||
from .state import AcceleratorState, GradientState, PartialState
|
||||
from .tracking import LOGGER_TYPE_TO_CLASS, GeneralTracker, filter_trackers
|
||||
@ -122,8 +123,6 @@ from .utils import (
|
||||
wait_for_everyone,
|
||||
)
|
||||
from .utils.constants import (
|
||||
BETA_TP_AVAILABLE_PYTORCH_VERSION,
|
||||
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION,
|
||||
FSDP2_PYTORCH_VERSION,
|
||||
FSDP_PYTORCH_VERSION,
|
||||
PROFILE_PATTERN_NAME,
|
||||
@ -211,8 +210,7 @@ class Accelerator:
|
||||
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
|
||||
using *accelerate config*
|
||||
torch_tp_plugin ([`~utils.TorchTensorParallelPlugin`], *optional*):
|
||||
Tweak your torch tensor parallel. This argument is optional and can be configured directly using
|
||||
*accelerate config*
|
||||
Deprecated: use `parallelism_config` with `tp_size` instead.
|
||||
megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*):
|
||||
Tweak your MegatronLM related args using this argument. This argument is optional and can be configured
|
||||
directly using *accelerate config*
|
||||
@ -287,7 +285,7 @@ class Accelerator:
|
||||
dataloader_config: DataLoaderConfiguration | None = None,
|
||||
deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
|
||||
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
|
||||
torch_tp_plugin: TorchTensorParallelPlugin | None = None,
|
||||
torch_tp_plugin: TorchTensorParallelPlugin | None = None, # Deprecate later, warning in `post_init`
|
||||
megatron_lm_plugin: MegatronLMPlugin | None = None,
|
||||
rng_types: list[str | RNGType] | None = None,
|
||||
log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
|
||||
@ -299,6 +297,7 @@ class Accelerator:
|
||||
dynamo_backend: DynamoBackend | str | None = None,
|
||||
dynamo_plugin: TorchDynamoPlugin | None = None,
|
||||
deepspeed_plugins: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
|
||||
parallelism_config: ParallelismConfig | None = None,
|
||||
):
|
||||
self.trackers = []
|
||||
if project_config is not None:
|
||||
@ -314,6 +313,12 @@ class Accelerator:
|
||||
raise ValueError(
|
||||
f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}"
|
||||
)
|
||||
if torch_tp_plugin is not None:
|
||||
warnings.warn(
|
||||
"`TorchTensorParallelPlugin` is deprecated and will be removed in a future version of Accelerate. "
|
||||
"Please use the `ParallelismConfig` with `tp_size` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
if dynamo_plugin is not None and dynamo_backend is not None:
|
||||
raise ValueError("You cannot pass in both `dynamo_plugin` and `dynamo_backend`, please only pass in one.")
|
||||
@ -376,13 +381,6 @@ class Accelerator:
|
||||
if not is_torch_version(">=", FSDP_PYTORCH_VERSION):
|
||||
raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")
|
||||
|
||||
if isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
|
||||
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
|
||||
raise ValueError(f"TP requires PyTorch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}")
|
||||
|
||||
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
|
||||
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
|
||||
|
||||
if fsdp_plugin is None: # init from env variables
|
||||
fsdp_plugin = (
|
||||
FullyShardedDataParallelPlugin() if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" else None
|
||||
@ -396,9 +394,6 @@ class Accelerator:
|
||||
if not is_torch_version(">=", FSDP2_PYTORCH_VERSION):
|
||||
raise ImportError(f"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}")
|
||||
|
||||
if torch_tp_plugin is not None and not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
|
||||
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
|
||||
|
||||
if megatron_lm_plugin is None: # init from env variables
|
||||
megatron_lm_plugin = (
|
||||
MegatronLMPlugin() if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" else None
|
||||
@ -451,19 +446,25 @@ class Accelerator:
|
||||
if "recipe_handler" in handler_attr and not self.has_fp8_handler:
|
||||
self.has_fp8_handler = True
|
||||
|
||||
parallelism_config = self._setup_parallelism_config(parallelism_config, torch_tp_plugin)
|
||||
|
||||
kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {}
|
||||
kwargs["parallelism_config"] = parallelism_config
|
||||
self.state = AcceleratorState(
|
||||
mixed_precision=mixed_precision,
|
||||
cpu=cpu,
|
||||
dynamo_plugin=dynamo_plugin,
|
||||
deepspeed_plugin=deepspeed_plugins,
|
||||
fsdp_plugin=fsdp_plugin,
|
||||
torch_tp_plugin=torch_tp_plugin,
|
||||
megatron_lm_plugin=megatron_lm_plugin,
|
||||
_from_accelerator=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.parallelism_config:
|
||||
self._build_torch_device_mesh(self.parallelism_config)
|
||||
self.parallelism_config._validate_accelerator(self)
|
||||
|
||||
self.fp8_enabled = self.state.mixed_precision == "fp8" or mixed_precision == "fp8"
|
||||
|
||||
# Check for automatic FP8 recipe creation
|
||||
@ -646,6 +647,18 @@ class Accelerator:
|
||||
"""
|
||||
return self.state.use_distributed
|
||||
|
||||
@property
|
||||
def multi_device(self):
|
||||
return self.use_distributed and self.distributed_type in (
|
||||
DistributedType.MULTI_GPU,
|
||||
DistributedType.MULTI_MLU,
|
||||
DistributedType.MULTI_SDAA,
|
||||
DistributedType.MULTI_MUSA,
|
||||
DistributedType.MULTI_NPU,
|
||||
DistributedType.MULTI_XPU,
|
||||
DistributedType.MULTI_HPU,
|
||||
)
|
||||
|
||||
@property
|
||||
def distributed_type(self):
|
||||
return self.state.distributed_type
|
||||
@ -730,6 +743,55 @@ class Accelerator:
|
||||
def is_fsdp2(self):
|
||||
return self.state.is_fsdp2
|
||||
|
||||
@property
|
||||
def is_composable_parallelism_enabled(self):
|
||||
return self.is_fsdp2
|
||||
|
||||
@property
|
||||
def parallelism_config(self) -> Union[ParallelismConfig, None]:
|
||||
return self.state.parallelism_config
|
||||
|
||||
@property
|
||||
def torch_device_mesh(self):
|
||||
return self.state.device_mesh
|
||||
|
||||
@property
|
||||
def should_save_model(self):
|
||||
if (pc := self.parallelism_config) is None:
|
||||
# shouldn't even happen
|
||||
return self.state.is_local_main_process
|
||||
_non_model_shard_dims = {
|
||||
pc.dp_replicate_enabled: "dp_replicate",
|
||||
pc.cp_enabled: "cp",
|
||||
}
|
||||
|
||||
# return all(
|
||||
# self.torch_device_mesh[dim].get_local_rank() == 0 for key, dim in non_model_shard_dims.items() if key
|
||||
# )
|
||||
# TODO: S1ro - this is a temporary solution until we figure out why `save_safe_file` is slow when not all processes
|
||||
return True
|
||||
|
||||
def _setup_parallelism_config(
|
||||
self, parallelism_config: ParallelismConfig | None, torch_tp_plugin: TorchTensorParallelPlugin | None
|
||||
):
|
||||
if parallelism_config is None:
|
||||
if PartialState._shared_state != {} and PartialState().parallelism_config is not None:
|
||||
parallelism_config = PartialState().parallelism_config
|
||||
else:
|
||||
# TODO: Remove after deprecating tp_plugin
|
||||
tp_size = 1 if torch_tp_plugin is None else torch_tp_plugin.tp_size
|
||||
parallelism_config = ParallelismConfig(tp_size=tp_size)
|
||||
|
||||
return parallelism_config
|
||||
|
||||
def _build_torch_device_mesh(self, parallelism_config):
|
||||
if PartialState._shared_state != {} and getattr(PartialState(), "device_mesh", None) is not None:
|
||||
device_mesh = PartialState().device_mesh
|
||||
else:
|
||||
device_mesh = parallelism_config.build_device_mesh(self.device.type)
|
||||
self.state.device_mesh = device_mesh
|
||||
PartialState().device_mesh = device_mesh
|
||||
|
||||
@contextmanager
|
||||
def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
|
||||
"""
|
||||
@ -1235,15 +1297,7 @@ class Accelerator:
|
||||
... optimizer.zero_grad()
|
||||
```
|
||||
"""
|
||||
if self.distributed_type in (
|
||||
DistributedType.MULTI_GPU,
|
||||
DistributedType.MULTI_NPU,
|
||||
DistributedType.MULTI_MLU,
|
||||
DistributedType.MULTI_SDAA,
|
||||
DistributedType.MULTI_MUSA,
|
||||
DistributedType.MULTI_XPU,
|
||||
DistributedType.MULTI_HPU,
|
||||
):
|
||||
if self.multi_device:
|
||||
dl_even_batches_values = []
|
||||
|
||||
if even_batches is not None:
|
||||
@ -1440,6 +1494,9 @@ class Accelerator:
|
||||
"You are using lower version of PyTorch(< 2.7.0) with ipex acceleration on Intel CPU or XPU, Intel has upstreamed most of the optimizations into stock PyTorch from 2.7.0, we enourage you to install the latest stock PyTorch and enjoy the out-of-experience on Intel CPU/XPU."
|
||||
)
|
||||
args = self._prepare_ipex(*args)
|
||||
if self.parallelism_config and self.parallelism_config.tp_enabled:
|
||||
args = self._prepare_tp(*args)
|
||||
|
||||
if self.fp8_backend == FP8BackendType.TE:
|
||||
args = self._prepare_te(*args)
|
||||
elif self.fp8_backend == FP8BackendType.AO:
|
||||
@ -1476,6 +1533,34 @@ class Accelerator:
|
||||
|
||||
return result if len(result) > 1 else result[0]
|
||||
|
||||
def _prepare_tp(self, *args):
|
||||
device_mesh = self.torch_device_mesh
|
||||
|
||||
for arg in args:
|
||||
if not isinstance(arg, torch.nn.Module):
|
||||
continue
|
||||
|
||||
from torch.distributed.tensor import DTensor, Replicate
|
||||
from transformers.integrations.tensor_parallel import ReplicateParallel
|
||||
|
||||
model: torch.nn.Module = arg
|
||||
tp_plan = ReplicateParallel
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if isinstance(param, DTensor):
|
||||
continue
|
||||
|
||||
dp = DTensor.from_local(param, device_mesh=device_mesh["tp"], placements=[Replicate()])
|
||||
param_name, param_type = name.rsplit(".", 1)
|
||||
module_to_tp = model.get_submodule(param_name)
|
||||
|
||||
tp_plan().prepare_module_tp(module_to_tp, device_mesh["tp"])
|
||||
if not isinstance(dp, torch.nn.Parameter):
|
||||
dp = torch.nn.Parameter(dp, requires_grad=param.requires_grad)
|
||||
setattr(module_to_tp, param_type, dp)
|
||||
|
||||
return args
|
||||
|
||||
def _prepare_fsdp2(self, *args):
|
||||
# First pass: prepare everything except schedulers (and model, which is prepared separately below)
|
||||
result = [
|
||||
@ -1514,14 +1599,18 @@ class Accelerator:
|
||||
old_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=True))
|
||||
|
||||
# Swap the optimizer parameters with empty, so `fully_shard` after will not allocate too much memory
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
for obj in result:
|
||||
if isinstance(obj, torch.optim.Optimizer):
|
||||
for param_group in obj.param_groups:
|
||||
for i, p in enumerate(param_group["params"]):
|
||||
# We drop a reference to the original param here, so that _move_states_to_device triggers a reallocation
|
||||
# We reassign the data_ptr to the original param, so that we preserve the mapping to the new ones
|
||||
param_group["params"][i] = torch.empty_like(p)
|
||||
param_group["params"][i].data_ptr = p.data_ptr()
|
||||
param_group["params"][i] = torch.empty(1, dtype=p.dtype, device=p.device)
|
||||
param_group["params"][i].data_ptr = (
|
||||
p._local_tensor.data_ptr() if isinstance(p, DTensor) else p.data_ptr()
|
||||
)
|
||||
|
||||
self._models.append(model)
|
||||
|
||||
@ -1645,15 +1734,7 @@ class Accelerator:
|
||||
elif device_placement and not self.verify_device_map(model):
|
||||
model = model.to(self.device)
|
||||
if not evaluation_mode:
|
||||
if self.distributed_type in (
|
||||
DistributedType.MULTI_GPU,
|
||||
DistributedType.MULTI_MLU,
|
||||
DistributedType.MULTI_SDAA,
|
||||
DistributedType.MULTI_MUSA,
|
||||
DistributedType.MULTI_NPU,
|
||||
DistributedType.MULTI_XPU,
|
||||
DistributedType.MULTI_HPU,
|
||||
):
|
||||
if self.multi_device and not (self.parallelism_config and self.parallelism_config.tp_enabled):
|
||||
if model_has_dtensor(model):
|
||||
raise ValueError(
|
||||
"Your model contains `DTensor` parameters, which is incompatible with DDP. Maybe you loaded your model with `device_map='auto'`? Specify `device_map='cuda'` or 'cpu' instead."
|
||||
@ -1668,23 +1749,20 @@ class Accelerator:
|
||||
device_ids, output_device = [self.local_process_index], self.local_process_index
|
||||
else:
|
||||
device_ids, output_device = None, None
|
||||
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=device_ids, output_device=output_device, **kwargs
|
||||
)
|
||||
if self.ddp_handler is not None:
|
||||
self.ddp_handler.register_comm_hook(model)
|
||||
elif self.distributed_type == DistributedType.TP:
|
||||
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
|
||||
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
|
||||
elif self.parallelism_config and self.parallelism_config.tp_enabled:
|
||||
if not hasattr(model, "tp_size"):
|
||||
raise NotImplementedError(
|
||||
"Model should undergo tensor parallel before passing it to accelerate."
|
||||
"You can use .from_pretrained(..., tp_plan='auto') if the model supports"
|
||||
)
|
||||
if model.tp_size != self.state.torch_tp_plugin.tp_size:
|
||||
if model.tp_size != self.parallelism_config.tp_size:
|
||||
raise ValueError(
|
||||
f"tp_size in the plugin {self.state.torch_tp_plugin.tp_size} should be same as model's tp size {model.tp_size}"
|
||||
f"tp_size in the plugin {self.parallelism_config.tp_size} should be same as model's tp size {model.tp_size}"
|
||||
)
|
||||
elif self.is_fsdp2:
|
||||
raise ValueError(
|
||||
@ -1820,7 +1898,7 @@ class Accelerator:
|
||||
del self._models[-2]
|
||||
self._models[-1] = model
|
||||
elif self.distributed_type == DistributedType.MULTI_CPU:
|
||||
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
|
||||
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler else {}
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
|
||||
if self.ddp_handler is not None:
|
||||
self.ddp_handler.register_comm_hook(model)
|
||||
@ -2347,11 +2425,10 @@ class Accelerator:
|
||||
Prepare the device mesh for distributed training. The dataloader will determine how to load data based on the
|
||||
device mesh.
|
||||
"""
|
||||
if self.state.torch_tp_plugin:
|
||||
return self.state.torch_tp_plugin.torch_device_mesh
|
||||
elif self.distributed_type == DistributedType.DEEPSPEED and hasattr(self.state, "ds_device_mesh"):
|
||||
if self.distributed_type == DistributedType.DEEPSPEED and hasattr(self.state, "ds_device_mesh"):
|
||||
return self.state.ds_device_mesh
|
||||
return None
|
||||
else:
|
||||
return self.torch_device_mesh
|
||||
|
||||
def _prepare_msamp(self, *args, device_placement):
|
||||
if not is_msamp_available():
|
||||
@ -3587,14 +3664,7 @@ class Accelerator:
|
||||
|
||||
map_location = load_model_func_kwargs.pop("map_location", None)
|
||||
if map_location is None:
|
||||
if self.num_processes > 1 and self.distributed_type in (
|
||||
DistributedType.MULTI_GPU,
|
||||
DistributedType.MULTI_MLU,
|
||||
DistributedType.MULTI_SDAA,
|
||||
DistributedType.MULTI_MUSA,
|
||||
DistributedType.MULTI_NPU,
|
||||
DistributedType.MULTI_HPU,
|
||||
):
|
||||
if self.num_processes > 1 and self.multi_device and self.distributed_type != DistributedType.MULTI_XPU:
|
||||
map_location = "on_device"
|
||||
else:
|
||||
map_location = "cpu"
|
||||
@ -3693,6 +3763,11 @@ class Accelerator:
|
||||
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
|
||||
|
||||
accessor_mapping[WeightWithDynamicFloat8CastTensor] = "_tensor"
|
||||
# we know we're in FSDP2 so DTensor is available
|
||||
if self.is_fsdp2:
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
accessor_mapping[DTensor] = "_local_tensor"
|
||||
|
||||
named_parameters.update(
|
||||
{
|
||||
|
@ -1128,16 +1128,20 @@ def prepare_data_loader(
|
||||
# ranks would range from 0...11
|
||||
# from data angle ranks should look like 0 0 0 1 1 1 2 2 2 3 3 3
|
||||
# processes with same ranks/ids would receive the same batch
|
||||
# for CP the same as TP applies
|
||||
submesh_fsdp_size = 1
|
||||
submesh_dp_size = 1
|
||||
submesh_tp_size = 1
|
||||
submesh_cp_size = 1
|
||||
if "tp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_tp_size = torch_device_mesh["tp"].size()
|
||||
if "dp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_dp_size = torch_device_mesh["dp"].size()
|
||||
if "fsdp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_fsdp_size = torch_device_mesh["fsdp"].size()
|
||||
process_index = process_index // submesh_tp_size
|
||||
if "cp" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_cp_size = torch_device_mesh["cp"].size()
|
||||
if "dp_replicate" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_dp_size = torch_device_mesh["dp_replicate"].size()
|
||||
if "dp_shard" in torch_device_mesh.mesh_dim_names:
|
||||
submesh_fsdp_size = torch_device_mesh["dp_shard"].size()
|
||||
process_index = process_index // (submesh_tp_size * submesh_cp_size)
|
||||
num_processes = submesh_fsdp_size * submesh_dp_size
|
||||
|
||||
# Sanity check
|
||||
|
268
src/accelerate/parallelism_config.py
Normal file
268
src/accelerate/parallelism_config.py
Normal file
@ -0,0 +1,268 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
from accelerate.utils.dataclasses import TorchTensorParallelConfig
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from accelerate import Accelerator
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParallelismConfig:
|
||||
"""
|
||||
A dataclass to configure parallelisms applied to the model. Inspired by torchtitan's `ParallelDims`
|
||||
https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/parallel_dims.py
|
||||
|
||||
Args:
|
||||
dp_replicate_size (`int`, defaults to `1`):
|
||||
The size of the data parallel group. If `dp_replicate_size` is set to 1, the data parallel replication
|
||||
group will not be used.
|
||||
dp_shard_size (`int`, defaults to `1`):
|
||||
The size of the model shard group. If `dp_replicate_size > 1` and `tp_size > 1`, `dp_shard_size` must also
|
||||
be greater than 1, as composing DDP + TP is currently not supported.
|
||||
tp_size (`int`, defaults to `1`):
|
||||
The size of the tensor parallel group. If `tp_size` is set to `1`, the tensor parallel group will not be
|
||||
used.
|
||||
cp_size (`int`, defaults to `1`):
|
||||
The size of the context parallel group. Currently not supported, but reserved for future use and enabled
|
||||
for downstream libraries.
|
||||
tp_handler (`~utils.TorchTensorParallelConfig`, defaults to `None`):
|
||||
The handler for the tensor parallel group.
|
||||
|
||||
You may obtain different distributed data parallel paradigms by configuring `dp_replicate_size` and `dp_shard_size`
|
||||
together:
|
||||
- `dp_replicate_size == 1` and `dp_shard_size > 1`, we obtain Fully Sharded Data Parallel (FSDP).
|
||||
- `dp_replicate_size > 1` and `dp_shard_size > 1`, we obtain Hybrid Sharded Data Parallel (HSDP).
|
||||
- `dp_replicate_size > 1` and `dp_shard_size == 1` is an invalid configuration, to use pure DP, use
|
||||
`DistributedDataParallelKwargs` instead.
|
||||
|
||||
"""
|
||||
|
||||
dp_replicate_size: int = 1
|
||||
dp_shard_size: int = 1
|
||||
tp_size: int = 1
|
||||
cp_size: int = 1
|
||||
|
||||
# we use Union because we might support other x parallel plugins (i.e. deepspeed, etc)
|
||||
tp_handler: Union[None, TorchTensorParallelConfig] = None
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"ParallelismConfig(\n "
|
||||
f"\tdp_replicate_size={self.dp_replicate_size},\n"
|
||||
f"\tdp_shard_size={self.dp_shard_size},\n"
|
||||
f"\ttp_size={self.tp_size},\n"
|
||||
f"\tcp_size={self.cp_size},\n"
|
||||
f"\ttotal_size={self.total_size}\n)"
|
||||
)
|
||||
|
||||
@property
|
||||
def dp_dim_names(self):
|
||||
"""Names of enabled dimensions across which data parallelism is applied."""
|
||||
dims = []
|
||||
if self.dp_replicate_enabled:
|
||||
dims += ["dp_replicate"]
|
||||
if self.dp_shard_enabled:
|
||||
dims += ["dp_shard"]
|
||||
return dims
|
||||
|
||||
@property
|
||||
def non_dp_dim_names(self):
|
||||
"""Names of enabled dimensions which will receive the same batch (non-data parallel dimensions)."""
|
||||
dims = []
|
||||
if self.tp_enabled:
|
||||
dims += ["tp"]
|
||||
if self.cp_enabled:
|
||||
dims += ["cp"]
|
||||
return dims
|
||||
|
||||
@property
|
||||
def dp_shard_cp_dim_names(self):
|
||||
"""Names of enabled dimensions which will be flattened into a joint mesh across which is model sharded in FSDP."""
|
||||
dims = []
|
||||
if self.dp_shard_enabled:
|
||||
dims += ["dp_shard"]
|
||||
if self.cp_enabled:
|
||||
dims += ["cp"]
|
||||
return dims
|
||||
|
||||
@property
|
||||
def dp_cp_dim_names(self):
|
||||
"""Names of enabled dimensions across which loss should be averaged"""
|
||||
dims = []
|
||||
if self.dp_replicate_enabled:
|
||||
dims += ["dp_replicate"]
|
||||
if self.dp_shard_enabled:
|
||||
dims += ["dp_shard"]
|
||||
if self.cp_enabled:
|
||||
dims += ["cp"]
|
||||
return dims
|
||||
|
||||
@property
|
||||
def fsdp_dim_names(self):
|
||||
"""Names of enabled dimensions across which FSDP is applied, including data parallel replication."""
|
||||
dims = []
|
||||
if self.dp_replicate_enabled:
|
||||
dims += ["dp_replicate"]
|
||||
dims += ["dp_shard_cp"]
|
||||
return dims
|
||||
|
||||
@property
|
||||
def total_size(self):
|
||||
"""The total size of the parallelism configuration, which is the product of all sizes."""
|
||||
return self.dp_replicate_size * self.dp_shard_size * self.tp_size * self.cp_size
|
||||
|
||||
@property
|
||||
def non_data_parallel_size(self):
|
||||
"""The size of the non-data parallel dimensions, which is the product of tensor and context parallel sizes."""
|
||||
return self.tp_size * self.cp_size
|
||||
|
||||
@property
|
||||
def data_parallel_size(self):
|
||||
"""The size of the data parallel dimensions, which is the product of data parallel replication and"""
|
||||
return self.dp_replicate_size * self.dp_shard_size
|
||||
|
||||
@property
|
||||
def dp_replicate_enabled(self):
|
||||
"""True if data parallel replication is enabled, i.e. `dp_replicate_size > 1`."""
|
||||
return self.dp_replicate_size > 1
|
||||
|
||||
@property
|
||||
def dp_shard_enabled(self):
|
||||
"""True if data parallel sharding is enabled, i.e. `dp_shard_size > 1`."""
|
||||
return self.dp_shard_size > 1
|
||||
|
||||
@property
|
||||
def tp_enabled(self):
|
||||
"""True if tensor parallelism is enabled, i.e. `tp_size > 1`."""
|
||||
return self.tp_size > 1
|
||||
|
||||
@property
|
||||
def cp_enabled(self):
|
||||
"""True if context parallelism is enabled, i.e. `cp_size > 1`."""
|
||||
return self.cp_size > 1
|
||||
|
||||
@property
|
||||
def active_mesh_dims(self):
|
||||
"""Names of all active mesh dimensions."""
|
||||
return self.dp_dim_names + self.non_dp_dim_names
|
||||
|
||||
def build_device_mesh(self, device_type: str):
|
||||
"""Builds a device mesh for the given device type based on the parallelism configuration.
|
||||
This method will also create required joint meshes (e.g. `dp_shard_cp`, `dp_cp`, `dp`).
|
||||
|
||||
Args:
|
||||
device_type (`str`): The type of device for which to build the mesh, e
|
||||
"""
|
||||
mesh = self._get_mesh()
|
||||
if len(mesh) == 0:
|
||||
return
|
||||
mesh_dim_names, mesh_shape = mesh
|
||||
device_mesh = init_device_mesh(
|
||||
device_type,
|
||||
mesh_shape,
|
||||
mesh_dim_names=mesh_dim_names,
|
||||
)
|
||||
if self.dp_dim_names:
|
||||
device_mesh[self.dp_dim_names]._flatten("dp")
|
||||
if self.dp_shard_cp_dim_names:
|
||||
device_mesh[self.dp_shard_cp_dim_names]._flatten("dp_shard_cp")
|
||||
if self.dp_cp_dim_names:
|
||||
device_mesh[self.dp_cp_dim_names]._flatten("dp_cp")
|
||||
|
||||
return device_mesh
|
||||
|
||||
def _get_mesh(self) -> tuple[tuple[int, ...], tuple[str, ...]]:
|
||||
"""Generate mesh shape and dimension names for torch.distributed.init_device_mesh()."""
|
||||
|
||||
# Build mesh dimensions dictionary
|
||||
mesh_dims = {parallelism: self._sizes[parallelism] for parallelism in self.active_mesh_dims}
|
||||
|
||||
# Apply canonical ordering
|
||||
mesh_order = ["dp_replicate", "dp_shard", "cp", "tp"]
|
||||
sorted_items = sorted(
|
||||
mesh_dims.items(),
|
||||
key=lambda x: (mesh_order.index(x[0])),
|
||||
)
|
||||
return tuple(zip(*sorted_items))
|
||||
|
||||
def __post_init__(self):
|
||||
# Basic size validation
|
||||
if self.dp_replicate_size < 1:
|
||||
raise ValueError(f"dp_replicate_size must be at least 1, but got {self.dp_replicate_size}")
|
||||
if self.dp_shard_size < 1:
|
||||
raise ValueError(f"dp_shard_size must be at least 1, but got {self.dp_shard_size}")
|
||||
if self.tp_size < 1:
|
||||
raise ValueError(f"tp_size must be at least 1, but got {self.tp_size}")
|
||||
if self.cp_size < 1:
|
||||
raise ValueError(f"cp_size must be at least 1, but got {self.cp_size}")
|
||||
|
||||
if (self.tp_size > 1 or self.cp_size > 1) and self.dp_replicate_size > 1 and self.dp_shard_size == 1:
|
||||
raise ValueError(
|
||||
"Tensor/Context parallelism (tp/cp_size > 1) cannot be used with pure data parallelism (dp_replicate_size > 1 and dp_shard_size == 1). "
|
||||
"Please set dp_shard_size > 1 and dp_replicate_size == 1 to compose FSDP + TP/CP for 2D parallel, "
|
||||
"or set dp_replicate_size == 1 and dp_shard_size > 1 to compose HSDP + TP/CP for 3D parallel."
|
||||
)
|
||||
self._sizes = {
|
||||
"dp_replicate": self.dp_replicate_size,
|
||||
"dp_shard": self.dp_shard_size,
|
||||
"tp": self.tp_size,
|
||||
"cp": self.cp_size,
|
||||
}
|
||||
|
||||
def _set_size(self, parallelism: str, size: int):
|
||||
assert parallelism in self._sizes.keys(), f"Parallelism must be one of {self._sizes.keys()}"
|
||||
self._sizes[parallelism] = size
|
||||
setattr(self, f"{parallelism}_size", size)
|
||||
|
||||
def _validate_accelerator(self, accelerator: "Accelerator"):
|
||||
_warnings = set()
|
||||
if not accelerator.multi_device and self.total_size == 1:
|
||||
# No distributed setup, valid parallelism config
|
||||
return
|
||||
|
||||
# We need this to ensure DDP works
|
||||
if self.total_size == 1:
|
||||
self._set_size("dp_replicate", accelerator.num_processes)
|
||||
|
||||
if self.total_size != accelerator.num_processes:
|
||||
raise ValueError(
|
||||
f"ParallelismConfig total_size ({self.total_size}) does not match "
|
||||
f"num_processes ({accelerator.num_processes}). Please adjust dp_replicate_size/ "
|
||||
f"dp_shard_size/tp_size/cp_size."
|
||||
)
|
||||
|
||||
if self.total_size > 1 and not (accelerator.is_fsdp2 or accelerator.multi_device):
|
||||
raise ValueError(
|
||||
f"ParallelismConfig is only compatible DistributedType.FSDP (version 2) or DistributedType.Multi{{Device}}, but got {accelerator.distributed_type}."
|
||||
)
|
||||
|
||||
for parallelism, size in self._sizes.items():
|
||||
if size == 1 and getattr(self, f"{parallelism}_handler", None) is not None:
|
||||
_warnings.add(
|
||||
f"ParallelismConfig.{parallelism}_handler is set, but {parallelism}_size is set to 1. This handler will be ignored."
|
||||
)
|
||||
|
||||
if _warnings and accelerator.is_main_process:
|
||||
warnings.warn(
|
||||
"ParallelismConfig has the following warnings:\n" + "\n".join(_warnings),
|
||||
UserWarning,
|
||||
)
|
@ -180,6 +180,8 @@ class PartialState:
|
||||
if not self.initialized:
|
||||
self._cpu = cpu
|
||||
self.backend = None
|
||||
self.parallelism_config = kwargs.pop("parallelism_config", None)
|
||||
self.device_mesh = kwargs.pop("device_mesh", None)
|
||||
env_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None)
|
||||
self.device = torch.device(env_device) if env_device is not None else None
|
||||
self.debug = parse_flag_from_env("ACCELERATE_DEBUG_MODE")
|
||||
@ -869,6 +871,8 @@ class AcceleratorState:
|
||||
- **device** (`torch.device`) -- The device to use.
|
||||
- **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently
|
||||
in use.
|
||||
- **parallelism_config** ([`~accelerate.utils.ParallelismConfig`]) -- The parallelism configuration for the
|
||||
current training environment. This is used to configure the distributed training environment.
|
||||
- **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`.
|
||||
- **local_process_index** (`int`) -- The index of the current process on the current server.
|
||||
- **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type
|
||||
@ -988,8 +992,6 @@ class AcceleratorState:
|
||||
self.distributed_type = DistributedType.MEGATRON_LM
|
||||
megatron_lm_plugin.set_mixed_precision(self._mixed_precision)
|
||||
self.megatron_lm_plugin = megatron_lm_plugin
|
||||
if self.torch_tp_plugin is not None:
|
||||
self.distributed_type = DistributedType.TP
|
||||
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
|
||||
if is_ipex_available():
|
||||
# check if user disables it explicitly
|
||||
|
@ -25,7 +25,8 @@ from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
|
||||
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate.utils import SAFE_WEIGHTS_NAME, TorchTensorParallelPlugin, set_seed
|
||||
from accelerate.parallelism_config import ParallelismConfig
|
||||
from accelerate.utils import SAFE_WEIGHTS_NAME, set_seed
|
||||
from accelerate.utils.deepspeed import DummyOptim, DummyScheduler
|
||||
|
||||
|
||||
@ -83,7 +84,7 @@ def training_function(config, args):
|
||||
accelerator_kwargs = {}
|
||||
# need this for DeepSpeed tests as `args.tp_size` would be None and `torch.distributed.init_device_mesh` would fail
|
||||
if args.tp_size is not None:
|
||||
accelerator_kwargs["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=args.tp_size)
|
||||
accelerator_kwargs["parallelism_config"] = ParallelismConfig(tp_size=args.tp_size)
|
||||
|
||||
# Initialize accelerator
|
||||
accelerator = Accelerator(**accelerator_kwargs)
|
||||
|
@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..parallelism_config import ParallelismConfig
|
||||
from .ao import convert_model_to_fp8_ao, filter_first_and_last_linear_layers, has_ao_layers
|
||||
from .constants import (
|
||||
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION,
|
||||
@ -61,6 +62,7 @@ from .dataclasses import (
|
||||
TensorInformation,
|
||||
TERecipeKwargs,
|
||||
TorchDynamoPlugin,
|
||||
TorchTensorParallelConfig,
|
||||
TorchTensorParallelPlugin,
|
||||
add_model_config_to_megatron_parser,
|
||||
)
|
||||
|
@ -33,6 +33,7 @@ import torch
|
||||
|
||||
from .constants import (
|
||||
BETA_TP_AVAILABLE_PYTORCH_VERSION,
|
||||
BETA_TP_AVAILABLE_TRANSFORMERS_VERSION,
|
||||
FSDP2_PYTORCH_VERSION,
|
||||
FSDP_AUTO_WRAP_POLICY,
|
||||
FSDP_BACKWARD_PREFETCH,
|
||||
@ -58,6 +59,7 @@ if TYPE_CHECKING:
|
||||
# Mock imports for type checking
|
||||
from torchao.float8 import Float8LinearConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -184,7 +186,9 @@ class DistributedDataParallelKwargs(KwargsHandler):
|
||||
|
||||
comm_hook: DDPCommunicationHookType = DDPCommunicationHookType.NO
|
||||
comm_wrapper: Literal[
|
||||
DDPCommunicationHookType.NO, DDPCommunicationHookType.FP16, DDPCommunicationHookType.BF16
|
||||
DDPCommunicationHookType.NO,
|
||||
DDPCommunicationHookType.FP16,
|
||||
DDPCommunicationHookType.BF16,
|
||||
] = DDPCommunicationHookType.NO
|
||||
comm_state_option: dict = field(default_factory=dict)
|
||||
|
||||
@ -192,7 +196,10 @@ class DistributedDataParallelKwargs(KwargsHandler):
|
||||
return {k: v for k, v in super().to_dict().items() if k not in ignore_keys}
|
||||
|
||||
def register_comm_hook(self, model):
|
||||
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks, powerSGD_hook
|
||||
from torch.distributed.algorithms.ddp_comm_hooks import (
|
||||
default_hooks,
|
||||
powerSGD_hook,
|
||||
)
|
||||
|
||||
hook_map: dict[DDPCommunicationHookType, Callable] = {
|
||||
DDPCommunicationHookType.FP16: default_hooks.fp16_compress_hook,
|
||||
@ -215,7 +222,11 @@ class DistributedDataParallelKwargs(KwargsHandler):
|
||||
if hook:
|
||||
state = (
|
||||
powerSGD_hook.PowerSGDState(None, **self.comm_state_option)
|
||||
if self.comm_hook in (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.BATCHED_POWER_SGD)
|
||||
if self.comm_hook
|
||||
in (
|
||||
DDPCommunicationHookType.POWER_SGD,
|
||||
DDPCommunicationHookType.BATCHED_POWER_SGD,
|
||||
)
|
||||
else None
|
||||
)
|
||||
model.register_comm_hook(
|
||||
@ -582,7 +593,6 @@ class DistributedType(str, enum.Enum):
|
||||
MULTI_XPU = "MULTI_XPU"
|
||||
DEEPSPEED = "DEEPSPEED"
|
||||
FSDP = "FSDP"
|
||||
TP = "TP"
|
||||
XLA = "XLA"
|
||||
MEGATRON_LM = "MEGATRON_LM"
|
||||
MULTI_HPU = "MULTI_HPU"
|
||||
@ -955,7 +965,10 @@ class GradientAccumulationPlugin(KwargsHandler):
|
||||
```
|
||||
"""
|
||||
|
||||
num_steps: int = field(default=None, metadata={"help": "The number of steps to accumulate gradients for."})
|
||||
num_steps: int = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of steps to accumulate gradients for."},
|
||||
)
|
||||
adjust_scheduler: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
@ -1006,12 +1019,22 @@ class TorchDynamoPlugin(KwargsHandler):
|
||||
metadata={"help": f"Possible options are {[b.value.lower() for b in DynamoBackend]}"},
|
||||
)
|
||||
mode: str = field(
|
||||
default=None, metadata={"help": "Possible options are 'default', 'reduce-overhead' or 'max-autotune'"}
|
||||
default=None,
|
||||
metadata={"help": "Possible options are 'default', 'reduce-overhead' or 'max-autotune'"},
|
||||
)
|
||||
fullgraph: bool = field(
|
||||
default=None,
|
||||
metadata={"help": "Whether it is ok to break model into several subgraphs"},
|
||||
)
|
||||
fullgraph: bool = field(default=None, metadata={"help": "Whether it is ok to break model into several subgraphs"})
|
||||
dynamic: bool = field(default=None, metadata={"help": "Whether to use dynamic shape for tracing"})
|
||||
options: Any = field(default=None, metadata={"help": "A dictionary of options to pass to the backend."})
|
||||
disable: bool = field(default=False, metadata={"help": "Turn torch.compile() into a no-op for testing"})
|
||||
options: Any = field(
|
||||
default=None,
|
||||
metadata={"help": "A dictionary of options to pass to the backend."},
|
||||
)
|
||||
disable: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Turn torch.compile() into a no-op for testing"},
|
||||
)
|
||||
|
||||
use_regional_compilation: bool = field(
|
||||
default=None,
|
||||
@ -1243,13 +1266,13 @@ class DeepSpeedPlugin:
|
||||
"stage": self.zero_stage,
|
||||
"offload_optimizer": {
|
||||
"device": self.offload_optimizer_device,
|
||||
"nvme_path": self.offload_optimizer_nvme_path
|
||||
if self.offload_optimizer_device == "nvme"
|
||||
else None,
|
||||
"nvme_path": (
|
||||
self.offload_optimizer_nvme_path if self.offload_optimizer_device == "nvme" else None
|
||||
),
|
||||
},
|
||||
"offload_param": {
|
||||
"device": self.offload_param_device,
|
||||
"nvme_path": self.offload_param_nvme_path if self.offload_param_device == "nvme" else None,
|
||||
"nvme_path": (self.offload_param_nvme_path if self.offload_param_device == "nvme" else None),
|
||||
},
|
||||
"stage3_gather_16bit_weights_on_model_save": self.zero3_save_16bit_model,
|
||||
},
|
||||
@ -1262,7 +1285,13 @@ class DeepSpeedPlugin:
|
||||
self.deepspeed_config["steps_per_print"] = float("inf") # this will stop deepspeed from logging @ stdout
|
||||
if self.zero3_init_flag is None:
|
||||
self.zero3_init_flag = (
|
||||
str_to_bool(os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_INIT", str(self.hf_ds_config.is_zero3()))) == 1
|
||||
str_to_bool(
|
||||
os.environ.get(
|
||||
"ACCELERATE_DEEPSPEED_ZERO3_INIT",
|
||||
str(self.hf_ds_config.is_zero3()),
|
||||
)
|
||||
)
|
||||
== 1
|
||||
)
|
||||
if self.zero3_init_flag and not self.hf_ds_config.is_zero3():
|
||||
warnings.warn("DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.")
|
||||
@ -1279,7 +1308,10 @@ class DeepSpeedPlugin:
|
||||
)
|
||||
if self.msamp_opt_level not in ["O1", "O2"]:
|
||||
raise ValueError("Invalid optimization level for MS-AMP. Please use one of ['O1' or'O2'].")
|
||||
self.deepspeed_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level}
|
||||
self.deepspeed_config["msamp"] = {
|
||||
"enabled": True,
|
||||
"opt_level": self.msamp_opt_level,
|
||||
}
|
||||
|
||||
def fill_match(self, ds_key_long, mismatches=None, must_match=True, **kwargs):
|
||||
mismatches = [] if mismatches is None else mismatches
|
||||
@ -1324,7 +1356,11 @@ class DeepSpeedPlugin:
|
||||
for key, value in config.items():
|
||||
if isinstance(value, dict):
|
||||
self.deepspeed_config_process(
|
||||
prefix=prefix + key + ".", mismatches=mismatches, config=value, must_match=must_match, **kwargs
|
||||
prefix=prefix + key + ".",
|
||||
mismatches=mismatches,
|
||||
config=value,
|
||||
must_match=must_match,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
self.fill_match(prefix + key, mismatches, must_match=must_match, **kwargs)
|
||||
@ -1351,7 +1387,10 @@ class DeepSpeedPlugin:
|
||||
|
||||
if mixed_precision == "fp8" and self.enable_msamp:
|
||||
if "msamp" not in ds_config:
|
||||
ds_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level}
|
||||
ds_config["msamp"] = {
|
||||
"enabled": True,
|
||||
"opt_level": self.msamp_opt_level,
|
||||
}
|
||||
|
||||
if mixed_precision != "no":
|
||||
diff_dtype = "bf16" if mixed_precision == "fp16" else "fp16"
|
||||
@ -1383,9 +1422,15 @@ class DeepSpeedPlugin:
|
||||
del ds_config["train_batch_size"]
|
||||
|
||||
if compare_versions("transformers", "<", "4.46"):
|
||||
from transformers.deepspeed import HfDeepSpeedConfig, unset_hf_deepspeed_config
|
||||
from transformers.deepspeed import (
|
||||
HfDeepSpeedConfig,
|
||||
unset_hf_deepspeed_config,
|
||||
)
|
||||
else:
|
||||
from transformers.integrations import HfDeepSpeedConfig, unset_hf_deepspeed_config
|
||||
from transformers.integrations import (
|
||||
HfDeepSpeedConfig,
|
||||
unset_hf_deepspeed_config,
|
||||
)
|
||||
|
||||
unset_hf_deepspeed_config()
|
||||
self.dschf = HfDeepSpeedConfig(ds_config) # keep this object alive # noqa
|
||||
@ -1583,7 +1628,11 @@ class FullyShardedDataParallelPlugin:
|
||||
},
|
||||
)
|
||||
mixed_precision_policy: Optional[
|
||||
Union[dict, "torch.distributed.fsdp.MixedPrecision", "torch.distributed.fsdp.MixedPrecisionPolicy"]
|
||||
Union[
|
||||
dict,
|
||||
"torch.distributed.fsdp.MixedPrecision",
|
||||
"torch.distributed.fsdp.MixedPrecisionPolicy",
|
||||
]
|
||||
] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@ -1601,7 +1650,11 @@ class FullyShardedDataParallelPlugin:
|
||||
},
|
||||
)
|
||||
)
|
||||
cpu_offload: Union[bool, "torch.distributed.fsdp.CPUOffload", "torch.distributed.fsdp.CPUOffloadPolicy"] = field(
|
||||
cpu_offload: Union[
|
||||
bool,
|
||||
"torch.distributed.fsdp.CPUOffload",
|
||||
"torch.distributed.fsdp.CPUOffloadPolicy",
|
||||
] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to offload parameters to CPU. Should be either a `bool` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. Defaults to `False`"
|
||||
@ -1628,7 +1681,10 @@ class FullyShardedDataParallelPlugin:
|
||||
metadata={"help": "State dict config to use. Is determined based on the `state_dict_type` if not passed in."},
|
||||
)
|
||||
optim_state_dict_config: Optional[
|
||||
Union["torch.distributed.fsdp.FullOptimStateDictConfig", "torch.distributed.fsdp.ShardedOptimStateDictConfig"]
|
||||
Union[
|
||||
"torch.distributed.fsdp.FullOptimStateDictConfig",
|
||||
"torch.distributed.fsdp.ShardedOptimStateDictConfig",
|
||||
]
|
||||
] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@ -1701,10 +1757,7 @@ class FullyShardedDataParallelPlugin:
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
from torch.distributed.fsdp import (
|
||||
BackwardPrefetch,
|
||||
ShardingStrategy,
|
||||
)
|
||||
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
|
||||
|
||||
_fsdp2_warnings = set()
|
||||
|
||||
@ -1738,7 +1791,8 @@ class FullyShardedDataParallelPlugin:
|
||||
# Fallback to `reshard_after_forward` in FSDP1 if `sharding_strategy` is not set
|
||||
if self.reshard_after_forward is None and self.sharding_strategy is None:
|
||||
reshard_after_forward = os.environ.get(
|
||||
env_prefix + "RESHARD_AFTER_FORWARD", "true" if self.fsdp_version == 2 else "FULL_SHARD"
|
||||
env_prefix + "RESHARD_AFTER_FORWARD",
|
||||
"true" if self.fsdp_version == 2 else "FULL_SHARD",
|
||||
)
|
||||
if self.fsdp_version == 2:
|
||||
self.reshard_after_forward = str_to_bool(reshard_after_forward.lower(), to_bool=True)
|
||||
@ -1795,7 +1849,10 @@ class FullyShardedDataParallelPlugin:
|
||||
raise ValueError(
|
||||
f"Invalid auto wrap policy: {self.auto_wrap_policy}. Must be one of {FSDP_AUTO_WRAP_POLICY}"
|
||||
)
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
|
||||
from torch.distributed.fsdp.wrap import (
|
||||
size_based_auto_wrap_policy,
|
||||
transformer_auto_wrap_policy,
|
||||
)
|
||||
|
||||
if self.auto_wrap_policy.upper() == "TRANSFORMER_BASED_WRAP":
|
||||
self.auto_wrap_policy = transformer_auto_wrap_policy
|
||||
@ -1910,7 +1967,8 @@ class FullyShardedDataParallelPlugin:
|
||||
|
||||
if self.state_dict_type is None:
|
||||
self.state_dict_type = os.environ.get(
|
||||
"FSDP_STATE_DICT_TYPE", "FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT"
|
||||
"FSDP_STATE_DICT_TYPE",
|
||||
"FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT",
|
||||
)
|
||||
if isinstance(self.state_dict_type, str):
|
||||
if self.state_dict_type.isdigit():
|
||||
@ -1940,7 +1998,10 @@ class FullyShardedDataParallelPlugin:
|
||||
Given `model`, creates an `auto_wrap_policy` baesd on the passed in policy and if we can use the
|
||||
`transformer_cls_to_wrap`
|
||||
"""
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
|
||||
from torch.distributed.fsdp.wrap import (
|
||||
size_based_auto_wrap_policy,
|
||||
transformer_auto_wrap_policy,
|
||||
)
|
||||
|
||||
# First base off of `_no_split_modules`
|
||||
no_split_modules = getattr(model, "_no_split_modules", None)
|
||||
@ -2077,35 +2138,30 @@ class TorchTensorParallelPlugin:
|
||||
metadata={"help": "tensor parallel size will be used in the device mesh preparation"},
|
||||
)
|
||||
|
||||
# torch_device_mesh is fo type "torch.distributed.DeviceMesh"
|
||||
# torch_device_mesh is of type "torch.distributed.DeviceMesh"
|
||||
torch_device_mesh: Optional["torch.distributed.DeviceMesh"] = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TorchTensorParallelConfig:
|
||||
"""
|
||||
Use this object in your [`Accelerator`] to customize your torch tensor parallelism.
|
||||
"""
|
||||
|
||||
enable_async_tp: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if not isinstance(self.tp_size, int):
|
||||
raise ValueError(f"`tp_size` set to {self.tp_size}, please set to an `int`.")
|
||||
|
||||
if self.tp_size <= 1:
|
||||
raise ValueError("`tp_size` must be greater than 1.")
|
||||
|
||||
if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION):
|
||||
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
|
||||
raise ValueError(
|
||||
f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel."
|
||||
f"Torch tensor parallelism is only available in PyTorch {BETA_TP_AVAILABLE_PYTORCH_VERSION} and later versions. "
|
||||
"Please upgrade your PyTorch version."
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
# support for other devices has to be investigated
|
||||
if is_hpu_available(init_hccl=True):
|
||||
device = "hpu"
|
||||
elif is_xpu_available():
|
||||
device = "xpu"
|
||||
else:
|
||||
device = "cuda"
|
||||
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
|
||||
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
|
||||
|
||||
mesh_dim_name = "tp"
|
||||
|
||||
# device mesh is not used for model sharding
|
||||
# it is only used for preparing data loader
|
||||
self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,))
|
||||
if self.enable_async_tp:
|
||||
warnings.warn("Async tensor parallelism is currently not supported, ignoring this option.")
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -2209,7 +2265,8 @@ class MegatronLMPlugin:
|
||||
pp_degree: int = field(default=None, metadata={"help": "pipeline parallelism degree."})
|
||||
num_micro_batches: int = field(default=None, metadata={"help": "number of micro-batches."})
|
||||
gradient_clipping: float = field(
|
||||
default=None, metadata={"help": "gradient clipping value based on global L2 Norm (0 to disable)"}
|
||||
default=None,
|
||||
metadata={"help": "gradient clipping value based on global L2 Norm (0 to disable)"},
|
||||
)
|
||||
sequence_parallelism: bool = field(
|
||||
default=None,
|
||||
@ -2224,7 +2281,8 @@ class MegatronLMPlugin:
|
||||
metadata={"help": "enable distributed optimizer"},
|
||||
)
|
||||
pipeline_model_parallel_split_rank: int = field(
|
||||
default=None, metadata={"help": "Rank where encoder and decoder should be split."}
|
||||
default=None,
|
||||
metadata={"help": "Rank where encoder and decoder should be split."},
|
||||
)
|
||||
num_layers_per_virtual_pipeline_stage: int = field(
|
||||
default=None, metadata={"help": "Number of layers per virtual pipeline stage."}
|
||||
@ -2321,10 +2379,12 @@ class MegatronLMPlugin:
|
||||
metadata={"help": "Whether to set all logging options."},
|
||||
)
|
||||
eval_iters: int = field(
|
||||
default=100, metadata={"help": "Number of iterations to run for evaluation validation/test for."}
|
||||
default=100,
|
||||
metadata={"help": "Number of iterations to run for evaluation validation/test for."},
|
||||
)
|
||||
eval_interval: int = field(
|
||||
default=1000, metadata={"help": "Interval between running evaluation on validation set."}
|
||||
default=1000,
|
||||
metadata={"help": "Interval between running evaluation on validation set."},
|
||||
)
|
||||
return_logits: bool = field(
|
||||
default=False,
|
||||
@ -2691,7 +2751,8 @@ class BnbQuantizationConfig:
|
||||
load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."})
|
||||
|
||||
llm_int8_threshold: float = field(
|
||||
default=6.0, metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"}
|
||||
default=6.0,
|
||||
metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"},
|
||||
)
|
||||
|
||||
load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."})
|
||||
|
@ -548,6 +548,11 @@ def fsdp2_switch_optimizer_parameters(optimizer: torch.optim.Optimizer, mapping:
|
||||
indicates a bug. If we kept the original params instead of raising, the training wouldn't be numerically
|
||||
correct and weights wouldn't get updated.
|
||||
"""
|
||||
from torch.distributed.tensor import DTensor
|
||||
|
||||
accessor_mapping = {}
|
||||
|
||||
accessor_mapping[DTensor] = "_local_tensor"
|
||||
try:
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["params"] = [mapping[p.data_ptr] for p in param_group["params"]]
|
||||
@ -615,12 +620,14 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
|
||||
fsdp2_plugin.set_auto_wrap_policy(model)
|
||||
|
||||
original_sd = model.state_dict()
|
||||
mesh = getattr(accelerator, "torch_device_mesh", None)
|
||||
|
||||
fsdp2_kwargs = {
|
||||
"reshard_after_forward": fsdp2_plugin.reshard_after_forward,
|
||||
"offload_policy": fsdp2_plugin.cpu_offload,
|
||||
# `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy`
|
||||
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
|
||||
"mesh": mesh[tuple(accelerator.parallelism_config.fsdp_dim_names)] if mesh is not None else None,
|
||||
}
|
||||
|
||||
model_has_params4bit = False
|
||||
|
@ -2097,7 +2097,6 @@ def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwarg
|
||||
DistributedType.MULTI_HPU,
|
||||
DistributedType.FSDP,
|
||||
DistributedType.XLA,
|
||||
DistributedType.TP,
|
||||
]:
|
||||
return torch.autocast(device_type=device_type, dtype=torch.bfloat16, **autocast_kwargs)
|
||||
else:
|
||||
|
126
tests/test_dataclasses.py
Normal file
126
tests/test_dataclasses.py
Normal file
@ -0,0 +1,126 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from accelerate.parallelism_config import ParallelismConfig
|
||||
|
||||
|
||||
class TestParallelismConfig:
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_init_device_mesh(self):
|
||||
def mock_init_mesh(device_type, mesh_shape, mesh_dim_names):
|
||||
mesh = Mock()
|
||||
mesh.size.return_value = 1
|
||||
for dim in mesh_shape:
|
||||
mesh.size.return_value *= dim
|
||||
mesh.shape = mesh_shape
|
||||
mesh.mesh_dim_names = mesh_dim_names
|
||||
|
||||
# mock device_mesh._flatten
|
||||
mesh.flattened_dims = []
|
||||
|
||||
def mock_getitem(key):
|
||||
submesh = Mock()
|
||||
|
||||
def mock_flatten(name):
|
||||
mesh.flattened_dims.append((key, name))
|
||||
|
||||
submesh._flatten = Mock(side_effect=mock_flatten)
|
||||
return submesh
|
||||
|
||||
mesh.__getitem__ = Mock(side_effect=mock_getitem)
|
||||
|
||||
return mesh
|
||||
|
||||
with patch("accelerate.parallelism_config.init_device_mesh", side_effect=mock_init_mesh):
|
||||
yield mock_init_mesh
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dp_replicate_size, dp_shard_size, tp_size, cp_size, expected_shape, expected_dim_names",
|
||||
[
|
||||
(8, 1, 1, 1, (8,), ("dp_replicate",)), # DDP
|
||||
(1, 8, 1, 1, (8,), ("dp_shard",)), # FSDP
|
||||
(2, 4, 1, 1, (2, 4), ("dp_replicate", "dp_shard")), # HSDP
|
||||
(1, 4, 2, 1, (4, 2), ("dp_shard", "tp")), # FSDP + TP
|
||||
(2, 2, 2, 1, (2, 2, 2), ("dp_replicate", "dp_shard", "tp")), # HSDP + TP
|
||||
(1, 1, 8, 1, (8,), ("tp",)), # TP only
|
||||
(1, 1, 1, 4, (4,), ("cp",)), # CP only
|
||||
(1, 4, 1, 2, (4, 2), ("dp_shard", "cp")), # FSDP + CP
|
||||
(1, 2, 2, 2, (2, 2, 2), ("dp_shard", "cp", "tp")), # FSDP + CP + TP
|
||||
(2, 2, 2, 2, (2, 2, 2, 2), ("dp_replicate", "dp_shard", "cp", "tp")), # HSDP + CP + TP
|
||||
],
|
||||
)
|
||||
def test_get_mesh(
|
||||
self,
|
||||
dp_replicate_size,
|
||||
dp_shard_size,
|
||||
tp_size,
|
||||
cp_size,
|
||||
expected_shape,
|
||||
expected_dim_names,
|
||||
):
|
||||
config = ParallelismConfig(
|
||||
dp_replicate_size=dp_replicate_size, dp_shard_size=dp_shard_size, tp_size=tp_size, cp_size=cp_size
|
||||
)
|
||||
mesh_dim_names, mesh_shape = config._get_mesh()
|
||||
assert mesh_shape == expected_shape
|
||||
assert mesh_dim_names == expected_dim_names
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"dp_replicate_size, dp_shard_size, tp_size, cp_size, expected_shape, expected_dim_names",
|
||||
[
|
||||
(8, 1, 1, 1, (8,), ("dp_replicate",)),
|
||||
(1, 8, 1, 1, (8,), ("dp_shard",)),
|
||||
(2, 4, 1, 1, (2, 4), ("dp_replicate", "dp_shard")),
|
||||
(1, 4, 2, 1, (4, 2), ("dp_shard", "tp")),
|
||||
(2, 2, 2, 1, (2, 2, 2), ("dp_replicate", "dp_shard", "tp")),
|
||||
(1, 1, 8, 1, (8,), ("tp",)),
|
||||
(1, 1, 1, 4, (4,), ("cp",)),
|
||||
(1, 4, 1, 2, (4, 2), ("dp_shard", "cp")),
|
||||
(1, 2, 2, 2, (2, 2, 2), ("dp_shard", "cp", "tp")),
|
||||
(2, 2, 2, 2, (2, 2, 2, 2), ("dp_replicate", "dp_shard", "cp", "tp")),
|
||||
],
|
||||
)
|
||||
def test_build_device_mesh(
|
||||
self,
|
||||
dp_replicate_size,
|
||||
dp_shard_size,
|
||||
tp_size,
|
||||
cp_size,
|
||||
expected_shape,
|
||||
expected_dim_names,
|
||||
):
|
||||
"""Test build_device_mesh creates correct mesh and applies flattening."""
|
||||
config = ParallelismConfig(
|
||||
dp_replicate_size=dp_replicate_size, dp_shard_size=dp_shard_size, tp_size=tp_size, cp_size=cp_size
|
||||
)
|
||||
device_mesh = config.build_device_mesh("cpu")
|
||||
|
||||
# Check mesh shape and dimension names match expected
|
||||
assert device_mesh.shape == expected_shape
|
||||
assert device_mesh.mesh_dim_names == expected_dim_names
|
||||
|
||||
# Check that correct flattening operations were called
|
||||
expected_flattened = []
|
||||
if config.dp_dim_names:
|
||||
expected_flattened.append((config.dp_dim_names, "dp"))
|
||||
if config.dp_shard_cp_dim_names:
|
||||
expected_flattened.append((config.dp_shard_cp_dim_names, "dp_shard_cp"))
|
||||
if config.dp_cp_dim_names:
|
||||
expected_flattened.append((config.dp_cp_dim_names, "dp_cp"))
|
||||
|
||||
assert device_mesh.flattened_dims == expected_flattened
|
Reference in New Issue
Block a user