mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-17 08:01:15 +08:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2943172b8f | |||
| f56f4441b3 | |||
| 45359a73ff | |||
| b5b68fbb4d | |||
| d190ed7e41 | |||
| b923e134e7 | |||
| b2956acbe9 | |||
| be0f7ce44f | |||
| 603a53f056 | |||
| 02e2ed567b | |||
| 8abd274a7f | |||
| b05d483944 | |||
| a74c7c9538 | |||
| a60640d7e2 | |||
| 611546f12d | |||
| 7d2a259e3d | |||
| e5c17f36a8 | |||
| 20de3fc959 | |||
| f84cb0c1fa | |||
| 136437e3e8 |
@ -7,6 +7,8 @@
|
||||
title: Installation
|
||||
title: Get started
|
||||
- sections:
|
||||
- local: big_modeling
|
||||
title: Handling big models
|
||||
- local: sagemaker
|
||||
title: Amazon SageMaker
|
||||
title: Guides
|
||||
|
||||
232
docs/source/big_modeling.mdx
Normal file
232
docs/source/big_modeling.mdx
Normal file
@ -0,0 +1,232 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Handling big models
|
||||
|
||||
When loading a pretrained model in PyTorch, the usual workflow looks like this:
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
my_model = ModelClass(...)
|
||||
state_dict = torch.load(checkpoint_file)
|
||||
my_model.load_state_dict(state_dict)
|
||||
```
|
||||
|
||||
In plain English, those steps are:
|
||||
1. Create the model with randomly initialized weights
|
||||
2. Load the model weights (in a dictionary usually called a state dict) from the disk
|
||||
3. Load those weights inside the model
|
||||
|
||||
While this works very well for regularly sized models, this workflow has some clear limitation when we deal with a huge model: in step 1, we load a full version of the model in RAM, and spend some time randomly initializing the weights (which will be discarded in step 3). In step 2, we load another full version of the model in RAM, with the pretrained weights. If you're loading a model with 6 billions parameters, this needs you will need 24GB of RAM for each copy of the model, so 48GB in total (half of it to load the model in FP16).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is quite new and still in its experimental stage. While we strive to provide a stable API, it's possible some small parts of the public API will change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Instantiating an empty model
|
||||
|
||||
The first tool Accelerate introduces to help with big models is a context manager [`init_empty_weights`] that helps you initialize a model without using any RAM, so that step 1 can be done on models of any size. Here is how it works:
|
||||
|
||||
```py
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
with init_empty_weights():
|
||||
my_model = ModelClass(...)
|
||||
```
|
||||
|
||||
For instance:
|
||||
|
||||
```py
|
||||
with init_empty_weights():
|
||||
model = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
|
||||
```
|
||||
|
||||
initializes an empty model with a bit more than 100B parameters. Behind the scenes, this relies on the meta device introduced in PyTorch 1.9. During the initialization under the context manager, each time a parameter is created, it is instantly moved on that device.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
You can't move a model initialized like this on CPU or another device directly, since it doesn't have any data. It's also very likely that a forward pass with that empty model will fail, as not all operations are supported on the meta device.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Sharded checkpoints
|
||||
|
||||
It's possible your model is so big that even a single copy won't fit in RAM. That doesn't mean it can't be loaded: if you have one or several GPUs, this is more memory available to store your model. In this case, it's better if your checkpoint is split in several smaller files that we call checkpoint shards.
|
||||
|
||||
Accelerate will handle sharded checkpoints as long as you follow the following format: your checkpoint should be in a folder, with several files containing the partial state dicts, and there should be an index in the JSON format that contains a dictionary mapping parameter names to the file containing their weights. For instance we could have a folder containing:
|
||||
|
||||
```bash
|
||||
first_state_dict.bin
|
||||
index.json
|
||||
second_state_dict.bin
|
||||
```
|
||||
|
||||
with index.json being the following file:
|
||||
|
||||
```
|
||||
{
|
||||
"linear1.weight": "first_state_dict.bin",
|
||||
"linear1.bias": "first_state_dict.bin",
|
||||
"linear2.weight": "second_state_dict.bin",
|
||||
"linear2.bias": "second_state_dict.bin"
|
||||
}
|
||||
```
|
||||
|
||||
and `first_state_dict.bin` containing the weights for `"linear1.weight"` and `"linear1.bias"`, `second_state_dict.bin` the ones for `"linear2.weight"` and `"linear2.bias"`
|
||||
|
||||
## Loading weights
|
||||
|
||||
The second tool Accelerate introduces is a function [`load_checkpoint_and_dispatch`], that will allow you to load a checkpoint inside your empty model. This supports full checkpoints (a single file containing the whole state dict) as well as sharded checkpoints. It will also automatically dispatch those weights across the devices you have available (GPUs, CPU RAM), so if you are loading a sharded checkpoint, the maximum RAM usage will be the size of the biggest shard.
|
||||
|
||||
Here is how we can use this to load the [GPT-J-6B](https://huggingface.co/EleutherAI/gpt-j-6B) model. You clone the sharded version of this model with:
|
||||
|
||||
```bash
|
||||
git clone https://huggingface.co/sgugger/sharded-gpt-j-6B
|
||||
cd sharded-gpt-j-6B
|
||||
git-lfs install
|
||||
git pull
|
||||
```
|
||||
|
||||
then we can initialize the model with
|
||||
|
||||
```py
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
checkpoint = "EleutherAI/gpt-j-6B"
|
||||
config = AutoConfig.from_pretrained(checkpoint)
|
||||
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
```
|
||||
|
||||
and load the checkpoint we just downloaded with:
|
||||
|
||||
```py
|
||||
from accelerate import load_checkpoint_and_dispatch
|
||||
|
||||
model = load_checkpoint_and_dispatch(
|
||||
model, "sharded-gpt-j-6B", device_map="auto", no_split_module_classes=["GPTJBlock"]
|
||||
)
|
||||
```
|
||||
|
||||
By passing `device_map="auto"`, we tell Accelerate to determine automatically where to put each layer of the model depending on the available resources:
|
||||
- first we use the maximum space available on the GPU(s)
|
||||
- if we still need space, we store the remaining weights on the CPU
|
||||
- if there is not enough RAM, we store the remaining weights on the hard drive as memory-mapped tensors
|
||||
|
||||
`no_split_module_classes=["GPTJBlock"]` indicates that the modules that are `GPTJBlock` should not be split on different devices. You should set here all blocks that include a residual connection of some kind.
|
||||
|
||||
You can see the `device_map` that Accelerate picked by accessing the `hf_device_map` attribute of your model:
|
||||
|
||||
```py
|
||||
model.hf_device_map
|
||||
```
|
||||
|
||||
```python out
|
||||
{'transformer.wte': 0,
|
||||
'transformer.drop': 0,
|
||||
'transformer.h.0': 0,
|
||||
'transformer.h.1': 0,
|
||||
'transformer.h.2': 0,
|
||||
'transformer.h.3': 0,
|
||||
'transformer.h.4': 0,
|
||||
'transformer.h.5': 0,
|
||||
'transformer.h.6': 0,
|
||||
'transformer.h.7': 0,
|
||||
'transformer.h.8': 0,
|
||||
'transformer.h.9': 0,
|
||||
'transformer.h.10': 0,
|
||||
'transformer.h.11': 0,
|
||||
'transformer.h.12': 0,
|
||||
'transformer.h.13': 0,
|
||||
'transformer.h.14': 0,
|
||||
'transformer.h.15': 0,
|
||||
'transformer.h.16': 0,
|
||||
'transformer.h.17': 0,
|
||||
'transformer.h.18': 0,
|
||||
'transformer.h.19': 0,
|
||||
'transformer.h.20': 0,
|
||||
'transformer.h.21': 0,
|
||||
'transformer.h.22': 0,
|
||||
'transformer.h.23': 0,
|
||||
'transformer.h.24': 1,
|
||||
'transformer.h.25': 1,
|
||||
'transformer.h.26': 1,
|
||||
'transformer.h.27': 1,
|
||||
'transformer.ln_f': 1,
|
||||
'lm_head': 1}
|
||||
```
|
||||
|
||||
You can also design your `device_map` yourself, if you prefer to explicitly decide where each layer should be. In this case, the command above becomes:
|
||||
|
||||
```py
|
||||
model = load_checkpoint_and_dispatch(model, "sharded-gpt-j-6B", device_map=my_device_map)
|
||||
```
|
||||
|
||||
## Run the model
|
||||
|
||||
Now that we have done this, our model lies across several devices, and maybe the hard drive. But it can still be used as a regular PyTorch model:
|
||||
|
||||
```py
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
inputs = tokenizer("Hello, my name is", return_tensors="pt")
|
||||
inputs = inputs.to(0)
|
||||
output = model.generate(inputs["input_ids"])
|
||||
tokenizer.decode(output[0].tolist())
|
||||
```
|
||||
|
||||
Behind the scenes, Accelerate added hooks to the model, so that:
|
||||
- at each layer, the inputs are put on the right device (so even if your model is spread across several GPUs, it works)
|
||||
- for the weights offloaded on the CPU, they are put on a GPU just before the forward pass, and cleaned up just after
|
||||
- for the weights offloaded on the hard drive, they are loaded in RAM then put on a GPU just before the forward pass, and cleaned up just after
|
||||
|
||||
This way, you model can run for inference even if it doesn't fit on one of the GPUs or the CPU RAM!
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This only supports inference of your model, not training. Most of the computation happens behind `torch.no_grad()` context managers to avoid spending some GPU memory with intermediate activations.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Limits and further development
|
||||
|
||||
We are aware of the current limitations in the API:
|
||||
|
||||
- While this could theoretically work just one CPU with potential disk offload, you need at least one GPU to run this API. This will be fixed in further development.
|
||||
- [`infer_auto_device_map`] (or `device_map="auto"` in [`load_checkpoint_and_dispatch`]) tries to maximize GPU and CPU RAM it sees available when you execute it. While PyTorch is very good at managing GPU RAM efficiently (and giving it back when not needed), it's not entirely true with Python and CPU RAM. Therefore, an automatically computed device map might be too intense on the CPU. Move a few modules to the disk device if you get crashes due to lack of RAM.
|
||||
- [`infer_auto_device_map`] (or `device_map="auto"` in [`load_checkpoint_and_dispatch`]) attributes devices sequentially (to avoid moving things back and forth) so if your first layer is bigger than the size of the GPU you have, it will end up with everything on the CPU/Disk.
|
||||
- [`load_checkpoint_and_dispatch`] and [`load_checkpoint_in_model`] do not perform any check on the correctness of your state dict compared to your model at the moment (this will be fixed in a future version), so you may get some weird errors if trying to load a checkpoint with mismatched or missing keys.
|
||||
- The model parallelism used when your model is split on several GPUs is naive and not optimized, meaning that only one GPU works at a given time and the other sits idle.
|
||||
- When weights are offloaded on the CPU/hard drive, there is no pre-fetching (yet, we will work on this for future versions) which means the weights are put on the GPU when they are needed and not before.
|
||||
- Hard-drive offloading might be very slow if the hardware you run on does not have fast communication between disk and CPU (like NVMes).
|
||||
|
||||
## API doc
|
||||
|
||||
[[autodoc]] cpu_offload
|
||||
|
||||
[[autodoc]] disk_offload
|
||||
|
||||
[[autodoc]] dispatch_model
|
||||
|
||||
[[autodoc]] infer_auto_device_map
|
||||
|
||||
[[autodoc]] init_empty_weights
|
||||
|
||||
[[autodoc]] load_checkpoint_and_dispatch
|
||||
|
||||
[[autodoc]] load_checkpoint_in_model
|
||||
@ -52,7 +52,7 @@ Changing it to work with accelerate is really easy and only adds a few lines of
|
||||
+ device = accelerator.device
|
||||
my_model.to(device)
|
||||
# Pass every important object (model, optimizer, dataloader) to *accelerator.prepare*
|
||||
+ my_model, my_optimizer, my_training_dataloader = accelerate.prepare(
|
||||
+ my_model, my_optimizer, my_training_dataloader = accelerator.prepare(
|
||||
+ my_model, my_optimizer, my_training_dataloader
|
||||
+ )
|
||||
|
||||
|
||||
@ -48,4 +48,4 @@ def training_function(args):
|
||||
+ inner_training_loop()
|
||||
```
|
||||
|
||||
[[autodoc]] memory_utils.find_executable_batch_size
|
||||
[[autodoc]] utils.find_executable_batch_size
|
||||
380
examples/by_feature/fsdp_with_peak_mem_tracking.py
Normal file
380
examples/by_feature/fsdp_with_peak_mem_tracking.py
Normal file
@ -0,0 +1,380 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from datasets import load_dataset, load_metric
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||
|
||||
|
||||
########################################################################
|
||||
# This is a fully working simple example to use Accelerate
|
||||
#
|
||||
# This example trains a Bert base model on GLUE MRPC
|
||||
# in any of the following settings (with the same script):
|
||||
# - single CPU or single GPU
|
||||
# - multi GPUS (using PyTorch distributed mode)
|
||||
# - (multi) TPUs
|
||||
# - fp16 (mixed-precision) or fp32 (normal precision)
|
||||
# - FSDP
|
||||
#
|
||||
# This example also demonstrates the checkpointing and sharding capabilities
|
||||
#
|
||||
# To run it in each of these various modes, follow the instructions
|
||||
# in the readme for examples:
|
||||
# https://github.com/huggingface/accelerate/tree/main/examples
|
||||
#
|
||||
########################################################################
|
||||
|
||||
|
||||
MAX_GPU_BATCH_SIZE = 16
|
||||
EVAL_BATCH_SIZE = 32
|
||||
|
||||
|
||||
# New Code #
|
||||
# Converting Bytes to Megabytes
|
||||
def b2mb(x):
|
||||
return int(x / 2**20)
|
||||
|
||||
|
||||
# New Code #
|
||||
# This context manager is used to track the peak memory usage of the process
|
||||
class TorchTracemalloc:
|
||||
def __enter__(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero
|
||||
self.begin = torch.cuda.memory_allocated()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
self.end = torch.cuda.memory_allocated()
|
||||
self.peak = torch.cuda.max_memory_allocated()
|
||||
self.used = b2mb(self.end - self.begin)
|
||||
self.peaked = b2mb(self.peak - self.begin)
|
||||
# print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")
|
||||
|
||||
|
||||
def training_function(config, args):
|
||||
# Initialize accelerator
|
||||
if args.with_tracking:
|
||||
accelerator = Accelerator(
|
||||
cpu=args.cpu, mixed_precision=args.mixed_precision, log_with="wandb", logging_dir=args.logging_dir
|
||||
)
|
||||
else:
|
||||
accelerator = Accelerator()
|
||||
accelerator.print(accelerator.distributed_type)
|
||||
|
||||
if hasattr(args.checkpointing_steps, "isdigit"):
|
||||
if args.checkpointing_steps == "epoch":
|
||||
checkpointing_steps = args.checkpointing_steps
|
||||
elif args.checkpointing_steps.isdigit():
|
||||
checkpointing_steps = int(args.checkpointing_steps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Argument `checkpointing_steps` must be either a number or `epoch`. `{args.checkpointing_steps}` passed."
|
||||
)
|
||||
else:
|
||||
checkpointing_steps = None
|
||||
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
|
||||
lr = config["lr"]
|
||||
num_epochs = int(config["num_epochs"])
|
||||
seed = int(config["seed"])
|
||||
batch_size = int(config["batch_size"])
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration
|
||||
if args.with_tracking:
|
||||
run = os.path.split(__file__)[-1].split(".")[0]
|
||||
if args.logging_dir:
|
||||
run = os.path.join(args.logging_dir, run)
|
||||
accelerator.print(run)
|
||||
accelerator.init_trackers(run, config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
datasets = load_dataset("glue", "mrpc")
|
||||
metric = load_metric("glue", "mrpc")
|
||||
|
||||
def tokenize_function(examples):
|
||||
# max_length=None => use the model max length (it's actually the default)
|
||||
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
|
||||
return outputs
|
||||
|
||||
# Apply the method we just defined to all the examples in all the splits of the dataset
|
||||
tokenized_datasets = datasets.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
remove_columns=["idx", "sentence1", "sentence2"],
|
||||
)
|
||||
|
||||
# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the
|
||||
# transformers library
|
||||
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
|
||||
|
||||
# If the batch size is too big we use gradient accumulation
|
||||
gradient_accumulation_steps = 1
|
||||
if batch_size > MAX_GPU_BATCH_SIZE:
|
||||
gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
|
||||
batch_size = MAX_GPU_BATCH_SIZE
|
||||
|
||||
def collate_fn(examples):
|
||||
# On TPU it's best to pad everything to the same length or training will be very slow.
|
||||
if accelerator.distributed_type == DistributedType.TPU:
|
||||
return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt")
|
||||
return tokenizer.pad(examples, padding="longest", return_tensors="pt")
|
||||
|
||||
# Instantiate dataloaders.
|
||||
train_dataloader = DataLoader(
|
||||
tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size
|
||||
)
|
||||
eval_dataloader = DataLoader(
|
||||
tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE
|
||||
)
|
||||
|
||||
set_seed(seed)
|
||||
|
||||
# Instantiate the model (we build the model here so that the seed also control new weights initialization)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, return_dict=True)
|
||||
# New Code #
|
||||
# For FSDP feature, it is highly recommended and efficient to prepare the model before creating optimizer
|
||||
model = accelerator.prepare(model)
|
||||
|
||||
# Instantiate optimizer
|
||||
# New Code #
|
||||
# For FSDP feature, at present it doesn't support multiple parameter groups,
|
||||
# so we need to create a single parameter group for the whole model
|
||||
optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr, weight_decay=2e-4)
|
||||
|
||||
# Instantiate scheduler
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=10,
|
||||
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
# New Code #
|
||||
# For FSDP feature, prepare everything except the model as we have already prepared the model
|
||||
# before creating the optimizer
|
||||
# There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
|
||||
# prepare method.
|
||||
optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
||||
optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
overall_step = 0
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
|
||||
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
|
||||
accelerator.load_state(args.resume_from_checkpoint)
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
|
||||
dirs.sort(key=os.path.getctime)
|
||||
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
|
||||
# Extract `epoch_{i}` or `step_{i}`
|
||||
training_difference = os.path.splitext(path)[0]
|
||||
|
||||
if "epoch" in training_difference:
|
||||
num_epochs -= int(training_difference.replace("epoch_", ""))
|
||||
resume_step = None
|
||||
else:
|
||||
resume_step = int(training_difference.replace("step_", ""))
|
||||
num_epochs -= resume_step // len(train_dataloader)
|
||||
# If resuming by step, we also need to know exactly how far into the DataLoader we went
|
||||
resume_step = (num_epochs * len(train_dataloader)) - resume_step
|
||||
|
||||
# Now we train the model
|
||||
for epoch in range(num_epochs):
|
||||
# New Code #
|
||||
# context manager to track the peak memory usage during the training epoch
|
||||
with TorchTracemalloc() as tracemalloc:
|
||||
model.train()
|
||||
if args.with_tracking:
|
||||
total_loss = 0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# We need to skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == 0:
|
||||
if resume_step is not None and step < resume_step:
|
||||
pass
|
||||
# We could avoid this line since we set the accelerator with `device_placement=True`.
|
||||
batch.to(accelerator.device)
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
loss = loss / gradient_accumulation_steps
|
||||
# We keep track of the loss at each epoch
|
||||
if args.with_tracking:
|
||||
total_loss += loss.detach().float()
|
||||
accelerator.backward(loss)
|
||||
if step % gradient_accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
# accelerator.print(lr_scheduler.get_lr())
|
||||
|
||||
overall_step += 1
|
||||
|
||||
if isinstance(checkpointing_steps, int):
|
||||
output_dir = f"step_{overall_step}"
|
||||
if overall_step % checkpointing_steps == 0:
|
||||
if args.output_dir is not None:
|
||||
output_dir = os.path.join(args.output_dir, output_dir)
|
||||
accelerator.save_state(output_dir)
|
||||
# New Code #
|
||||
# Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
|
||||
accelerator.print("Memory before entering the train : {}".format(b2mb(tracemalloc.begin)))
|
||||
accelerator.print("Memory consumed at the end of the train (end-begin): {}".format(tracemalloc.used))
|
||||
accelerator.print("Peak Memory consumed during the train (max-begin): {}".format(tracemalloc.peaked))
|
||||
accelerator.print(
|
||||
"Total Peak Memory consumed during the train (max): {}".format(
|
||||
tracemalloc.peaked + b2mb(tracemalloc.begin)
|
||||
)
|
||||
)
|
||||
# Logging the peak memory usage of the GPU to the tracker
|
||||
if args.with_tracking:
|
||||
accelerator.log(
|
||||
{
|
||||
"train_total_peak_memory": tracemalloc.peaked + b2mb(tracemalloc.begin),
|
||||
},
|
||||
step=epoch,
|
||||
)
|
||||
|
||||
# New Code #
|
||||
# context manager to track the peak memory usage during the evaluation
|
||||
with TorchTracemalloc() as tracemalloc:
|
||||
model.eval()
|
||||
samples_seen = 0
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
# We could avoid this line since we set the accelerator with `device_placement=True`.
|
||||
batch.to(accelerator.device)
|
||||
with torch.no_grad():
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1)
|
||||
# It is slightly faster to call this once, than multiple times
|
||||
predictions, references = accelerator.gather(
|
||||
(predictions, batch["labels"])
|
||||
) # If we are in a multiprocess environment, the last batch has duplicates
|
||||
if accelerator.num_processes > 1:
|
||||
if step == len(eval_dataloader) - 1:
|
||||
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
|
||||
references = references[: len(eval_dataloader.dataset) - samples_seen]
|
||||
else:
|
||||
samples_seen += references.shape[0]
|
||||
metric.add_batch(
|
||||
predictions=predictions,
|
||||
references=references,
|
||||
)
|
||||
|
||||
eval_metric = metric.compute()
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
if args.with_tracking:
|
||||
accelerator.log(
|
||||
{
|
||||
"accuracy": eval_metric["accuracy"],
|
||||
"f1": eval_metric["f1"],
|
||||
"train_loss": total_loss,
|
||||
},
|
||||
step=epoch,
|
||||
)
|
||||
|
||||
if checkpointing_steps == "epoch":
|
||||
output_dir = f"epoch_{epoch}"
|
||||
if args.output_dir is not None:
|
||||
output_dir = os.path.join(args.output_dir, output_dir)
|
||||
accelerator.save_state(output_dir)
|
||||
# New Code #
|
||||
# Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
|
||||
accelerator.print("Memory before entering the eval : {}".format(b2mb(tracemalloc.begin)))
|
||||
accelerator.print("Memory consumed at the end of the eval (end-begin): {}".format(tracemalloc.used))
|
||||
accelerator.print("Peak Memory consumed during the eval (max-begin): {}".format(tracemalloc.peaked))
|
||||
accelerator.print(
|
||||
"Total Peak Memory consumed during the eval (max): {}".format(tracemalloc.peaked + b2mb(tracemalloc.begin))
|
||||
)
|
||||
# Logging the peak memory usage of the GPU to the tracker
|
||||
if args.with_tracking:
|
||||
accelerator.log(
|
||||
{
|
||||
"eval_total_peak_memory": tracemalloc.peaked + b2mb(tracemalloc.begin),
|
||||
},
|
||||
step=epoch,
|
||||
)
|
||||
|
||||
if args.with_tracking:
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Simple example of training script.")
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default="no",
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help="Whether to use mixed precision. Choose"
|
||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||
"and an Nvidia Ampere GPU.",
|
||||
)
|
||||
parser.add_argument("--cpu", action="store_true", help="If passed, will train on the CPU.")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="If the training should continue from a checkpoint folder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_tracking",
|
||||
action="store_true",
|
||||
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=".",
|
||||
help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help="Location on where to store experiment tracking logs`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
type=str,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
required=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
config = {"lr": 2e-5, "num_epochs": 3, "seed": 1, "batch_size": 16}
|
||||
training_function(config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -19,7 +19,7 @@ from torch.utils.data import DataLoader
|
||||
from accelerate import Accelerator, DistributedType
|
||||
|
||||
# New Code #
|
||||
from accelerate.memory_utils import find_executable_batch_size
|
||||
from accelerate.utils import find_executable_batch_size
|
||||
from datasets import load_dataset, load_metric
|
||||
from transformers import (
|
||||
AdamW,
|
||||
|
||||
@ -177,7 +177,7 @@ def training_function(config, args):
|
||||
# First we check if it's a distributed system
|
||||
if accelerator.num_processes > 1:
|
||||
# Then see if we're on the last batch of our eval dataloader
|
||||
if step == len(eval_dataloader):
|
||||
if step == len(eval_dataloader) - 1:
|
||||
# Last batch needs to be truncated on distributed systems as it contains additional samples
|
||||
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
|
||||
references = references[: len(eval_dataloader.dataset) - samples_seen]
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
@ -30,9 +29,6 @@ from transformers import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
########################################################################
|
||||
# This is a fully working simple example to use Accelerate,
|
||||
# specifically showcasing the experiment tracking capability,
|
||||
|
||||
@ -232,7 +232,7 @@ def training_function(config, args):
|
||||
accelerator.save_state(output_dir)
|
||||
model.eval()
|
||||
accurate = 0
|
||||
num_elems = 0
|
||||
samples_seen = 0
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
# We could avoid this line since we set the accelerator with `device_placement=True`.
|
||||
batch = {k: v.to(accelerator.device) for k, v in batch.items()}
|
||||
@ -240,11 +240,19 @@ def training_function(config, args):
|
||||
with torch.no_grad():
|
||||
outputs = model(inputs)
|
||||
predictions = outputs.argmax(dim=-1)
|
||||
accurate_preds = accelerator.gather(predictions) == accelerator.gather(batch["label"])
|
||||
num_elems += accurate_preds.shape[0]
|
||||
predictions, references = accelerator.gather((predictions, batch["label"]))
|
||||
if accelerator.num_processes > 1:
|
||||
if step == len(eval_dataloader) - 1:
|
||||
predictions = predictions[: len(eval_dataloader) - samples_seen]
|
||||
references = references[: len(eval_dataloader) - samples_seen]
|
||||
else:
|
||||
samples_seen += references.shape[0]
|
||||
else:
|
||||
samples_seen += references.shape[0]
|
||||
accurate_preds = predictions == references
|
||||
accurate += accurate_preds.long().sum()
|
||||
|
||||
eval_metric = accurate.item() / num_elems
|
||||
eval_metric = accurate.item() / samples_seen
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}: {100 * eval_metric:.2f}")
|
||||
if args.with_tracking:
|
||||
|
||||
@ -215,6 +215,7 @@ def training_function(config, args):
|
||||
accelerator.save_state(output_dir)
|
||||
|
||||
model.eval()
|
||||
samples_seen = 0
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
# We could avoid this line since we set the accelerator with `device_placement=True`.
|
||||
batch.to(accelerator.device)
|
||||
@ -222,7 +223,15 @@ def training_function(config, args):
|
||||
outputs = model(**batch)
|
||||
predictions = outputs.logits.argmax(dim=-1)
|
||||
# It is slightly faster to call this once, than multiple times
|
||||
predictions, references = accelerator.gather((predictions, batch["labels"]))
|
||||
predictions, references = accelerator.gather(
|
||||
(predictions, batch["labels"])
|
||||
) # If we are in a multiprocess environment, the last batch has duplicates
|
||||
if accelerator.num_processes > 1:
|
||||
if step == len(eval_dataloader) - 1:
|
||||
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
|
||||
references = references[: len(eval_dataloader.dataset) - samples_seen]
|
||||
else:
|
||||
samples_seen += references.shape[0]
|
||||
metric.add_batch(
|
||||
predictions=predictions,
|
||||
references=references,
|
||||
|
||||
2
setup.py
2
setup.py
@ -36,7 +36,7 @@ extras["sagemaker"] = [
|
||||
|
||||
setup(
|
||||
name="accelerate",
|
||||
version="0.7.0.dev0",
|
||||
version="0.8.0",
|
||||
description="Accelerate",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@ -2,10 +2,20 @@
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
__version__ = "0.7.0.dev0"
|
||||
__version__ = "0.8.0"
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .kwargs_handlers import DistributedDataParallelKwargs, GradScalerKwargs, InitProcessGroupKwargs
|
||||
from .big_modeling import cpu_offload, disk_offload, dispatch_model, init_empty_weights, load_checkpoint_and_dispatch
|
||||
from .launchers import debug_launcher, notebook_launcher
|
||||
from .state import DistributedType
|
||||
from .utils import DeepSpeedPlugin, synchronize_rng_states
|
||||
from .utils import (
|
||||
DeepSpeedPlugin,
|
||||
DistributedDataParallelKwargs,
|
||||
DistributedType,
|
||||
FullyShardedDataParallelPlugin,
|
||||
GradScalerKwargs,
|
||||
InitProcessGroupKwargs,
|
||||
find_executable_batch_size,
|
||||
infer_auto_device_map,
|
||||
load_checkpoint_in_model,
|
||||
synchronize_rng_states,
|
||||
)
|
||||
|
||||
@ -25,14 +25,19 @@ from packaging import version
|
||||
|
||||
from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
|
||||
from .data_loader import prepare_data_loader
|
||||
from .kwargs_handlers import DistributedDataParallelKwargs, GradScalerKwargs, InitProcessGroupKwargs, KwargsHandler
|
||||
from .logging import get_logger
|
||||
from .optimizer import AcceleratedOptimizer
|
||||
from .scheduler import AcceleratedScheduler
|
||||
from .state import AcceleratorState, DistributedType, is_deepspeed_available
|
||||
from .state import AcceleratorState
|
||||
from .tracking import LOGGER_TYPE_TO_CLASS, GeneralTracker, filter_trackers
|
||||
from .utils import (
|
||||
DeepSpeedPlugin,
|
||||
DistributedDataParallelKwargs,
|
||||
DistributedType,
|
||||
FullyShardedDataParallelPlugin,
|
||||
GradScalerKwargs,
|
||||
InitProcessGroupKwargs,
|
||||
KwargsHandler,
|
||||
LoggerType,
|
||||
PrecisionType,
|
||||
RNGType,
|
||||
@ -40,6 +45,7 @@ from .utils import (
|
||||
extract_model_from_parallel,
|
||||
gather,
|
||||
get_pretty_name,
|
||||
is_deepspeed_available,
|
||||
pad_across_processes,
|
||||
reduce,
|
||||
save,
|
||||
@ -50,12 +56,9 @@ from .utils import (
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
from .deepspeed_utils import DeepSpeedEngineWrapper, DeepSpeedOptimizerWrapper
|
||||
from .utils import DeepSpeedEngineWrapper, DeepSpeedOptimizerWrapper
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Accelerator:
|
||||
@ -160,15 +163,18 @@ class Accelerator:
|
||||
assert isinstance(
|
||||
deepspeed_plugin, DeepSpeedPlugin
|
||||
), "`deepspeed_plugin` must be a DeepSpeedPlugin object."
|
||||
os.environ["USE_DEEPSPEED"] = "true" # use DeepSpeed if plugin is provided
|
||||
|
||||
if fsdp_plugin is None: # init from env variables
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin() if os.environ.get("USE_FSDP", "false") == "true" else None
|
||||
else:
|
||||
if not isinstance(fsdp_plugin, FullyShardedDataParallelPlugin):
|
||||
raise TypeError("`fsdp_plugin` must be a FullyShardedDataParallelPlugin object.")
|
||||
os.environ["USE_FSDP"] = "true" # use FSDP if plugin is provided
|
||||
|
||||
if os.environ.get("USE_FSDP", "false") == "true":
|
||||
if version.parse(torch.__version__) < version.parse("1.12.0.dev20220418+cu113"):
|
||||
raise ValueError("FSDP requires PyTorch >= 1.12.0.dev20220418+cu113")
|
||||
if fsdp_plugin is None: # init from env variables
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin()
|
||||
else:
|
||||
if not isinstance(fsdp_plugin, FullyShardedDataParallelPlugin):
|
||||
raise TypeError("`fsdp_plugin` must be a FullyShardedDataParallelPlugin object.")
|
||||
|
||||
# Kwargs handlers
|
||||
self.ddp_handler = None
|
||||
@ -462,17 +468,20 @@ class Accelerator:
|
||||
elif self.distributed_type == DistributedType.FSDP:
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
|
||||
|
||||
fsdp_plugin = self.state.fsdp_plugin
|
||||
model = FSDP(
|
||||
model,
|
||||
sharding_strategy=fsdp_plugin.sharding_strategy,
|
||||
cpu_offload=fsdp_plugin.cpu_offload,
|
||||
auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
|
||||
backward_prefetch=fsdp_plugin.backward_prefetch,
|
||||
ignored_modules=fsdp_plugin.ignored_modules,
|
||||
)
|
||||
if not fsdp_plugin.cpu_offload.offload_params:
|
||||
model.to(self.device)
|
||||
# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
|
||||
# don't wrap it again
|
||||
if type(model) != FSDP:
|
||||
fsdp_plugin = self.state.fsdp_plugin
|
||||
model = FSDP(
|
||||
model,
|
||||
sharding_strategy=fsdp_plugin.sharding_strategy,
|
||||
cpu_offload=fsdp_plugin.cpu_offload,
|
||||
auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
|
||||
backward_prefetch=fsdp_plugin.backward_prefetch,
|
||||
ignored_modules=fsdp_plugin.ignored_modules,
|
||||
)
|
||||
if not fsdp_plugin.cpu_offload.offload_params:
|
||||
model.to(self.device)
|
||||
elif self.distributed_type == DistributedType.MULTI_CPU:
|
||||
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
|
||||
|
||||
285
src/accelerate/big_modeling.py
Normal file
285
src/accelerate/big_modeling.py
Normal file
@ -0,0 +1,285 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .hooks import AlignDevicesHook, add_hook_to_module, attach_align_device_hook, attach_align_device_hook_on_blocks
|
||||
from .utils import (
|
||||
OffloadedWeightsLoader,
|
||||
check_device_map,
|
||||
extract_submodules_state_dict,
|
||||
infer_auto_device_map,
|
||||
load_checkpoint_in_model,
|
||||
offload_state_dict,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def init_empty_weights(include_buffers: bool = False):
|
||||
"""
|
||||
A context manager under which models are initialized with all parameters on the meta device, therefore creating an
|
||||
empty model. Useful when just initializing the model would blow the available RAM.
|
||||
|
||||
Args:
|
||||
include_buffers (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also put all buffers on the meta device while initializing.
|
||||
|
||||
Example:
|
||||
|
||||
```pyton
|
||||
import torch.nn as nn
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
# Initialize a model with 100 billions parameters in no time and without using any RAM.
|
||||
with init_empty_weights():
|
||||
tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Any model created under this context manager has no weights. As such you can't do something like
|
||||
`model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
old_register_parameter = nn.Module.register_parameter
|
||||
if include_buffers:
|
||||
old_register_buffer = nn.Module.register_buffer
|
||||
|
||||
def register_empty_parameter(module, name, param):
|
||||
old_register_parameter(module, name, param)
|
||||
if param is not None:
|
||||
module._parameters[name] = nn.Parameter(module._parameters[name].to(torch.device("meta")))
|
||||
|
||||
def register_empty_buffer(module, name, buffer):
|
||||
old_register_buffer(module, name, buffer)
|
||||
if buffer is not None:
|
||||
module._buffers[name] = module._buffers[name].to(torch.device("meta"))
|
||||
|
||||
try:
|
||||
nn.Module.register_parameter = register_empty_parameter
|
||||
if include_buffers:
|
||||
nn.Module.register_buffer = register_empty_buffer
|
||||
yield
|
||||
finally:
|
||||
nn.Module.register_parameter = old_register_parameter
|
||||
if include_buffers:
|
||||
nn.Module.register_buffer = old_register_buffer
|
||||
|
||||
|
||||
def cpu_offload(
|
||||
model: nn.Module,
|
||||
execution_device: Optional[torch.device] = None,
|
||||
offload_buffers: bool = False,
|
||||
state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""
|
||||
Activates full CPU offload for a model. As a result, all parameters of the model will be offloaded and only one
|
||||
copy of the state dict of the model will be kept. During the forward pass, parameters will be extracted from that
|
||||
state dict and put on the execution device passed as they are needed, then offloaded again.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`):
|
||||
The model to offload.
|
||||
execution_device (`torch.device`, *optional*):
|
||||
The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
|
||||
model first parameter device.
|
||||
offload_buffers (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to offload the buffers with the model parameters.
|
||||
state_dict (`Dict[str, torch.Tensor]`, *optional*):
|
||||
The state dict of the model that will be kept on CPU.
|
||||
"""
|
||||
if execution_device is None:
|
||||
execution_device = next(iter(model.parameters())).device
|
||||
if state_dict is None:
|
||||
state_dict = {n: p.to("cpu") for n, p in model.state_dict().items()}
|
||||
attach_align_device_hook(
|
||||
model, execution_device=execution_device, offload=True, offload_buffers=offload_buffers, weights_map=state_dict
|
||||
)
|
||||
add_hook_to_module(model, AlignDevicesHook(io_same_device=True))
|
||||
return model
|
||||
|
||||
|
||||
def disk_offload(
|
||||
model: nn.Module,
|
||||
offload_dir: Union[str, os.PathLike],
|
||||
execution_device: Optional[torch.device] = None,
|
||||
offload_buffers: bool = False,
|
||||
):
|
||||
"""
|
||||
Activates full disk offload for a model. As a result, all parameters of the model will be offloaded as
|
||||
memory-mapped array in a given folder. During the forward pass, parameters will be accessed from that folder and
|
||||
put on the execution device passed as they are needed, then offloaded again.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): The model to offload.
|
||||
offload_dir (`str` or `os.PathLike`):
|
||||
The folder in which to offload the model weights (or where the model weights are already offloaded).
|
||||
execution_device (`torch.device`, *optional*):
|
||||
The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
|
||||
model's first parameter device.
|
||||
offload_buffers (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to offload the buffers with the model parameters.
|
||||
"""
|
||||
if not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json")):
|
||||
offload_state_dict(offload_dir, model.state_dict())
|
||||
if execution_device is None:
|
||||
execution_device = next(iter(model.parameters())).device
|
||||
weights_map = OffloadedWeightsLoader(save_folder=offload_dir)
|
||||
attach_align_device_hook(
|
||||
model,
|
||||
execution_device=execution_device,
|
||||
offload=True,
|
||||
offload_buffers=offload_buffers,
|
||||
weights_map=weights_map,
|
||||
)
|
||||
add_hook_to_module(model, AlignDevicesHook(io_same_device=True))
|
||||
return model
|
||||
|
||||
|
||||
def dispatch_model(
|
||||
model: nn.Module,
|
||||
device_map: Dict[str, Union[str, int, torch.device]],
|
||||
main_device: Optional[torch.device] = None,
|
||||
state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||
offload_dir: Union[str, os.PathLike] = None,
|
||||
offload_buffers: bool = False,
|
||||
):
|
||||
"""
|
||||
Dispatches a model according to a given device map. Layers of the model might be spread across GPUs, offloaded on
|
||||
the CPU or even the disk.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`):
|
||||
The model to dispatch.
|
||||
device_map (`Dict[str, Union[str, int, torch.device]]`):
|
||||
A dictionary mapping module names in the models `state_dict` to the device they should go to. Note that
|
||||
`"disk"` is accepted even if it's not a proper value for `torch.device`.
|
||||
main_device (`str`, `int` or `torch.device`, *optional*):
|
||||
The main execution device. Will default to the first device in the `device_map` different from `"cpu"` or
|
||||
`"disk"`.
|
||||
state_dict (`Dict[str, torch.Tensor]`, *optional*):
|
||||
The state dict of the part of the model that will be kept on CPU.
|
||||
offload_dir (`str` or `os.PathLike`):
|
||||
The folder in which to offload the model weights (or where the model weights are already offloaded).
|
||||
offload_buffers (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to offload the buffers with the model parameters.
|
||||
"""
|
||||
# Error early if the device map is incomplete.
|
||||
check_device_map(model, device_map)
|
||||
|
||||
if main_device is None:
|
||||
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
|
||||
|
||||
cpu_modules = [name for name, device in device_map.items() if device == "cpu"]
|
||||
if state_dict is None and len(cpu_modules) > 0:
|
||||
state_dict = extract_submodules_state_dict(model.state_dict(), cpu_modules)
|
||||
|
||||
disk_modules = [name for name, device in device_map.items() if device == "disk"]
|
||||
if offload_dir is None and len(disk_modules) > 0:
|
||||
raise ValueError(
|
||||
"We need an `offload_dir` to dispatch this model according to this `device_map`, the following submodules "
|
||||
f"need to be offloaded: {', '.join(disk_modules)}."
|
||||
)
|
||||
if len(disk_modules) > 0 and (
|
||||
not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json"))
|
||||
):
|
||||
disk_state_dict = extract_submodules_state_dict(model.state_dict(), disk_modules)
|
||||
offload_state_dict(offload_dir, disk_state_dict)
|
||||
|
||||
execution_device = {
|
||||
name: main_device if device in ["cpu", "disk"] else device for name, device in device_map.items()
|
||||
}
|
||||
offload = {name: device in ["cpu", "disk"] for name, device in device_map.items()}
|
||||
save_folder = offload_dir if len(disk_modules) > 0 else None
|
||||
if state_dict is not None or save_folder is not None:
|
||||
weights_map = OffloadedWeightsLoader(state_dict=state_dict, save_folder=save_folder)
|
||||
else:
|
||||
weights_map = None
|
||||
|
||||
attach_align_device_hook_on_blocks(
|
||||
model,
|
||||
execution_device=execution_device,
|
||||
offload=offload,
|
||||
offload_buffers=offload_buffers,
|
||||
weights_map=weights_map,
|
||||
)
|
||||
model.hf_device_map = device_map
|
||||
return model
|
||||
|
||||
|
||||
def load_checkpoint_and_dispatch(
|
||||
model: nn.Module,
|
||||
checkpoint: Union[str, os.PathLike],
|
||||
device_map: Optional[Union[str, Dict[str, Union[int, str, torch.device]]]] = None,
|
||||
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
|
||||
no_split_module_classes: Optional[List[str]] = None,
|
||||
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
||||
offload_buffers: bool = False,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
offload_state_dict: bool = False,
|
||||
):
|
||||
"""
|
||||
Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are
|
||||
loaded and adds the various hooks that will make this model run properly (even if split across devices).
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): The model in which we want to load a checkpoint.
|
||||
checkpoint (`str` or `os.PathLike`):
|
||||
The folder checkpoint to load. It can be:
|
||||
- a path to a file containing a whole model state dict
|
||||
- a path to a `.json` file containing the index to a sharded checkpoint
|
||||
- a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
|
||||
device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
|
||||
name, once a given module name is inside, every submodule of it will be sent to the same device.
|
||||
|
||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`.
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU
|
||||
and the available CPU RAM if unset.
|
||||
no_split_module_classes (`List[str]`, *optional*):
|
||||
A list of layer class names that should never be split across device (for instance any layer that has a
|
||||
residual connection).
|
||||
offload_folder (`str` or `os.PathLike`, *optional*):
|
||||
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
||||
offload_buffers (`bool`, *optional*, defaults to `False`):
|
||||
In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
|
||||
well as the parameters.
|
||||
dtype (`str` or `torch.dtype`, *optional*):
|
||||
If provided, the weights will be converted to that type when loaded.
|
||||
offload_state_dict (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, will temporarily offload the CPU state dict on the hard drive to avoig getting out of CPU RAM if
|
||||
the weight of the CPU state dict + the biggest shard does not fit.
|
||||
"""
|
||||
if device_map == "auto":
|
||||
device_map = infer_auto_device_map(
|
||||
model, max_memory=max_memory, no_split_module_classes=no_split_module_classes, dtype=dtype
|
||||
)
|
||||
load_checkpoint_in_model(
|
||||
model,
|
||||
checkpoint,
|
||||
device_map=device_map,
|
||||
offload_folder=offload_folder,
|
||||
dtype=dtype,
|
||||
offload_state_dict=offload_state_dict,
|
||||
)
|
||||
if device_map is None:
|
||||
return model
|
||||
return dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_buffers=offload_buffers)
|
||||
@ -21,17 +21,25 @@ import numpy as np
|
||||
import torch
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from .state import is_tpu_available
|
||||
from .utils import MODEL_NAME, OPTIMIZER_NAME, RNG_STATE_NAME, SCALER_NAME, SCHEDULER_NAME, get_pretty_name, save
|
||||
from .utils import (
|
||||
MODEL_NAME,
|
||||
OPTIMIZER_NAME,
|
||||
RNG_STATE_NAME,
|
||||
SCALER_NAME,
|
||||
SCHEDULER_NAME,
|
||||
get_pretty_name,
|
||||
is_tpu_available,
|
||||
save,
|
||||
)
|
||||
|
||||
|
||||
if is_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
import logging
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_accelerator_state(
|
||||
|
||||
@ -17,7 +17,7 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from accelerate.state import ComputeEnvironment
|
||||
from accelerate.utils import ComputeEnvironment
|
||||
|
||||
from .cluster import get_cluster_input
|
||||
from .config_args import cache_dir, default_config_file, default_yaml_config_file, load_config_from_file # noqa: F401
|
||||
|
||||
@ -14,9 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from accelerate.state import ComputeEnvironment, DistributedType
|
||||
|
||||
from ...utils import is_deepspeed_available
|
||||
from ...utils import ComputeEnvironment, DistributedType, is_deepspeed_available
|
||||
from .config_args import ClusterConfig
|
||||
from .config_utils import _ask_field, _convert_distributed_mode, _convert_yes_no_to_bool
|
||||
|
||||
@ -64,7 +62,7 @@ def get_cluster_input():
|
||||
else:
|
||||
use_cpu = False
|
||||
|
||||
deepspeed_config = None
|
||||
deepspeed_config = {}
|
||||
if distributed_type in [DistributedType.MULTI_GPU, DistributedType.NO]:
|
||||
use_deepspeed = _ask_field(
|
||||
"Do you want to use DeepSpeed? [yes/NO]: ",
|
||||
@ -78,7 +76,6 @@ def get_cluster_input():
|
||||
is_deepspeed_available()
|
||||
), "DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source"
|
||||
|
||||
deepspeed_config = {}
|
||||
if distributed_type == DistributedType.DEEPSPEED:
|
||||
deepspeed_config["zero_stage"] = _ask_field(
|
||||
"What should be your DeepSpeed's ZeRO optimization stage (0, 1, 2, 3)? [2]: ",
|
||||
@ -99,6 +96,7 @@ def get_cluster_input():
|
||||
default=1,
|
||||
)
|
||||
|
||||
fsdp_config = {}
|
||||
if distributed_type in [DistributedType.MULTI_GPU]:
|
||||
use_fsdp = _ask_field(
|
||||
"Do you want to use FullyShardedDataParallel? [yes/NO]: ",
|
||||
@ -108,7 +106,6 @@ def get_cluster_input():
|
||||
)
|
||||
if use_fsdp:
|
||||
distributed_type = DistributedType.FSDP
|
||||
fsdp_config = {}
|
||||
if distributed_type == DistributedType.FSDP:
|
||||
fsdp_config["sharding_strategy"] = _ask_field(
|
||||
"What should be your sharding strategy ([1] FULL_SHARD, [2] SHARD_GRAD_OP)? [1]: ",
|
||||
@ -135,12 +132,27 @@ def get_cluster_input():
|
||||
else:
|
||||
main_training_function = "main"
|
||||
|
||||
num_processes = _ask_field(
|
||||
"How many processes in total will you use? [1]: ",
|
||||
lambda x: int(x),
|
||||
default=1,
|
||||
error_message="Please enter an integer.",
|
||||
)
|
||||
if distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_GPU, DistributedType.TPU]:
|
||||
machine_type = str(distributed_type).split(".")[1].replace("MULTI_", "")
|
||||
if machine_type == "TPU":
|
||||
machine_type += " cores"
|
||||
else:
|
||||
machine_type += "(s)"
|
||||
num_processes = _ask_field(
|
||||
f"How many {machine_type} should be used for distributed training? [1]:",
|
||||
lambda x: int(x),
|
||||
default=1,
|
||||
error_message="Please enter an integer.",
|
||||
)
|
||||
elif distributed_type in [DistributedType.FSDP, DistributedType.DEEPSPEED]:
|
||||
num_processes = _ask_field(
|
||||
"How many GPU(s) should be used for distributed training? [1]:",
|
||||
lambda x: int(x),
|
||||
default=1,
|
||||
error_message="Please enter an integer.",
|
||||
)
|
||||
else:
|
||||
num_processes = 1
|
||||
|
||||
if distributed_type != DistributedType.TPU:
|
||||
mixed_precision = _ask_field(
|
||||
|
||||
@ -21,7 +21,8 @@ from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
import yaml
|
||||
from accelerate.state import ComputeEnvironment, DistributedType, SageMakerDistributedType
|
||||
|
||||
from ...utils import ComputeEnvironment, DistributedType, SageMakerDistributedType
|
||||
|
||||
|
||||
hf_cache_home = os.path.expanduser(
|
||||
@ -139,6 +140,13 @@ class ClusterConfig(BaseConfig):
|
||||
# args for fsdp
|
||||
fsdp_config: dict = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.deepspeed_config is None:
|
||||
self.deepspeed_config = {}
|
||||
if self.fsdp_config is None:
|
||||
self.fsdp_config = {}
|
||||
return super().__post_init__()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SageMakerConfig(BaseConfig):
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from accelerate.state import ComputeEnvironment, DistributedType, SageMakerDistributedType
|
||||
from ...utils.dataclasses import ComputeEnvironment, DistributedType, SageMakerDistributedType
|
||||
|
||||
|
||||
def _ask_field(input_text, convert_value=None, default=None, error_message=None):
|
||||
|
||||
@ -16,9 +16,8 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from accelerate.state import ComputeEnvironment, SageMakerDistributedType
|
||||
from accelerate.utils import is_boto3_available
|
||||
|
||||
from ...utils.dataclasses import ComputeEnvironment, SageMakerDistributedType
|
||||
from ...utils.imports import is_boto3_available
|
||||
from .config_args import SageMakerConfig
|
||||
from .config_utils import _ask_field, _convert_sagemaker_distributed_mode
|
||||
|
||||
|
||||
@ -26,8 +26,13 @@ from typing import Dict, List
|
||||
|
||||
from accelerate.commands.config import default_config_file, load_config_from_file
|
||||
from accelerate.commands.config.config_args import SageMakerConfig
|
||||
from accelerate.state import ComputeEnvironment, DistributedType
|
||||
from accelerate.utils import PrecisionType, PrepareForLaunch, is_sagemaker_available
|
||||
from accelerate.utils import (
|
||||
ComputeEnvironment,
|
||||
DistributedType,
|
||||
PrecisionType,
|
||||
PrepareForLaunch,
|
||||
is_sagemaker_available,
|
||||
)
|
||||
|
||||
|
||||
def launch_command_parser(subparsers=None):
|
||||
|
||||
411
src/accelerate/hooks.py
Normal file
411
src/accelerate/hooks.py
Normal file
@ -0,0 +1,411 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from typing import Dict, Mapping, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .utils import PrefixedDataset, find_device, named_module_tensors, send_to_device, set_module_tensor_to_device
|
||||
|
||||
|
||||
class ModelHook:
|
||||
"""
|
||||
A hook that contains callbacks to be executed just before and after the forward method of a model. The difference
|
||||
with PyTorch existing hooks is that they get passed along the kwargs.
|
||||
|
||||
Class attribute:
|
||||
- **no_grad** (`bool`, *optional*, defaults to `False`) -- Whether or not to execute the actual forward pass under
|
||||
the `torch.no_grad()` context manager.
|
||||
"""
|
||||
|
||||
no_grad = False
|
||||
|
||||
def init_hook(self, module):
|
||||
"""
|
||||
To be executed when the hook is attached to the module.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`): The module attached to this hook.
|
||||
"""
|
||||
return module
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
"""
|
||||
To be executed just before the forward method of the model.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`): The module whose forward pass will be executed just after this event.
|
||||
args (`Tuple[Any]`): The positional arguments passed to the module.
|
||||
kwargs (`Dict[Str, Any]`): The keyword arguments passed to the module.
|
||||
|
||||
Returns:
|
||||
`Tuple[Tuple[Any], Dict[Str, Any]]`: A tuple with the treated `args` and `kwargs`.
|
||||
"""
|
||||
return args, kwargs
|
||||
|
||||
def post_forward(self, module, output):
|
||||
"""
|
||||
To be executed just after the forward method of the model.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`): The module whose forward pass been executed just before this event.
|
||||
output (`Any`): The output of the module.
|
||||
|
||||
Returns:
|
||||
`Any`: The processed `output`.
|
||||
"""
|
||||
return output
|
||||
|
||||
def detach_hook(self, module):
|
||||
"""
|
||||
To be executed when the hook is deached from a module.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`): The module detached from this hook.
|
||||
"""
|
||||
return module
|
||||
|
||||
|
||||
class SequentialHook(ModelHook):
|
||||
"""
|
||||
A hook that can contain several hooks and iterates through them at each event.
|
||||
"""
|
||||
|
||||
def __init__(self, *hooks):
|
||||
self.hooks = hooks
|
||||
|
||||
def init_hook(self, module):
|
||||
for hook in self.hooks:
|
||||
module = hook.init_hook(module)
|
||||
return module
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
for hook in self.hooks:
|
||||
args, kwargs = hook.pre_forward(module, *args, **kwargs)
|
||||
return args, kwargs
|
||||
|
||||
def post_forward(self, module, output):
|
||||
for hook in self.hooks:
|
||||
output = hook.post_forward(module, output)
|
||||
return output
|
||||
|
||||
def detach_hook(self, module):
|
||||
for hook in self.hooks:
|
||||
module = hook.detach_hook(module)
|
||||
return module
|
||||
|
||||
|
||||
def add_hook_to_module(module: nn.Module, hook: ModelHook):
|
||||
"""
|
||||
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
|
||||
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
If the module already contains a hook, this will replace it with the new hook passed. To chain two hooks together,
|
||||
use the `SequentialHook` class.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`): The module to attach a hook to.
|
||||
hook (`ModelHook`): The hook to attach.
|
||||
|
||||
Returns:
|
||||
`torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can
|
||||
be discarded).
|
||||
"""
|
||||
if hasattr(module, "_hf_hook") and hasattr(module, "_old_forward"):
|
||||
# If we already put some hook on this module, we replace it with the new one.
|
||||
old_forward = module._old_forward
|
||||
else:
|
||||
old_forward = module.forward
|
||||
module._old_forward = old_forward
|
||||
|
||||
module = hook.init_hook(module)
|
||||
module._hf_hook = hook
|
||||
|
||||
@functools.wraps(old_forward)
|
||||
def new_forward(*args, **kwargs):
|
||||
args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
|
||||
if module._hf_hook.no_grad:
|
||||
with torch.no_grad():
|
||||
output = old_forward(*args, **kwargs)
|
||||
else:
|
||||
output = old_forward(*args, **kwargs)
|
||||
return module._hf_hook.post_forward(module, output)
|
||||
|
||||
module.forward = new_forward
|
||||
return module
|
||||
|
||||
|
||||
def remove_hook_from_module(module: nn.Module):
|
||||
"""
|
||||
Removes any hook attached to a module via `add_hook_to_module`.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`): The module to attach a hook to.
|
||||
|
||||
Returns:
|
||||
`torch.nn.Module`: The same module, with the hook detached (the module is modified in place, so the result can
|
||||
be discarded).
|
||||
"""
|
||||
if hasattr(module, "_hf_hook"):
|
||||
module._hf_hook.detach_hook(module)
|
||||
delattr(module, "_hf_hook")
|
||||
|
||||
if hasattr(module, "_old_forward"):
|
||||
module.forward = module._old_forward
|
||||
delattr(module, "_old_forward")
|
||||
|
||||
return module
|
||||
|
||||
|
||||
class AlignDevicesHook(ModelHook):
|
||||
"""
|
||||
A generic `ModelHook` that ensures inputs and model weights are on the same device for the forward pass of the
|
||||
associated module, potentially offloading the weights after the forward pass.
|
||||
|
||||
Args:
|
||||
execution_device (`torch.device`, *optional*):
|
||||
The device on which inputs and model weights should be placed before the forward pass.
|
||||
offload (`bool`, *optional*, defauts to `False`):
|
||||
Whether or not the weights should be offloaded after the forward pass.
|
||||
io_same_device (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the output should be placed on the same device as the input was.
|
||||
weights_map (`Mapping[str, torch.Tensor]`, *optional*):
|
||||
When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
|
||||
offload_buffers (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to include the associated module's buffers when offloading.
|
||||
place_submodules (`bool`, *optional*, defaults to `False`):
|
||||
Whether to place the submodules on `execution_device` during the `init_hook` event.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
execution_device: Optional[Union[int, str, torch.device]] = None,
|
||||
offload: bool = False,
|
||||
io_same_device: bool = False,
|
||||
weights_map: Optional[Mapping] = None,
|
||||
offload_buffers: bool = False,
|
||||
place_submodules: bool = False,
|
||||
):
|
||||
self.execution_device = execution_device
|
||||
self.offload = offload
|
||||
self.io_same_device = io_same_device
|
||||
self.weights_map = weights_map
|
||||
self.offload_buffers = offload_buffers
|
||||
self.place_submodules = place_submodules
|
||||
|
||||
# Will contain the input device when `io_same_device=True`.
|
||||
self.input_device = None
|
||||
self.param_original_devices = {}
|
||||
self.buffer_original_devices = {}
|
||||
|
||||
def init_hook(self, module):
|
||||
if not self.offload and self.execution_device is not None:
|
||||
for name, _ in named_module_tensors(module, recurse=self.place_submodules):
|
||||
set_module_tensor_to_device(module, name, self.execution_device)
|
||||
elif self.offload:
|
||||
self.original_devices = {name: param.device for name, param in named_module_tensors(module)}
|
||||
if self.weights_map is None:
|
||||
self.weights_map = {
|
||||
name: param.to("cpu")
|
||||
for name, param in named_module_tensors(module, include_buffers=self.offload_buffers)
|
||||
}
|
||||
|
||||
for name, _ in named_module_tensors(module, include_buffers=self.offload_buffers):
|
||||
set_module_tensor_to_device(module, name, "meta")
|
||||
if not self.offload_buffers and self.execution_device is not None:
|
||||
for name, _ in module.named_buffers(recurse=False):
|
||||
set_module_tensor_to_device(module, name, self.execution_device)
|
||||
return module
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
if self.io_same_device:
|
||||
self.input_device = find_device([args, kwargs])
|
||||
if self.offload:
|
||||
for name, _ in named_module_tensors(module, include_buffers=self.offload_buffers):
|
||||
set_module_tensor_to_device(module, name, self.execution_device, value=self.weights_map[name])
|
||||
|
||||
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
|
||||
|
||||
def post_forward(self, module, output):
|
||||
if self.offload:
|
||||
for name, _ in named_module_tensors(module, include_buffers=self.offload_buffers):
|
||||
set_module_tensor_to_device(module, name, "meta")
|
||||
|
||||
if self.io_same_device and self.input_device is not None:
|
||||
output = send_to_device(output, self.input_device)
|
||||
|
||||
return output
|
||||
|
||||
def detach_hook(self, module):
|
||||
if self.offload:
|
||||
for name, device in self.original_devices.items():
|
||||
if device != torch.device("meta"):
|
||||
set_module_tensor_to_device(module, name, device, value=self.weights_map.get(name, None))
|
||||
|
||||
|
||||
def attach_align_device_hook(
|
||||
module: torch.nn.Module,
|
||||
execution_device: Optional[torch.device] = None,
|
||||
offload: bool = False,
|
||||
weights_map: Optional[Mapping] = None,
|
||||
offload_buffers: bool = False,
|
||||
module_name: str = "",
|
||||
):
|
||||
"""
|
||||
Recursively attaches `AlignDevicesHook` to all submodules of a given model that have direct parameters and/or
|
||||
buffers.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module where we want to attach the hooks.
|
||||
execution_device (`torch.device`, *optional*):
|
||||
The device on which inputs and model weights should be placed before the forward pass.
|
||||
offload (`bool`, *optional*, defauts to `False`):
|
||||
Whether or not the weights should be offloaded after the forward pass.
|
||||
weights_map (`Mapping[str, torch.Tensor]`, *optional*):
|
||||
When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
|
||||
offload_buffers (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to include the associated module's buffers when offloading.
|
||||
module_name (`str`, *optional*, defaults to `""`):
|
||||
The name of the module.
|
||||
"""
|
||||
# Attach the hook on this module if it has any direct tensor.
|
||||
directs = named_module_tensors(module)
|
||||
if len(list(directs)) > 0:
|
||||
if weights_map is not None:
|
||||
prefix = f"{module_name}." if len(module_name) > 0 else ""
|
||||
prefixed_weights_map = PrefixedDataset(weights_map, prefix)
|
||||
else:
|
||||
prefixed_weights_map = None
|
||||
hook = AlignDevicesHook(
|
||||
execution_device=execution_device,
|
||||
offload=offload,
|
||||
weights_map=prefixed_weights_map,
|
||||
offload_buffers=offload_buffers,
|
||||
)
|
||||
add_hook_to_module(module, hook)
|
||||
|
||||
# Recurse on all children of the module.
|
||||
for child_name, child in module.named_children():
|
||||
child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
|
||||
attach_align_device_hook(
|
||||
child,
|
||||
execution_device=execution_device,
|
||||
offload=offload,
|
||||
weights_map=weights_map,
|
||||
offload_buffers=offload_buffers,
|
||||
module_name=child_name,
|
||||
)
|
||||
|
||||
|
||||
def remove_hook_from_submodules(module: nn.Module):
|
||||
"""
|
||||
Recursively removes all hooks attached on the submodules of a given model.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`): The module on which to remove all hooks.
|
||||
"""
|
||||
remove_hook_from_module(module)
|
||||
for child in module.children():
|
||||
remove_hook_from_submodules(child)
|
||||
|
||||
|
||||
def attach_align_device_hook_on_blocks(
|
||||
module: nn.Module,
|
||||
execution_device: Optional[Union[torch.device, Dict[str, torch.device]]] = None,
|
||||
offload: Union[bool, Dict[str, bool]] = False,
|
||||
weights_map: Mapping = None,
|
||||
offload_buffers: bool = False,
|
||||
module_name: str = "",
|
||||
):
|
||||
"""
|
||||
Attaches `AlignDevicesHook` to all blocks of a given model as needed.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module where we want to attach the hooks.
|
||||
execution_device (`torch.device` or `Dict[str, torch.device]`, *optional*):
|
||||
The device on which inputs and model weights should be placed before the forward pass. It can be one device
|
||||
for the whole module, or a dictionary mapping module name to device.
|
||||
offload (`bool`, *optional*, defauts to `False`):
|
||||
Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole
|
||||
module, or a dictionary mapping module name to boolean.
|
||||
weights_map (`Mapping[str, torch.Tensor]`, *optional*):
|
||||
When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
|
||||
offload_buffers (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to include the associated module's buffers when offloading.
|
||||
module_name (`str`, *optional*, defaults to `""`):
|
||||
The name of the module.
|
||||
"""
|
||||
# If one device and one offload, we've got one hook.
|
||||
if not isinstance(execution_device, Mapping) and not isinstance(offload, dict):
|
||||
if not offload:
|
||||
hook = AlignDevicesHook(execution_device=execution_device, io_same_device=True, place_submodules=True)
|
||||
add_hook_to_module(module, hook)
|
||||
else:
|
||||
attach_align_device_hook(
|
||||
module,
|
||||
execution_device=execution_device,
|
||||
offload=True,
|
||||
weights_map=weights_map,
|
||||
offload_buffers=offload_buffers,
|
||||
module_name=module_name,
|
||||
)
|
||||
return
|
||||
|
||||
if not isinstance(execution_device, Mapping):
|
||||
execution_device = {key: offload for key in offload.keys()}
|
||||
if not isinstance(offload, Mapping):
|
||||
offload = {key: offload for key in execution_device.keys()}
|
||||
|
||||
if module_name in execution_device and not offload[module_name]:
|
||||
hook = AlignDevicesHook(
|
||||
execution_device=execution_device[module_name],
|
||||
offload_buffers=offload_buffers,
|
||||
io_same_device=(module_name == ""),
|
||||
place_submodules=True,
|
||||
)
|
||||
add_hook_to_module(module, hook)
|
||||
elif module_name in execution_device:
|
||||
attach_align_device_hook(
|
||||
module,
|
||||
execution_device=execution_device[module_name],
|
||||
offload=True,
|
||||
weights_map=weights_map,
|
||||
offload_buffers=offload_buffers,
|
||||
module_name=module_name,
|
||||
)
|
||||
if not hasattr(module, "_hf_hook"):
|
||||
hook = AlignDevicesHook(execution_device=execution_device[module_name], io_same_device=(module_name == ""))
|
||||
add_hook_to_module(module, hook)
|
||||
elif module_name == "":
|
||||
hook = AlignDevicesHook(io_same_device=True)
|
||||
add_hook_to_module(module, hook)
|
||||
|
||||
for child_name, child in module.named_children():
|
||||
child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
|
||||
attach_align_device_hook_on_blocks(
|
||||
child,
|
||||
execution_device=execution_device,
|
||||
offload=offload,
|
||||
weights_map=weights_map,
|
||||
offload_buffers=offload_buffers,
|
||||
module_name=child_name,
|
||||
)
|
||||
@ -1,90 +0,0 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class KwargsHandler:
|
||||
"""
|
||||
Internal mixin that implements a `to_kwargs()` method for a dataclass.
|
||||
"""
|
||||
|
||||
def to_dict(self):
|
||||
return copy.deepcopy(self.__dict__)
|
||||
|
||||
def to_kwargs(self):
|
||||
"""
|
||||
Returns a dictionary containing the attributes with values different from the default of this class.
|
||||
"""
|
||||
default_dict = self.__class__().to_dict()
|
||||
this_dict = self.to_dict()
|
||||
return {k: v for k, v in this_dict.items() if default_dict[k] != v}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DistributedDataParallelKwargs(KwargsHandler):
|
||||
"""
|
||||
Use this object in your [`Accelerator`] to customize how your model is wrapped in a
|
||||
`torch.nn.parallel.DistributedDataParallel`. Please refer to the documentation of this
|
||||
[wrapper](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) for more
|
||||
information on each argument.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
`gradient_as_bucket_view` is only available in PyTorch 1.7.0 and later versions.
|
||||
|
||||
</Tip>"""
|
||||
|
||||
dim: int = 0
|
||||
broadcast_buffers: bool = True
|
||||
bucket_cap_mb: int = 25
|
||||
find_unused_parameters: bool = False
|
||||
check_reduction: bool = False
|
||||
gradient_as_bucket_view: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class GradScalerKwargs(KwargsHandler):
|
||||
"""
|
||||
Use this object in your [`Accelerator`] to customize the behavior of mixed precision, specifically how the
|
||||
`torch.cuda.amp.GradScaler` used is created. Please refer to the documentation of this
|
||||
[scaler](https://pytorch.org/docs/stable/amp.html?highlight=gradscaler) for more information on each argument.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
`GradScaler` is only available in PyTorch 1.5.0 and later versions.
|
||||
|
||||
</Tip>"""
|
||||
|
||||
init_scale: float = 65536.0
|
||||
growth_factor: float = 2.0
|
||||
backoff_factor: float = 0.5
|
||||
growth_interval: int = 2000
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitProcessGroupKwargs(KwargsHandler):
|
||||
"""
|
||||
Use this object in your [`Accelerator`] to customize the initialization of the distributed processes. Please refer
|
||||
to the documentation of this
|
||||
[method](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more
|
||||
information on each argument.
|
||||
"""
|
||||
|
||||
init_method: Optional[str] = None
|
||||
timeout: timedelta = timedelta(seconds=1800)
|
||||
63
src/accelerate/logging.py
Normal file
63
src/accelerate/logging.py
Normal file
@ -0,0 +1,63 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from .state import AcceleratorState
|
||||
|
||||
|
||||
class MultiProcessAdapter(logging.LoggerAdapter):
|
||||
"""
|
||||
An adapter to assist with logging in multiprocess.
|
||||
|
||||
`log` takes in an additional `main_process_only` kwarg, which dictates whether it should be called on all processes
|
||||
or only the main executed one. Default is `main_process_only=True`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _should_log(main_process_only):
|
||||
"Check if log should be performed"
|
||||
return not main_process_only or (main_process_only and AcceleratorState().local_process_index == 0)
|
||||
|
||||
def log(self, level, msg, *args, **kwargs):
|
||||
"""
|
||||
Delegates logger call after checking if we should log.
|
||||
|
||||
Accepts a new kwarg of `main_process_only`, which will dictate whether it will be logged across all processes
|
||||
or only the main executed one. Default is `True` if not passed
|
||||
"""
|
||||
main_process_only = kwargs.pop("main_process_only", True)
|
||||
if self.isEnabledFor(level) and self._should_log(main_process_only):
|
||||
msg, kwargs = self.process(msg, kwargs)
|
||||
self.logger.log(level, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(name: str):
|
||||
"""
|
||||
Returns a `logging.Logger` for `name` that can handle multiprocessing.
|
||||
|
||||
If a log should be called on all processes, pass `main_process_only=False`
|
||||
|
||||
E.g.
|
||||
```python
|
||||
logger.info("My log", main_process_only=False)
|
||||
logger.debug("My log", main_process_only=False)
|
||||
```
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
The name for the logger, such as `__file__`
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
return MultiProcessAdapter(logger, {})
|
||||
@ -12,75 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
A collection of utilities for ensuring that training can always occur. Heavily influenced by the
|
||||
[toma](https://github.com/BlackHC/toma) library.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import gc
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all
|
||||
|
||||
|
||||
def should_reduce_batch_size(exception: Exception) -> bool:
|
||||
"""
|
||||
Checks if `exception` relates to CUDA out-of-memory, CUDNN not supported, or CPU out-of-memory
|
||||
|
||||
Args:
|
||||
exception (`Exception`):
|
||||
An exception
|
||||
"""
|
||||
_statements = [
|
||||
"CUDA out of memory.", # CUDA OOM
|
||||
"cuDNN error: CUDNN_STATUS_NOT_SUPPORTED.", # CUDNN SNAFU
|
||||
"DefaultCPUAllocator: can't allocate memory", # CPU OOM
|
||||
]
|
||||
if isinstance(exception, RuntimeError) and len(exception.args) == 1:
|
||||
return any(err in exception.args[0] for err in _statements)
|
||||
return False
|
||||
import warnings
|
||||
|
||||
|
||||
def find_executable_batch_size(function: callable = None, starting_batch_size: int = 128):
|
||||
"""
|
||||
A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
|
||||
CUDNN, the batch size is cut in half and passed to `function`
|
||||
warnings.warn(
|
||||
"memory_utils has been reorganized to utils.memory. Import `find_executable_batchsize` from the main `__init__`: "
|
||||
"`from accelerate import find_executable_batch_size` to avoid this warning.",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
`function` must take in a `batch_size` parameter as its first argument.
|
||||
|
||||
Args:
|
||||
function (`callable`, *optional*):
|
||||
A function to wrap
|
||||
starting_batch_size (`int`, *optional*):
|
||||
The batch size to try and fit into memory
|
||||
"""
|
||||
if function is None:
|
||||
return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size)
|
||||
|
||||
batch_size = starting_batch_size
|
||||
|
||||
def decorator(*args, **kwargs):
|
||||
nonlocal batch_size
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
params = list(inspect.signature(function).parameters.keys())
|
||||
# Guard against user error
|
||||
if len(params) < (len(args) + 1):
|
||||
arg_str = ", ".join([f"{arg}={value}" for arg, value in zip(params[1:], args[1:])])
|
||||
raise TypeError(
|
||||
f"Batch size was passed into `{function.__name__}` as the first argument when called."
|
||||
f"Remove this as the decorator already does so: `{function.__name__}({arg_str})`"
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
return function(batch_size, *args, **kwargs)
|
||||
except Exception as e:
|
||||
if should_reduce_batch_size(e):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
batch_size //= 2
|
||||
else:
|
||||
raise
|
||||
|
||||
return decorator
|
||||
from .utils.memory import find_executable_batch_size
|
||||
|
||||
@ -19,8 +19,8 @@ import torch
|
||||
|
||||
from packaging import version
|
||||
|
||||
from .state import AcceleratorState, DistributedType, is_tpu_available
|
||||
from .utils import honor_type
|
||||
from .state import AcceleratorState
|
||||
from .utils import DistributedType, honor_type, is_tpu_available
|
||||
|
||||
|
||||
if is_tpu_available():
|
||||
|
||||
@ -12,29 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from distutils.util import strtobool
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
try:
|
||||
import torch_ccl # noqa: F401
|
||||
|
||||
_ccl_available = True
|
||||
except ImportError:
|
||||
_ccl_available = False
|
||||
from .utils import DistributedType, is_ccl_available, is_deepspeed_available, is_tpu_available
|
||||
|
||||
|
||||
try:
|
||||
if is_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
_tpu_available = True
|
||||
except ImportError:
|
||||
_tpu_available = False
|
||||
|
||||
|
||||
def get_int_from_env(env_keys, default):
|
||||
"""Returns the first positive env value found in the `env_keys` list or the default."""
|
||||
@ -45,22 +33,6 @@ def get_int_from_env(env_keys, default):
|
||||
return default
|
||||
|
||||
|
||||
def is_ccl_available():
|
||||
return _ccl_available
|
||||
|
||||
|
||||
def is_apex_available():
|
||||
return importlib.util.find_spec("apex") is not None
|
||||
|
||||
|
||||
def is_tpu_available():
|
||||
return _tpu_available
|
||||
|
||||
|
||||
def is_deepspeed_available():
|
||||
return importlib.util.find_spec("deepspeed") is not None
|
||||
|
||||
|
||||
def parse_flag_from_env(key, default=False):
|
||||
value = os.environ.get(key, str(default))
|
||||
return strtobool(value) == 1 # As its name indicates `strtobool` actually returns an int...
|
||||
@ -71,60 +43,6 @@ def parse_choice_from_env(key, default="no"):
|
||||
return value
|
||||
|
||||
|
||||
class DistributedType(str, Enum):
|
||||
"""
|
||||
Represents a type of distributed environment.
|
||||
|
||||
Values:
|
||||
|
||||
- **NO** -- Not a distributed environment, just a single process.
|
||||
- **MULTI_CPU** -- Distributed on multiple CPU nodes.
|
||||
- **MULTI_GPU** -- Distributed on multiple GPUs.
|
||||
- **DEEPSPEED** -- Using DeepSpeed.
|
||||
- **TPU** -- Distributed on TPUs.
|
||||
"""
|
||||
|
||||
# Subclassing str as well as Enum allows the `DistributedType` to be JSON-serializable out of the box.
|
||||
NO = "NO"
|
||||
MULTI_CPU = "MULTI_CPU"
|
||||
MULTI_GPU = "MULTI_GPU"
|
||||
DEEPSPEED = "DEEPSPEED"
|
||||
FSDP = "FSDP"
|
||||
TPU = "TPU"
|
||||
|
||||
|
||||
class SageMakerDistributedType(str, Enum):
|
||||
"""
|
||||
Represents a type of distributed environment.
|
||||
|
||||
Values:
|
||||
|
||||
- **NO** -- Not a distributed environment, just a single process.
|
||||
- **DATA_PARALLEL** -- using sagemaker distributed data parallelism.
|
||||
- **MODEL_PARALLEL** -- using sagemaker distributed model parallelism.
|
||||
"""
|
||||
|
||||
# Subclassing str as well as Enum allows the `SageMakerDistributedType` to be JSON-serializable out of the box.
|
||||
NO = "NO"
|
||||
DATA_PARALLEL = "DATA_PARALLEL"
|
||||
MODEL_PARALLEL = "MODEL_PARALLEL"
|
||||
|
||||
|
||||
class ComputeEnvironment(str, Enum):
|
||||
"""
|
||||
Represents a type of the compute environment.
|
||||
|
||||
Values:
|
||||
|
||||
- **LOCAL_MACHINE** -- private/custom cluster hardware.
|
||||
- **AMAZON_SAGEMAKER** -- Amazon SageMaker as compute environment.
|
||||
"""
|
||||
|
||||
# Subclassing str as well as Enum allows the `ComputeEnvironment` to be JSON-serializable out of the box.
|
||||
LOCAL_MACHINE = "LOCAL_MACHINE"
|
||||
AMAZON_SAGEMAKER = "AMAZON_SAGEMAKER"
|
||||
|
||||
|
||||
# Inspired by Alex Martelli's 'Borg'.
|
||||
class AcceleratorState:
|
||||
"""
|
||||
|
||||
@ -2,5 +2,5 @@
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
from .testing import are_the_same_tensors, execute_subprocess_async, require_cuda, require_multi_gpu, require_tpu
|
||||
from .testing import are_the_same_tensors, execute_subprocess_async, require_cuda, require_multi_gpu, require_tpu, slow
|
||||
from .training import RegressionDataset, RegressionModel
|
||||
|
||||
@ -19,9 +19,9 @@ from torch.utils.data import DataLoader
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.data_loader import prepare_data_loader
|
||||
from accelerate.state import AcceleratorState, DistributedType
|
||||
from accelerate.state import AcceleratorState
|
||||
from accelerate.test_utils import RegressionDataset, RegressionModel, are_the_same_tensors
|
||||
from accelerate.utils import gather, set_seed, synchronize_rng_states
|
||||
from accelerate.utils import DistributedType, gather, set_seed, synchronize_rng_states
|
||||
from packaging import version
|
||||
|
||||
|
||||
|
||||
@ -25,8 +25,8 @@ from unittest import mock
|
||||
|
||||
import torch
|
||||
|
||||
from ..state import AcceleratorState, is_tpu_available
|
||||
from ..utils import gather, is_tensorflow_available
|
||||
from ..state import AcceleratorState
|
||||
from ..utils import gather, is_comet_ml_available, is_tensorflow_available, is_tpu_available, is_wandb_available
|
||||
|
||||
|
||||
def parse_flag_from_env(key, default=False):
|
||||
@ -53,10 +53,51 @@ def slow(test_case):
|
||||
Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a
|
||||
truthy value to run them.
|
||||
"""
|
||||
if not _run_slow_tests:
|
||||
return unittest.skip("test is slow")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
|
||||
|
||||
|
||||
def require_cuda(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires CUDA. These tests are skipped when there are no GPU available.
|
||||
"""
|
||||
return unittest.skipUnless(torch.cuda.is_available(), "test requires a GPU")(test_case)
|
||||
|
||||
|
||||
def require_tpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires TPUs. These tests are skipped when there are no TPUs available.
|
||||
"""
|
||||
return unittest.skipUnless(is_tpu_available(), "test requires TPU")(test_case)
|
||||
|
||||
|
||||
def require_multi_gpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a multi-GPU setup. These tests are skipped on a machine without multiple
|
||||
GPUs.
|
||||
"""
|
||||
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
|
||||
|
||||
|
||||
def require_tensorflow(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires TensorFlow installed. These tests are skipped when TensorFlow isn't
|
||||
installed
|
||||
"""
|
||||
return unittest.skipUnless(is_tensorflow_available(), "test requires TensorFlow")(test_case)
|
||||
|
||||
|
||||
def require_wandb(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires wandb installed. These tests are skipped when wandb isn't installed
|
||||
"""
|
||||
return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
|
||||
|
||||
|
||||
def require_comet_ml(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires comet_ml installed. These tests are skipped when comet_ml isn't installed
|
||||
"""
|
||||
return unittest.skipUnless(is_comet_ml_available(), "test requires comet_ml")(test_case)
|
||||
|
||||
|
||||
class TempDirTestCase(unittest.TestCase):
|
||||
@ -136,48 +177,6 @@ def are_the_same_tensors(tensor):
|
||||
return True
|
||||
|
||||
|
||||
def require_cuda(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires CUDA. These tests are skipped when there are no GPU available.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
return unittest.skip("test requires a GPU")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_tpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires TPUs. These tests are skipped when there are no TPUs available.
|
||||
"""
|
||||
if not is_tpu_available():
|
||||
return unittest.skip("test requires TPU")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_multi_gpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a multi-GPU setup. These tests are skipped on a machine without multiple
|
||||
GPUs.
|
||||
"""
|
||||
if torch.cuda.device_count() < 2:
|
||||
return unittest.skip("test requires multiple GPUs")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_tensorflow(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires TensorFlow installed. These tests are skipped when TensorFlow isn't
|
||||
installed
|
||||
"""
|
||||
if not is_tensorflow_available():
|
||||
return unittest.skip("test requires TensorFlow")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
class _RunOutput:
|
||||
def __init__(self, returncode, stdout, stderr):
|
||||
self.returncode = returncode
|
||||
|
||||
@ -15,11 +15,11 @@
|
||||
# Expectation:
|
||||
# Provide a project dir name, then each type of logger gets stored in project/{`logging_dir`}
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABCMeta, abstractmethod, abstractproperty
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from .logging import get_logger
|
||||
from .utils import LoggerType, is_comet_ml_available, is_tensorboard_available, is_wandb_available
|
||||
|
||||
|
||||
@ -41,7 +41,7 @@ if is_comet_ml_available():
|
||||
_available_trackers.append(LoggerType.COMETML)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_available_trackers():
|
||||
|
||||
76
src/accelerate/utils/__init__.py
Normal file
76
src/accelerate/utils/__init__.py
Normal file
@ -0,0 +1,76 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all
|
||||
|
||||
from .constants import MODEL_NAME, OPTIMIZER_NAME, RNG_STATE_NAME, SCALER_NAME, SCHEDULER_NAME
|
||||
from .dataclasses import (
|
||||
ComputeEnvironment,
|
||||
DeepSpeedPlugin,
|
||||
DistributedDataParallelKwargs,
|
||||
DistributedType,
|
||||
FullyShardedDataParallelPlugin,
|
||||
GradScalerKwargs,
|
||||
InitProcessGroupKwargs,
|
||||
KwargsHandler,
|
||||
LoggerType,
|
||||
PrecisionType,
|
||||
RNGType,
|
||||
SageMakerDistributedType,
|
||||
TensorInformation,
|
||||
)
|
||||
from .imports import (
|
||||
is_apex_available,
|
||||
is_boto3_available,
|
||||
is_ccl_available,
|
||||
is_comet_ml_available,
|
||||
is_deepspeed_available,
|
||||
is_sagemaker_available,
|
||||
is_tensorboard_available,
|
||||
is_tensorflow_available,
|
||||
is_tpu_available,
|
||||
is_wandb_available,
|
||||
)
|
||||
from .modeling import (
|
||||
check_device_map,
|
||||
compute_module_sizes,
|
||||
convert_file_size_to_int,
|
||||
dtype_byte_size,
|
||||
find_tied_parameters,
|
||||
get_max_layer_size,
|
||||
get_max_memory,
|
||||
infer_auto_device_map,
|
||||
load_checkpoint_in_model,
|
||||
named_module_tensors,
|
||||
set_module_tensor_to_device,
|
||||
)
|
||||
from .offload import OffloadedWeightsLoader, PrefixedDataset, extract_submodules_state_dict, offload_state_dict
|
||||
from .operations import (
|
||||
broadcast,
|
||||
broadcast_object_list,
|
||||
concatenate,
|
||||
convert_outputs_to_fp32,
|
||||
convert_to_fp32,
|
||||
find_batch_size,
|
||||
find_device,
|
||||
gather,
|
||||
gather_object,
|
||||
get_data_structure,
|
||||
honor_type,
|
||||
initialize_tensors,
|
||||
is_tensor_information,
|
||||
is_torch_tensor,
|
||||
pad_across_processes,
|
||||
recursively_apply,
|
||||
reduce,
|
||||
send_to_device,
|
||||
slice_tensors,
|
||||
)
|
||||
|
||||
|
||||
if is_deepspeed_available():
|
||||
from .deepspeed import DeepSpeedEngineWrapper, DeepSpeedOptimizerWrapper
|
||||
|
||||
from .launch import PrepareForLaunch
|
||||
from .memory import find_executable_batch_size
|
||||
from .other import extract_model_from_parallel, get_pretty_name, patch_environment, save, wait_for_everyone
|
||||
from .random import set_seed, synchronize_rng_state, synchronize_rng_states
|
||||
19
src/accelerate/utils/constants.py
Normal file
19
src/accelerate/utils/constants.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
SCALER_NAME = "scaler.pt"
|
||||
MODEL_NAME = "pytorch_model"
|
||||
RNG_STATE_NAME = "random_states"
|
||||
OPTIMIZER_NAME = "optimizer"
|
||||
SCHEDULER_NAME = "scheduler"
|
||||
304
src/accelerate/utils/dataclasses.py
Normal file
304
src/accelerate/utils/dataclasses.py
Normal file
@ -0,0 +1,304 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
General namespace and dataclass related classes
|
||||
"""
|
||||
|
||||
import copy
|
||||
import enum
|
||||
import functools
|
||||
import os
|
||||
import typing
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import timedelta
|
||||
from typing import Callable, Iterable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class KwargsHandler:
|
||||
"""
|
||||
Internal mixin that implements a `to_kwargs()` method for a dataclass.
|
||||
"""
|
||||
|
||||
def to_dict(self):
|
||||
return copy.deepcopy(self.__dict__)
|
||||
|
||||
def to_kwargs(self):
|
||||
"""
|
||||
Returns a dictionary containing the attributes with values different from the default of this class.
|
||||
"""
|
||||
default_dict = self.__class__().to_dict()
|
||||
this_dict = self.to_dict()
|
||||
return {k: v for k, v in this_dict.items() if default_dict[k] != v}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DistributedDataParallelKwargs(KwargsHandler):
|
||||
"""
|
||||
Use this object in your [`Accelerator`] to customize how your model is wrapped in a
|
||||
`torch.nn.parallel.DistributedDataParallel`. Please refer to the documentation of this
|
||||
[wrapper](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) for more
|
||||
information on each argument.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
`gradient_as_bucket_view` is only available in PyTorch 1.7.0 and later versions.
|
||||
|
||||
</Tip>"""
|
||||
|
||||
dim: int = 0
|
||||
broadcast_buffers: bool = True
|
||||
bucket_cap_mb: int = 25
|
||||
find_unused_parameters: bool = False
|
||||
check_reduction: bool = False
|
||||
gradient_as_bucket_view: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class GradScalerKwargs(KwargsHandler):
|
||||
"""
|
||||
Use this object in your [`Accelerator`] to customize the behavior of mixed precision, specifically how the
|
||||
`torch.cuda.amp.GradScaler` used is created. Please refer to the documentation of this
|
||||
[scaler](https://pytorch.org/docs/stable/amp.html?highlight=gradscaler) for more information on each argument.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
`GradScaler` is only available in PyTorch 1.5.0 and later versions.
|
||||
|
||||
</Tip>"""
|
||||
|
||||
init_scale: float = 65536.0
|
||||
growth_factor: float = 2.0
|
||||
backoff_factor: float = 0.5
|
||||
growth_interval: int = 2000
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitProcessGroupKwargs(KwargsHandler):
|
||||
"""
|
||||
Use this object in your [`Accelerator`] to customize the initialization of the distributed processes. Please refer
|
||||
to the documentation of this
|
||||
[method](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more
|
||||
information on each argument.
|
||||
"""
|
||||
|
||||
init_method: Optional[str] = None
|
||||
timeout: timedelta = timedelta(seconds=1800)
|
||||
|
||||
|
||||
class DistributedType(str, enum.Enum):
|
||||
"""
|
||||
Represents a type of distributed environment.
|
||||
|
||||
Values:
|
||||
|
||||
- **NO** -- Not a distributed environment, just a single process.
|
||||
- **MULTI_CPU** -- Distributed on multiple CPU nodes.
|
||||
- **MULTI_GPU** -- Distributed on multiple GPUs.
|
||||
- **DEEPSPEED** -- Using DeepSpeed.
|
||||
- **TPU** -- Distributed on TPUs.
|
||||
"""
|
||||
|
||||
# Subclassing str as well as Enum allows the `DistributedType` to be JSON-serializable out of the box.
|
||||
NO = "NO"
|
||||
MULTI_CPU = "MULTI_CPU"
|
||||
MULTI_GPU = "MULTI_GPU"
|
||||
DEEPSPEED = "DEEPSPEED"
|
||||
FSDP = "FSDP"
|
||||
TPU = "TPU"
|
||||
|
||||
|
||||
class SageMakerDistributedType(str, enum.Enum):
|
||||
"""
|
||||
Represents a type of distributed environment.
|
||||
|
||||
Values:
|
||||
|
||||
- **NO** -- Not a distributed environment, just a single process.
|
||||
- **DATA_PARALLEL** -- using sagemaker distributed data parallelism.
|
||||
- **MODEL_PARALLEL** -- using sagemaker distributed model parallelism.
|
||||
"""
|
||||
|
||||
# Subclassing str as well as Enum allows the `SageMakerDistributedType` to be JSON-serializable out of the box.
|
||||
NO = "NO"
|
||||
DATA_PARALLEL = "DATA_PARALLEL"
|
||||
MODEL_PARALLEL = "MODEL_PARALLEL"
|
||||
|
||||
|
||||
class ComputeEnvironment(str, enum.Enum):
|
||||
"""
|
||||
Represents a type of the compute environment.
|
||||
|
||||
Values:
|
||||
|
||||
- **LOCAL_MACHINE** -- private/custom cluster hardware.
|
||||
- **AMAZON_SAGEMAKER** -- Amazon SageMaker as compute environment.
|
||||
"""
|
||||
|
||||
# Subclassing str as well as Enum allows the `ComputeEnvironment` to be JSON-serializable out of the box.
|
||||
LOCAL_MACHINE = "LOCAL_MACHINE"
|
||||
AMAZON_SAGEMAKER = "AMAZON_SAGEMAKER"
|
||||
|
||||
|
||||
class EnumWithContains(enum.EnumMeta):
|
||||
"A metaclass that adds the ability to check if `self` contains an item with the `in` operator"
|
||||
|
||||
def __contains__(cls, item):
|
||||
try:
|
||||
cls(item)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class BaseEnum(enum.Enum, metaclass=EnumWithContains):
|
||||
"An enum class that can get the value of an item with `str(Enum.key)`"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
@classmethod
|
||||
def list(cls):
|
||||
"Method to list all the possible items in `cls`"
|
||||
return list(map(lambda item: str(item), cls))
|
||||
|
||||
|
||||
class LoggerType(BaseEnum):
|
||||
ALL = "all"
|
||||
TENSORBOARD = "tensorboard"
|
||||
WANDB = "wandb"
|
||||
COMETML = "comet_ml"
|
||||
|
||||
|
||||
class PrecisionType(BaseEnum):
|
||||
NO = "no"
|
||||
FP16 = "fp16"
|
||||
BF16 = "bf16"
|
||||
|
||||
|
||||
class RNGType(BaseEnum):
|
||||
TORCH = "torch"
|
||||
CUDA = "cuda"
|
||||
XLA = "xla"
|
||||
GENERATOR = "generator"
|
||||
|
||||
|
||||
# data classes
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorInformation:
|
||||
shape: torch.Size
|
||||
dtype: torch.dtype
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepSpeedPlugin:
|
||||
|
||||
gradient_accumulation_steps: int = field(
|
||||
default=None, metadata={"help": "Number of steps to accumulate gradients before updating optimizer states"}
|
||||
)
|
||||
zero_stage: int = field(
|
||||
default=None,
|
||||
metadata={"help": "Possible options are 0,1,2,3; Default will be taken from environment variable"},
|
||||
)
|
||||
is_train_batch_min: str = field(
|
||||
default=True,
|
||||
metadata={"help": "If both train & eval dataloaders are specified, this will decide the train_batch_size"},
|
||||
)
|
||||
|
||||
auto_opt_mapping: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "whether to map torch.adam to deepspeed optimizer version of adam based on config"},
|
||||
)
|
||||
|
||||
offload_optimizer_device: bool = field(default=None, metadata={"help": "Possible options are none|cpu|nvme"})
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
if self.gradient_accumulation_steps is None:
|
||||
self.gradient_accumulation_steps = int(os.environ.get("GRADIENT_ACCUMULATION_STEPS", 1))
|
||||
|
||||
if self.zero_stage is None:
|
||||
self.zero_stage = int(os.environ.get("DEEPSPEED_ZERO_STAGE", 2))
|
||||
|
||||
if self.offload_optimizer_device is None:
|
||||
self.offload_optimizer_device = os.environ.get("DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE", "none")
|
||||
|
||||
self.deepspeed_config = {
|
||||
"train_batch_size": None,
|
||||
"gradient_accumulation_steps": self.gradient_accumulation_steps,
|
||||
"zero_optimization": {
|
||||
"stage": self.zero_stage,
|
||||
"offload_optimizer": {
|
||||
"device": self.offload_optimizer_device,
|
||||
},
|
||||
},
|
||||
"steps_per_print": float("inf"), # this will stop deepspeed from logging @ stdout
|
||||
"zero_allow_untested_optimizer": True,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class FullyShardedDataParallelPlugin:
|
||||
"""
|
||||
This plugin is used to enable fully sharded data parallelism.
|
||||
"""
|
||||
|
||||
sharding_strategy: "typing.Any" = field(
|
||||
default=None,
|
||||
metadata={"help": "Possible options are [1] FULL_SHARD, [2] SHARD_GRAD_OP"},
|
||||
)
|
||||
backward_prefetch: "typing.Any" = field(
|
||||
default=None,
|
||||
metadata={"help": "Possible options are [1] BACKWARD_PRE, [2] BACKWARD_POST"},
|
||||
)
|
||||
auto_wrap_policy: "typing.Any" = field(
|
||||
default=None,
|
||||
metadata={"help": "A callable specifying a policy to recursively wrap layers with FSDP"},
|
||||
)
|
||||
cpu_offload: Optional[Callable] = field(
|
||||
default=None,
|
||||
metadata={"help": "Decides Whether to offload parameters and gradients to CPU."},
|
||||
)
|
||||
min_num_params: int = field(
|
||||
default=None, metadata={"help": "FSDP's minimum number of parameters for Default Auto Wrapping."}
|
||||
)
|
||||
ignored_modules: Optional[Iterable[torch.nn.Module]] = field(
|
||||
default=None,
|
||||
metadata={"help": "A list of modules to ignore for FSDP."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, ShardingStrategy
|
||||
from torch.distributed.fsdp.wrap import default_auto_wrap_policy
|
||||
|
||||
if self.sharding_strategy is None:
|
||||
self.sharding_strategy = ShardingStrategy(int(os.environ.get("FSDP_SHARDING_STRATEGY", 1)))
|
||||
|
||||
if self.cpu_offload is None:
|
||||
if os.environ.get("FSDP_OFFLOAD_PARAMS", "false") == "true":
|
||||
self.cpu_offload = CPUOffload(offload_params=True)
|
||||
else:
|
||||
self.cpu_offload = CPUOffload(offload_params=False)
|
||||
|
||||
if self.min_num_params is None:
|
||||
self.min_num_params = int(os.environ.get("FSDP_MIN_NUM_PARAMS", 0))
|
||||
|
||||
if self.auto_wrap_policy is None:
|
||||
if self.min_num_params > 0:
|
||||
self.auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=self.min_num_params)
|
||||
@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .optimizer import AcceleratedOptimizer
|
||||
from .state import is_apex_available, is_deepspeed_available
|
||||
from ..optimizer import AcceleratedOptimizer
|
||||
from .imports import is_apex_available, is_deepspeed_available
|
||||
|
||||
|
||||
if is_deepspeed_available():
|
||||
72
src/accelerate/utils/imports.py
Normal file
72
src/accelerate/utils/imports.py
Normal file
@ -0,0 +1,72 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import importlib
|
||||
|
||||
|
||||
try:
|
||||
import torch_ccl # noqa: F401
|
||||
|
||||
_ccl_available = True
|
||||
except ImportError:
|
||||
_ccl_available = False
|
||||
|
||||
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm # noqa: F401
|
||||
|
||||
_tpu_available = True
|
||||
except ImportError:
|
||||
_tpu_available = False
|
||||
|
||||
|
||||
def is_ccl_available():
|
||||
return _ccl_available
|
||||
|
||||
|
||||
def is_apex_available():
|
||||
return importlib.util.find_spec("apex") is not None
|
||||
|
||||
|
||||
def is_tpu_available():
|
||||
return _tpu_available
|
||||
|
||||
|
||||
def is_deepspeed_available():
|
||||
return importlib.util.find_spec("deepspeed") is not None
|
||||
|
||||
|
||||
def is_tensorflow_available():
|
||||
return importlib.util.find_spec("tensorflow") is not None
|
||||
|
||||
|
||||
def is_tensorboard_available():
|
||||
return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
|
||||
|
||||
|
||||
def is_wandb_available():
|
||||
return importlib.util.find_spec("wandb") is not None
|
||||
|
||||
|
||||
def is_comet_ml_available():
|
||||
return importlib.util.find_spec("comet_ml") is not None
|
||||
|
||||
|
||||
def is_boto3_available():
|
||||
return importlib.util.find_spec("boto3") is not None
|
||||
|
||||
|
||||
def is_sagemaker_available():
|
||||
return importlib.util.find_spec("sagemaker") is not None
|
||||
55
src/accelerate/utils/launch.py
Normal file
55
src/accelerate/utils/launch.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from .dataclasses import DistributedType
|
||||
|
||||
|
||||
class PrepareForLaunch:
|
||||
"""
|
||||
Prepare a function that will launched in a distributed setup.
|
||||
|
||||
Args:
|
||||
launcher (`Callable`):
|
||||
The function to launch.
|
||||
distributed_type ([`~state.DistributedType`]):
|
||||
The distributed type to prepare for.
|
||||
debug (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not this is a debug launch.
|
||||
"""
|
||||
|
||||
def __init__(self, launcher, distributed_type="NO", debug=False):
|
||||
self.launcher = launcher
|
||||
self.distributed_type = DistributedType(distributed_type)
|
||||
self.debug = debug
|
||||
|
||||
def __call__(self, index, *args):
|
||||
if self.debug:
|
||||
world_size = int(os.environ.get("WORLD_SIZE"))
|
||||
rdv_file = os.environ.get("ACCELERATE_DEBUG_RDV_FILE")
|
||||
torch.distributed.init_process_group(
|
||||
"gloo",
|
||||
rank=index,
|
||||
store=torch.distributed.FileStore(rdv_file, world_size),
|
||||
world_size=world_size,
|
||||
)
|
||||
elif self.distributed_type == DistributedType.MULTI_GPU or self.distributed_type == DistributedType.MULTI_CPU:
|
||||
# Prepare the environment for torch.distributed
|
||||
os.environ["LOCAL_RANK"] = str(index)
|
||||
os.environ["RANK"] = str(index)
|
||||
|
||||
self.launcher(*args)
|
||||
88
src/accelerate/utils/memory.py
Normal file
88
src/accelerate/utils/memory.py
Normal file
@ -0,0 +1,88 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
A collection of utilities for ensuring that training can always occur. Heavily influenced by the
|
||||
[toma](https://github.com/BlackHC/toma) library.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import gc
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def should_reduce_batch_size(exception: Exception) -> bool:
|
||||
"""
|
||||
Checks if `exception` relates to CUDA out-of-memory, CUDNN not supported, or CPU out-of-memory
|
||||
|
||||
Args:
|
||||
exception (`Exception`):
|
||||
An exception
|
||||
"""
|
||||
_statements = [
|
||||
"CUDA out of memory.", # CUDA OOM
|
||||
"cuDNN error: CUDNN_STATUS_NOT_SUPPORTED.", # CUDNN SNAFU
|
||||
"DefaultCPUAllocator: can't allocate memory", # CPU OOM
|
||||
]
|
||||
if isinstance(exception, RuntimeError) and len(exception.args) == 1:
|
||||
return any(err in exception.args[0] for err in _statements)
|
||||
return False
|
||||
|
||||
|
||||
def find_executable_batch_size(function: callable = None, starting_batch_size: int = 128):
|
||||
"""
|
||||
A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or
|
||||
CUDNN, the batch size is cut in half and passed to `function`
|
||||
|
||||
`function` must take in a `batch_size` parameter as its first argument.
|
||||
|
||||
Args:
|
||||
function (`callable`, *optional*):
|
||||
A function to wrap
|
||||
starting_batch_size (`int`, *optional*):
|
||||
The batch size to try and fit into memory
|
||||
"""
|
||||
if function is None:
|
||||
return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size)
|
||||
|
||||
batch_size = starting_batch_size
|
||||
|
||||
def decorator(*args, **kwargs):
|
||||
nonlocal batch_size
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
params = list(inspect.signature(function).parameters.keys())
|
||||
# Guard against user error
|
||||
if len(params) < (len(args) + 1):
|
||||
arg_str = ", ".join([f"{arg}={value}" for arg, value in zip(params[1:], args[1:])])
|
||||
raise TypeError(
|
||||
f"Batch size was passed into `{function.__name__}` as the first argument when called."
|
||||
f"Remove this as the decorator already does so: `{function.__name__}({arg_str})`"
|
||||
)
|
||||
while True:
|
||||
if batch_size == 0:
|
||||
raise RuntimeError("No executable batch size found, reached zero.")
|
||||
try:
|
||||
return function(batch_size, *args, **kwargs)
|
||||
except Exception as e:
|
||||
if should_reduce_batch_size(e):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
batch_size //= 2
|
||||
else:
|
||||
raise
|
||||
|
||||
return decorator
|
||||
624
src/accelerate/utils/modeling.py
Normal file
624
src/accelerate/utils/modeling.py
Normal file
@ -0,0 +1,624 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||
|
||||
|
||||
def convert_file_size_to_int(size: Union[int, str]):
|
||||
"""
|
||||
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
|
||||
|
||||
Args:
|
||||
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> convert_file_size_to_int("1MiB")
|
||||
1048576
|
||||
```
|
||||
"""
|
||||
if isinstance(size, int):
|
||||
return size
|
||||
if size.upper().endswith("GIB"):
|
||||
return int(size[:-3]) * (2**30)
|
||||
if size.upper().endswith("MIB"):
|
||||
return int(size[:-3]) * (2**20)
|
||||
if size.upper().endswith("KIB"):
|
||||
return int(size[:-3]) * (2**10)
|
||||
if size.upper().endswith("GB"):
|
||||
int_size = int(size[:-2]) * (10**9)
|
||||
return int_size // 8 if size.endswith("b") else int_size
|
||||
if size.upper().endswith("MB"):
|
||||
int_size = int(size[:-2]) * (10**6)
|
||||
return int_size // 8 if size.endswith("b") else int_size
|
||||
if size.upper().endswith("KB"):
|
||||
int_size = int(size[:-2]) * (10**3)
|
||||
return int_size // 8 if size.endswith("b") else int_size
|
||||
raise ValueError("`size` is not in a valid format. Use an integer followed by the unit, e.g., '5GB'.")
|
||||
|
||||
|
||||
def dtype_byte_size(dtype: torch.dtype):
|
||||
"""
|
||||
Returns the size (in bytes) occupied by one parameter of type `dtype`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> dtype_byte_size(torch.float32)
|
||||
4
|
||||
```
|
||||
"""
|
||||
if dtype == torch.bool:
|
||||
return 1 / 8
|
||||
bit_search = re.search("[^\d](\d+)$", str(dtype))
|
||||
if bit_search is None:
|
||||
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
||||
bit_size = int(bit_search.groups()[0])
|
||||
return bit_size // 8
|
||||
|
||||
|
||||
def set_module_tensor_to_device(
|
||||
module: nn.Module, tensor_name: str, device: Union[int, str, torch.device], value: Optional[torch.Tensor] = None
|
||||
):
|
||||
"""
|
||||
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
|
||||
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function).
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`): The module in which the tensor we want to move lives.
|
||||
param_name (`str`): The full name of the parameter/buffer.
|
||||
device (`int`, `str` or `torch.device`): The device on which to set the tensor.
|
||||
value (`torch.Tensor`, *optional*): The value of the tensor (useful when going from the meta device to any
|
||||
other device).
|
||||
"""
|
||||
# Recurse if needed
|
||||
if "." in tensor_name:
|
||||
splits = tensor_name.split(".")
|
||||
for split in splits[:-1]:
|
||||
new_module = getattr(module, split)
|
||||
if new_module is None:
|
||||
raise ValueError(f"{module} has no attribute {split}.")
|
||||
module = new_module
|
||||
tensor_name = splits[-1]
|
||||
|
||||
if tensor_name not in module._parameters and tensor_name not in module._buffers:
|
||||
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
|
||||
is_buffer = tensor_name in module._buffers
|
||||
old_value = getattr(module, tensor_name)
|
||||
|
||||
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
|
||||
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
|
||||
|
||||
with torch.no_grad():
|
||||
if value is None:
|
||||
new_value = old_value.to(device)
|
||||
elif isinstance(value, torch.Tensor):
|
||||
new_value = value.to(device)
|
||||
else:
|
||||
new_value = torch.tensor(value, device=device)
|
||||
if is_buffer:
|
||||
module._buffers[tensor_name] = new_value
|
||||
else:
|
||||
new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad)
|
||||
module._parameters[tensor_name] = new_value
|
||||
|
||||
|
||||
def named_module_tensors(module: nn.Module, include_buffers: bool = True, recurse: bool = False):
|
||||
"""
|
||||
A helper function that gathers all the tensors (parameters + buffers) of a given module. If `include_buffers=True`
|
||||
it's the same as doing `module.named_parameters(recurse=recurse) + module.named_buffers(recurse=recurse)`.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`): The module we want the tensors or.
|
||||
include_buffer (`bool`, *optional*, defaults to `True`): Whether or not to include the buffers in the result.
|
||||
recurse (`bool`, *optional`, defaults to `False`):
|
||||
Whether or not to go look in every submodule or just return the direct parameters and buffers.
|
||||
"""
|
||||
for named_parameter in module.named_parameters(recurse=recurse):
|
||||
yield named_parameter
|
||||
|
||||
if include_buffers:
|
||||
for named_buffer in module.named_buffers(recurse=recurse):
|
||||
yield named_buffer
|
||||
|
||||
|
||||
def find_tied_parameters(model: nn.Module, **kwargs):
|
||||
"""
|
||||
Find the tied parameters in a given model.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): The model to inspect.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
|
||||
them.
|
||||
|
||||
</Tip>
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
```py
|
||||
>>> from collections import OrderedDict
|
||||
>>> import torch.nn as nn
|
||||
|
||||
>>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
|
||||
>>> model.linear2.weight = test_model.linear1.weight
|
||||
>>> find_tied_parameters(test_model)
|
||||
{'linear1.weight': 'linear2.weight'}
|
||||
```
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: A dictionary mapping tied parameter names to the name of the parameter they are tied to.
|
||||
"""
|
||||
# Initialize result and named_parameters before recursing.
|
||||
named_parameters = kwargs.get("named_parameters", None)
|
||||
prefix = kwargs.get("prefix", "")
|
||||
result = kwargs.get("result", {})
|
||||
|
||||
if named_parameters is None:
|
||||
named_parameters = {n: p for n, p in model.named_parameters()}
|
||||
else:
|
||||
# A tied parameter will not be in the full `named_parameters` seen above but will be in the `named_parameters`
|
||||
# of the submodule it belongs to. So while recursing we track the names that are not in the initial
|
||||
# `named_parameters`.
|
||||
for name, parameter in model.named_parameters():
|
||||
full_name = name if prefix == "" else f"{prefix}.{name}"
|
||||
if full_name not in named_parameters:
|
||||
# When we find one, it has to be one of the existing parameters.
|
||||
for new_name, new_param in named_parameters.items():
|
||||
if new_param is parameter:
|
||||
result[new_name] = full_name
|
||||
|
||||
# Once we have treated direct parameters, we move to the child modules.
|
||||
for name, child in model.named_children():
|
||||
child_name = name if prefix == "" else f"{prefix}.{name}"
|
||||
find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def compute_module_sizes(model: nn.Module, dtype: Optional[Union[str, torch.device]] = None):
|
||||
"""
|
||||
Compute the size of each submodule of a given model.
|
||||
"""
|
||||
if isinstance(dtype, str):
|
||||
# We accept "torch.float16" or just "float16"
|
||||
dtype = dtype.replace("torch.", "")
|
||||
dtype = getattr(torch, dtype)
|
||||
if dtype is not None:
|
||||
dtype_size = dtype_byte_size(dtype)
|
||||
module_sizes = defaultdict(int)
|
||||
for name, tensor in named_module_tensors(model, recurse=True):
|
||||
if dtype is None:
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
else:
|
||||
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
|
||||
name_parts = name.split(".")
|
||||
for idx in range(len(name_parts) + 1):
|
||||
module_sizes[".".join(name_parts[:idx])] += size
|
||||
|
||||
return module_sizes
|
||||
|
||||
|
||||
def get_max_layer_size(
|
||||
modules: List[Tuple[str, torch.nn.Module]], module_sizes: Dict[str, int], no_split_module_classes: List[str]
|
||||
):
|
||||
"""
|
||||
Utility function that will scan a list of named modules and return the maximum size used by one full layer. The
|
||||
definition of a layer being:
|
||||
- a module with no direct children (just parameters and buffers)
|
||||
- a module whose class name is in the list `no_split_module_classes`
|
||||
|
||||
Args:
|
||||
modules (`List[Tuple[str, torch.nn.Module]]`):
|
||||
The list of named modules where we want to determine the maximum layer size.
|
||||
module_sizes (`Dict[str, int]`):
|
||||
A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`).
|
||||
no_split_module_classes (`List[str]`):
|
||||
A list of class names for layers we don't want to be split.
|
||||
|
||||
Returns:
|
||||
`Tuple[int, List[str]]`: The maximum size of a layer with the list of layer names realizing that maximum size.
|
||||
"""
|
||||
max_size = 0
|
||||
layer_names = []
|
||||
modules_to_treat = modules.copy()
|
||||
while len(modules_to_treat) > 0:
|
||||
module_name, module = modules_to_treat.pop(0)
|
||||
modules_children = list(module.named_children())
|
||||
if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
|
||||
# No splitting this one so we compare to the max_size
|
||||
size = module_sizes[module_name]
|
||||
if size > max_size:
|
||||
max_size = size
|
||||
layer_names = [module_name]
|
||||
elif size == max_size:
|
||||
layer_names.append(module_name)
|
||||
else:
|
||||
modules_to_treat = [(f"{module_name}.{n}", v) for n, v in modules_children] + modules_to_treat
|
||||
return max_size, layer_names
|
||||
|
||||
|
||||
def get_max_memory(max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None):
|
||||
"""
|
||||
Get the maximum memory available if nothing is passed, converts string to int otherwise.
|
||||
"""
|
||||
import psutil
|
||||
|
||||
if max_memory is None:
|
||||
if not torch.cuda.is_available():
|
||||
max_memory = {}
|
||||
else:
|
||||
# Make sure CUDA is initialized on each GPU to have the right memory info.
|
||||
for i in range(torch.cuda.device_count()):
|
||||
_ = torch.tensor([0], device=i)
|
||||
max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())}
|
||||
max_memory["cpu"] = psutil.virtual_memory().available
|
||||
return max_memory
|
||||
|
||||
for key in max_memory:
|
||||
if isinstance(max_memory[key], str):
|
||||
max_memory[key] = convert_file_size_to_int(max_memory[key])
|
||||
return max_memory
|
||||
|
||||
|
||||
def clean_device_map(device_map: Dict[str, Union[int, str, torch.device]], module_name: str = ""):
|
||||
"""
|
||||
Cleans a device_map by grouping all submodules that go on the same device together.
|
||||
"""
|
||||
# Get the value of the current module and if there is only one split across several keys, regroup it.
|
||||
prefix = "" if module_name == "" else f"{module_name}."
|
||||
values = [v for k, v in device_map.items() if k.startswith(prefix)]
|
||||
if len(set(values)) == 1 and len(values) > 1:
|
||||
for k in [k for k in device_map if k.startswith(prefix)]:
|
||||
del device_map[k]
|
||||
device_map[module_name] = values[0]
|
||||
|
||||
# Recurse over the children
|
||||
children_modules = [k for k in device_map.keys() if k.startswith(module_name) and len(k) > len(module_name)]
|
||||
idx = len(module_name.split(".")) + 1 if len(module_name) > 0 else 1
|
||||
children_modules = set(".".join(k.split(".")[:idx]) for k in children_modules)
|
||||
for child in children_modules:
|
||||
clean_device_map(device_map, module_name=child)
|
||||
|
||||
return device_map
|
||||
|
||||
|
||||
def infer_auto_device_map(
|
||||
model: nn.Module,
|
||||
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
|
||||
no_split_module_classes: Optional[List[str]] = None,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
):
|
||||
"""
|
||||
Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,
|
||||
such that:
|
||||
- we don't exceed the memory available of any of the GPU.
|
||||
- if offload to the CPU is needed, there is always room left on GPU 0 to put back the layer offloaded on CPU that
|
||||
has the largest size.
|
||||
- if offload to the CPU is needed,we don't exceed the RAM available on the CPU.
|
||||
- if offload to the disk is needed, there is always room left on the CPU to put back the layer offloaded on disk
|
||||
that has the largest size.
|
||||
|
||||
<Tip>
|
||||
|
||||
All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the
|
||||
meta device (as it would if initialized within the `init_empty_weights` context manager).
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): The model to analyze.
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
|
||||
no_split_module_classes (`List[str]`, *optional*):
|
||||
A list of layer class names that should never be split across device (for instance any layer that has a
|
||||
residual connection).
|
||||
dtype (`str` or `torch.dtype`, *optional*):
|
||||
If provided, the weights will be converted to that type when loaded.
|
||||
"""
|
||||
# Get default / clean up max_memory
|
||||
max_memory = get_max_memory(max_memory)
|
||||
if no_split_module_classes is None:
|
||||
no_split_module_classes = []
|
||||
elif not isinstance(no_split_module_classes, (list, tuple)):
|
||||
no_split_module_classes = [no_split_module_classes]
|
||||
|
||||
devices = list(max_memory.keys())
|
||||
gpus = [device for device in devices if device != "cpu"]
|
||||
if "disk" not in devices:
|
||||
devices.append("disk")
|
||||
|
||||
# Devices that need to keep space for a potential offloaded layer.
|
||||
main_devices = [gpus[0], "cpu"] if len(gpus) > 0 else ["cpu"]
|
||||
|
||||
module_sizes = compute_module_sizes(model, dtype=dtype)
|
||||
tied_parameters = find_tied_parameters(model)
|
||||
|
||||
device_map = {}
|
||||
current_device = 0
|
||||
current_memory_used = 0
|
||||
|
||||
# Direct submodules and parameters
|
||||
modules_to_treat = list(model.named_parameters(recurse=False)) + list(model.named_children())
|
||||
# Initialize maximum largest layer, to know which space to keep in memory
|
||||
max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes)
|
||||
|
||||
# Ready ? This is going to be a bit messy.
|
||||
while len(modules_to_treat) > 0:
|
||||
name, module = modules_to_treat.pop(0)
|
||||
# Max size in the remaining layers may have changed since we took one, so we maybe update it.
|
||||
max_layer_names = [n for n in max_layer_names if not n.startswith(name)]
|
||||
if len(max_layer_names) == 0:
|
||||
max_layer_size, max_layer_names = get_max_layer_size(
|
||||
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
|
||||
module_sizes,
|
||||
no_split_module_classes,
|
||||
)
|
||||
# Assess size needed
|
||||
module_size = module_sizes[name]
|
||||
tied_params = [v for k, v in tied_parameters.items() if name in k]
|
||||
# We ignore parameters that are tied when they're tied to > 1 one
|
||||
tied_param = tied_params[0] if len(tied_params) == 1 else None
|
||||
|
||||
device = devices[current_device]
|
||||
current_max_size = max_memory[device] if device != "disk" else None
|
||||
# Reduce max size available by the largest layer.
|
||||
if devices[current_device] in main_devices:
|
||||
current_max_size = current_max_size - max_layer_size
|
||||
# Case 1 -> We're too big!
|
||||
if current_max_size is not None and current_memory_used + module_size > current_max_size:
|
||||
# Split or not split?
|
||||
modules_children = list(module.named_children())
|
||||
if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
|
||||
# -> no split, we go to the next device
|
||||
current_device += 1
|
||||
modules_to_treat = [(name, module)] + modules_to_treat
|
||||
current_memory_used = 0
|
||||
else:
|
||||
# -> split, we replace the module studied by its children + parameters
|
||||
modules_children = list(module.named_parameters(recurse=False)) + modules_children
|
||||
modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat
|
||||
# Update the max layer size.
|
||||
max_layer_size, max_layer_names = get_max_layer_size(
|
||||
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
|
||||
module_sizes,
|
||||
no_split_module_classes,
|
||||
)
|
||||
|
||||
# Case 2, it fits! We're not entirely out of the wood though, because we may have some tied parameters.
|
||||
elif tied_param is not None:
|
||||
# Determine the sized occupied by this module + the module containing the tied parameter
|
||||
tied_module_size = module_size
|
||||
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n in tied_param][0]
|
||||
tied_module_name, tied_module = modules_to_treat[tied_module_index]
|
||||
tied_module_size += module_sizes[tied_module_name] - module_sizes[tied_param]
|
||||
if current_max_size is not None and current_memory_used + tied_module_size > current_max_size:
|
||||
# Split or not split?
|
||||
tied_module_children = list(tied_module.named_children())
|
||||
if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes:
|
||||
# If the tied module is not split, we go to the next device
|
||||
current_device += 1
|
||||
modules_to_treat = [(name, module)] + modules_to_treat
|
||||
current_memory_used = 0
|
||||
else:
|
||||
# Otherwise, we replace the tied module by its children.
|
||||
tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children
|
||||
tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children]
|
||||
modules_to_treat = (
|
||||
[(name, module)]
|
||||
+ modules_to_treat[:tied_module_index]
|
||||
+ tied_module_children
|
||||
+ modules_to_treat[tied_module_index + 1 :]
|
||||
)
|
||||
# Update the max layer size.
|
||||
max_layer_size, max_layer_names = get_max_layer_size(
|
||||
[(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
|
||||
module_sizes,
|
||||
no_split_module_classes,
|
||||
)
|
||||
else:
|
||||
# We really really fit!
|
||||
current_memory_used += tied_module_size
|
||||
device_map[name] = devices[current_device]
|
||||
modules_to_treat.pop(tied_module_index)
|
||||
device_map[tied_module_name] = devices[current_device]
|
||||
else:
|
||||
current_memory_used += module_size
|
||||
device_map[name] = devices[current_device]
|
||||
|
||||
return clean_device_map(device_map)
|
||||
|
||||
|
||||
def check_device_map(model: nn.Module, device_map: Dict[str, Union[int, str, torch.device]]):
|
||||
"""
|
||||
Checks a device map covers everything in a given model.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): The model to check the device map against.
|
||||
device_map (`Dict[str, Union[int, str, torch.device]]`): The device map to check.
|
||||
"""
|
||||
all_model_tensors = [name for name, _ in model.state_dict().items()]
|
||||
for module_name in device_map.keys():
|
||||
all_model_tensors = [name for name in all_model_tensors if not name.startswith(module_name)]
|
||||
if len(all_model_tensors) > 0:
|
||||
non_covered_params = ", ".join(all_model_tensors)
|
||||
raise ValueError(
|
||||
f"The device_map provided does not give any device for the following parameters: {non_covered_params}"
|
||||
)
|
||||
|
||||
|
||||
def load_checkpoint_in_model(
|
||||
model: nn.Module,
|
||||
checkpoint: Union[str, os.PathLike],
|
||||
device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
|
||||
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
offload_state_dict: bool = False,
|
||||
):
|
||||
"""
|
||||
Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are
|
||||
loaded.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Once loaded across devices, you still need to call [`dispatch_model`] on your model to make it able to run. To
|
||||
group the checkpoint loading and dispatch in one single call, use [`load_checkpoint_and_dispatch`].
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): The model in which we want to load a checkpoint.
|
||||
checkpoint (`str` or `os.PathLike`):
|
||||
The folder checkpoint to load. It can be:
|
||||
- a path to a file containing a whole model state dict
|
||||
- a path to a `.json` file containing the index to a sharded checkpoint
|
||||
- a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
|
||||
device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
|
||||
name, once a given module name is inside, every submodule of it will be sent to the same device.
|
||||
offload_folder (`str` or `os.PathLike`, *optional*):
|
||||
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
||||
dtype (`str` or `torch.dtype`, *optional*):
|
||||
If provided, the weights will be converted to that type when loaded.
|
||||
offload_state_dict (`bool`, *optional*, defaults to `False`):
|
||||
If `True`, will temporarily offload the CPU state dict on the hard drive to avoig getting out of CPU RAM if
|
||||
the weight of the CPU state dict + the biggest shard does not fit.
|
||||
"""
|
||||
if offload_folder is None and device_map is not None and "disk" in device_map.values():
|
||||
raise ValueError(
|
||||
"At least one of the model submodule will be offloaded to disk, please pass along an `offload_folder`."
|
||||
)
|
||||
elif offload_folder is not None and device_map is not None and "disk" in device_map.values():
|
||||
os.makedirs(offload_folder, exist_ok=True)
|
||||
|
||||
if isinstance(dtype, str):
|
||||
# We accept "torch.float16" or just "float16"
|
||||
dtype = dtype.replace("torch.", "")
|
||||
dtype = getattr(torch, dtype)
|
||||
|
||||
checkpoint_files = None
|
||||
index_filename = None
|
||||
if os.path.isfile(checkpoint):
|
||||
if str(checkpoint).endswith(".json"):
|
||||
index_filename = checkpoint
|
||||
else:
|
||||
checkpoint_files = [checkpoint]
|
||||
elif os.path.isdir(checkpoint):
|
||||
potential_index = [f for f in os.listdir(checkpoint) if f.endswith(".index.json")]
|
||||
if len(potential_index) == 0:
|
||||
raise ValueError(f"{checkpoint} is not a folder containing a `.index.json` file.")
|
||||
elif len(potential_index) == 1:
|
||||
index_filename = os.path.join(checkpoint, potential_index[0])
|
||||
else:
|
||||
raise ValueError(f"{checkpoint} containing mote than one `.index.json` file, delete the irrelevant ones.")
|
||||
else:
|
||||
raise ValueError(
|
||||
"`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded "
|
||||
f"checkpoint, or a folder containing a sharded checkpoint, but got {checkpoint}."
|
||||
)
|
||||
|
||||
if index_filename is not None:
|
||||
checkpoint_folder = os.path.split(index_filename)[0]
|
||||
with open(index_filename, "r") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
if "weight_map" in index:
|
||||
index = index["weight_map"]
|
||||
checkpoint_files = sorted(list(set(index.values())))
|
||||
checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files]
|
||||
|
||||
# Logic for missing/unexepected keys goes here.
|
||||
|
||||
offload_index = {}
|
||||
if offload_state_dict:
|
||||
state_dict_folder = tempfile.mkdtemp()
|
||||
state_dict_index = {}
|
||||
|
||||
for checkpoint_file in checkpoint_files:
|
||||
checkpoint = torch.load(checkpoint_file)
|
||||
if device_map is None:
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
else:
|
||||
for param_name, param in checkpoint.items():
|
||||
module_name = param_name
|
||||
if dtype is not None:
|
||||
param = param.to(dtype)
|
||||
while len(module_name) > 0 and module_name not in device_map:
|
||||
module_name = ".".join(module_name.split(".")[:-1])
|
||||
if module_name == "" and "" not in device_map:
|
||||
# TODO: group all errors and raise at the end.
|
||||
raise ValueError(f"{param_name} doesn't have any device set.")
|
||||
param_device = device_map[module_name]
|
||||
|
||||
if param_device == "disk":
|
||||
set_module_tensor_to_device(model, param_name, "meta")
|
||||
tensor_file = os.path.join(offload_folder, f"{param_name}.dat")
|
||||
array = param.numpy()
|
||||
offload_index[param_name] = {"dtype": str(array.dtype), "shape": list(array.shape)}
|
||||
file_array = np.memmap(tensor_file, dtype=array.dtype, mode="w+", shape=array.shape)
|
||||
file_array[:] = array[:]
|
||||
file_array.flush()
|
||||
elif param_device == "cpu" and offload_state_dict:
|
||||
set_module_tensor_to_device(model, param_name, "meta")
|
||||
tensor_file = os.path.join(state_dict_folder, f"{param_name}.dat")
|
||||
array = param.numpy()
|
||||
state_dict_index[param_name] = {"dtype": str(array.dtype), "shape": list(array.shape)}
|
||||
file_array = np.memmap(tensor_file, dtype=array.dtype, mode="w+", shape=array.shape)
|
||||
file_array[:] = array[:]
|
||||
file_array.flush()
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
||||
|
||||
# Force Python to clean up.
|
||||
del checkpoint
|
||||
gc.collect()
|
||||
|
||||
if len(offload_index) > 0:
|
||||
offload_index_file = os.path.join(offload_folder, "index.json")
|
||||
if os.path.isfile(offload_index_file):
|
||||
with open(offload_index_file, "r", encoding="utf-8") as f:
|
||||
current_offload_index = json.load(f)
|
||||
else:
|
||||
current_offload_index = {}
|
||||
current_offload_index.update(offload_index)
|
||||
|
||||
with open(offload_index_file, "w", encoding="utf-8") as f:
|
||||
json.dump(current_offload_index, f, indent=2)
|
||||
|
||||
# Load back offloaded state dict on CPU
|
||||
if offload_state_dict and len(state_dict_index) > 0:
|
||||
for param_name, metadata in state_dict_index.items():
|
||||
tensor_file = os.path.join(state_dict_folder, f"{param_name}.dat")
|
||||
shape = tuple(metadata["shape"])
|
||||
weight = np.memmap(tensor_file, dtype=metadata["dtype"], mode="r", shape=shape)
|
||||
set_module_tensor_to_device(model, param_name, "cpu", value=torch.tensor(weight))
|
||||
shutil.rmtree(state_dict_folder)
|
||||
143
src/accelerate/utils/offload.py
Normal file
143
src/accelerate/utils/offload.py
Normal file
@ -0,0 +1,143 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def offload_state_dict(save_dir: Union[str, os.PathLike], state_dict: Dict[str, torch.Tensor]):
|
||||
"""
|
||||
Offload a state dict in a given folder.
|
||||
|
||||
Args:
|
||||
save_dir (`str` or `os.PathLike`): The directory in which to offload the state dict.
|
||||
state_dict (`Dict[str, torch.Tensor]`): The dictionary of tensors to offload.
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
index = {}
|
||||
for name, parameter in state_dict.items():
|
||||
tensor_file = os.path.join(save_dir, f"{name}.dat")
|
||||
array = parameter.numpy()
|
||||
index[name] = {"dtype": str(array.dtype), "shape": list(array.shape)}
|
||||
if array.ndim == 0:
|
||||
array = array[None]
|
||||
file_array = np.memmap(tensor_file, dtype=array.dtype, mode="w+", shape=array.shape)
|
||||
file_array[:] = array[:]
|
||||
file_array.flush()
|
||||
|
||||
# Update index
|
||||
index_file = os.path.join(save_dir, "index.json")
|
||||
if os.path.isfile(index_file):
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
current_index = json.load(f)
|
||||
else:
|
||||
current_index = {}
|
||||
current_index.update(index)
|
||||
|
||||
with open(index_file, "w", encoding="utf-8") as f:
|
||||
json.dump(current_index, f, indent=2)
|
||||
|
||||
|
||||
class PrefixedDataset(Mapping):
|
||||
"""
|
||||
Will access keys in a given dataset by adding a prefix.
|
||||
|
||||
Args:
|
||||
dataset (`Mapping`): Any map with string keys.
|
||||
prefix (`str`): A prefix to add when trying to access any element in the underlying dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: Mapping, prefix: str):
|
||||
self.dataset = dataset
|
||||
self.prefix = prefix
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.dataset[f"{self.prefix}{key}"]
|
||||
|
||||
def __iter__(self):
|
||||
return iter([key for key in self.dataset if key.startswith(self.prefix)])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
|
||||
class OffloadedWeightsLoader(Mapping):
|
||||
"""
|
||||
A collection that loads weights stored in a given state dict or memory-mapped on disk.
|
||||
|
||||
Args:
|
||||
state_dict (`Dict[str, torch.Tensor]`, *optional*):
|
||||
A dictionary parameter name to tensor.
|
||||
save_folder (`str` or `os.PathLike`, *optional*):
|
||||
The directory in which the weights are stored (by `offload_state_dict` for instance).
|
||||
index (`Dict`, *optional*):
|
||||
A dictionary from weight name to their information (`dtype` and `shape`). Will default to the index saved
|
||||
in `save_folder`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_dict: Dict[str, torch.Tensor] = None,
|
||||
save_folder: Optional[Union[str, os.PathLike]] = None,
|
||||
index: Mapping = None,
|
||||
):
|
||||
if state_dict is None and save_folder is None:
|
||||
raise ValueError("Need either a `state_dict` or a `save_folder` containing offloaded weights.")
|
||||
|
||||
self.state_dict = {} if state_dict is None else state_dict
|
||||
self.save_folder = save_folder
|
||||
if index is None and save_folder is not None:
|
||||
with open(os.path.join(save_folder, "index.json")) as f:
|
||||
index = json.load(f)
|
||||
self.index = {} if index is None else index
|
||||
self.all_keys = list(self.state_dict.keys())
|
||||
self.all_keys.extend([key for key in self.index if key not in self.all_keys])
|
||||
|
||||
def __getitem__(self, key: str):
|
||||
# State dict gets priority
|
||||
if key in self.state_dict:
|
||||
return self.state_dict[key]
|
||||
weight_info = self.index[key]
|
||||
weight_file = os.path.join(self.save_folder, f"{key}.dat")
|
||||
shape = tuple(weight_info["shape"])
|
||||
if shape == ():
|
||||
weight = np.memmap(weight_file, dtype=weight_info["dtype"], shape=(1,), mode="r")[0]
|
||||
else:
|
||||
weight = np.memmap(weight_file, dtype=weight_info["dtype"], shape=shape, mode="r")
|
||||
return torch.tensor(weight)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.all_keys)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.all_keys)
|
||||
|
||||
|
||||
def extract_submodules_state_dict(state_dict: Dict[str, torch.Tensor], submodule_names: List[str]):
|
||||
"""
|
||||
Extract the sub state-dict corresponding to a list of given submodules.
|
||||
|
||||
Args:
|
||||
state_dict (`Dict[str, torch.Tensor]`): The state dict to extract from.
|
||||
submodule_names (`List[str]`): The list of submodule names we want to extract.
|
||||
"""
|
||||
result = {}
|
||||
for module_name in submodule_names:
|
||||
result.update({key: param for key, param in state_dict.items() if key.startswith(module_name)})
|
||||
return result
|
||||
@ -1,4 +1,4 @@
|
||||
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,172 +12,34 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
import random
|
||||
import typing
|
||||
from collections.abc import Mapping
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, EnumMeta
|
||||
from functools import update_wrapper
|
||||
from typing import Any, Callable, Iterable, List, Optional, Union
|
||||
"""
|
||||
A set of basic tensor ops compatible with tpu, gpu, and multigpu
|
||||
"""
|
||||
|
||||
|
||||
from functools import update_wrapper
|
||||
from typing import Any, Mapping
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from packaging import version
|
||||
|
||||
from .state import AcceleratorState, DistributedType, is_deepspeed_available, is_tpu_available
|
||||
from ..state import AcceleratorState
|
||||
from .dataclasses import DistributedType, TensorInformation
|
||||
from .imports import is_tpu_available
|
||||
|
||||
|
||||
if is_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
def is_tensorflow_available():
|
||||
return importlib.util.find_spec("tensorflow") is not None
|
||||
def is_torch_tensor(tensor):
|
||||
return isinstance(tensor, torch.Tensor)
|
||||
|
||||
|
||||
def is_tensorboard_available():
|
||||
return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
|
||||
|
||||
|
||||
def is_wandb_available():
|
||||
return importlib.util.find_spec("wandb") is not None
|
||||
|
||||
|
||||
def is_comet_ml_available():
|
||||
return importlib.util.find_spec("comet_ml") is not None
|
||||
|
||||
|
||||
def is_boto3_available():
|
||||
return importlib.util.find_spec("boto3") is not None
|
||||
|
||||
|
||||
def is_sagemaker_available():
|
||||
return importlib.util.find_spec("sagemaker") is not None
|
||||
|
||||
|
||||
if is_deepspeed_available():
|
||||
from deepspeed import DeepSpeedEngine
|
||||
|
||||
SCALER_NAME = "scaler.pt"
|
||||
MODEL_NAME = "pytorch_model"
|
||||
RNG_STATE_NAME = "random_states"
|
||||
OPTIMIZER_NAME = "optimizer"
|
||||
SCHEDULER_NAME = "scheduler"
|
||||
|
||||
|
||||
class EnumWithContains(EnumMeta):
|
||||
"A metaclass that adds the ability to check if `self` contains an item with the `in` operator"
|
||||
|
||||
def __contains__(cls, item):
|
||||
try:
|
||||
cls(item)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class BaseEnum(Enum, metaclass=EnumWithContains):
|
||||
"An enum class that can get the value of an item with `str(Enum.key)`"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
@classmethod
|
||||
def list(cls):
|
||||
"Method to list all the possible items in `cls`"
|
||||
return list(map(lambda item: str(item), cls))
|
||||
|
||||
|
||||
class LoggerType(BaseEnum):
|
||||
ALL = "all"
|
||||
TENSORBOARD = "tensorboard"
|
||||
WANDB = "wandb"
|
||||
COMETML = "comet_ml"
|
||||
|
||||
|
||||
class PrecisionType(BaseEnum):
|
||||
NO = "no"
|
||||
FP16 = "fp16"
|
||||
BF16 = "bf16"
|
||||
|
||||
|
||||
class RNGType(BaseEnum):
|
||||
TORCH = "torch"
|
||||
CUDA = "cuda"
|
||||
XLA = "xla"
|
||||
GENERATOR = "generator"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorInformation:
|
||||
shape: torch.Size
|
||||
dtype: torch.dtype
|
||||
|
||||
|
||||
def set_seed(seed: int, device_specific: bool = False):
|
||||
"""
|
||||
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
||||
|
||||
Args:
|
||||
seed (`int`): The seed to set.
|
||||
device_specific (`bool`, *optional*, defaults to `False`):
|
||||
Whether to differ the seed on each device slightly with `self.process_index`.
|
||||
"""
|
||||
if device_specific:
|
||||
seed += AcceleratorState().process_index
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# ^^ safe to call this function even if cuda is not available
|
||||
if is_tpu_available():
|
||||
xm.set_rng_state(seed)
|
||||
|
||||
|
||||
def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optional[torch.Generator] = None):
|
||||
# Get the proper rng state
|
||||
if rng_type == RNGType.TORCH:
|
||||
rng_state = torch.get_rng_state()
|
||||
elif rng_type == RNGType.CUDA:
|
||||
rng_state = torch.cuda.get_rng_state()
|
||||
elif rng_type == RNGType.XLA:
|
||||
assert is_tpu_available(), "Can't synchronize XLA seeds on an environment without TPUs."
|
||||
rng_state = torch.tensor(xm.get_rng_state())
|
||||
elif rng_type == RNGType.GENERATOR:
|
||||
assert generator is not None, "Need a generator to synchronize its seed."
|
||||
rng_state = generator.get_state()
|
||||
|
||||
# Broadcast the rng state from device 0 to other devices
|
||||
state = AcceleratorState()
|
||||
if state.distributed_type == DistributedType.TPU:
|
||||
rng_state = xm.mesh_reduce("random_seed", rng_state, lambda x: x[0])
|
||||
elif state.distributed_type in [DistributedType.DEEPSPEED, DistributedType.MULTI_GPU]:
|
||||
rng_state = rng_state.to(state.device)
|
||||
torch.distributed.broadcast(rng_state, 0)
|
||||
rng_state = rng_state.cpu()
|
||||
elif state.distributed_type == DistributedType.MULTI_CPU:
|
||||
torch.distributed.broadcast(rng_state, 0)
|
||||
|
||||
# Set the broadcast rng state
|
||||
if rng_type == RNGType.TORCH:
|
||||
torch.set_rng_state(rng_state)
|
||||
elif rng_type == RNGType.CUDA:
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
elif rng_type == RNGType.XLA:
|
||||
xm.set_rng_state(rng_state.item())
|
||||
elif rng_type == RNGType.GENERATOR:
|
||||
generator.set_state(rng_state)
|
||||
|
||||
|
||||
def synchronize_rng_states(rng_types: List[Union[str, RNGType]], generator: Optional[torch.Generator] = None):
|
||||
for rng_type in rng_types:
|
||||
synchronize_rng_state(RNGType(rng_type), generator=generator)
|
||||
def is_tensor_information(tensor_info):
|
||||
return isinstance(tensor_info, TensorInformation)
|
||||
|
||||
|
||||
def honor_type(obj, generator):
|
||||
@ -191,14 +53,6 @@ def honor_type(obj, generator):
|
||||
return type(obj)(*list(generator))
|
||||
|
||||
|
||||
def is_torch_tensor(tensor):
|
||||
return isinstance(tensor, torch.Tensor)
|
||||
|
||||
|
||||
def is_tensor_information(tensor_info):
|
||||
return isinstance(tensor_info, TensorInformation)
|
||||
|
||||
|
||||
def recursively_apply(func, data, *args, test_type=is_torch_tensor, error_on_other_type=False, **kwargs):
|
||||
"""
|
||||
Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type.
|
||||
@ -305,73 +159,24 @@ def initialize_tensors(data_structure):
|
||||
return recursively_apply(_initialize_tensor, data_structure, test_type=is_tensor_information)
|
||||
|
||||
|
||||
def convert_to_fp32(tensor):
|
||||
def find_batch_size(data):
|
||||
"""
|
||||
Recursively converts the elements nested list/tuple/dictionary of tensors in FP16/BF16 precision to FP32.
|
||||
Recursively finds the batch size in a nested list/tuple/dictionary of lists of tensors.
|
||||
|
||||
Args:
|
||||
tensor (nested list/tuple/dictionary of `torch.Tensor`):
|
||||
The data to convert from FP16/BF16 to FP32.
|
||||
data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size.
|
||||
|
||||
Returns:
|
||||
The same data structure as `tensor` with all tensors that were in FP16/BF16 precision converted to FP32.
|
||||
`int`: The batch size.
|
||||
"""
|
||||
|
||||
def _convert_to_fp32(tensor):
|
||||
return tensor.float()
|
||||
|
||||
def _is_fp16_bf16_tensor(tensor):
|
||||
return hasattr(tensor, "dtype") and (
|
||||
tensor.dtype == torch.float16
|
||||
or (version.parse(torch.__version__) >= version.parse("1.10") and tensor.dtype == torch.bfloat16)
|
||||
)
|
||||
|
||||
return recursively_apply(_convert_to_fp32, tensor, test_type=_is_fp16_bf16_tensor)
|
||||
|
||||
|
||||
class ConvertOutputsToFp32:
|
||||
"""
|
||||
Decorator to apply to a function outputing tensors (like a model forward pass) that ensures the outputs in FP16
|
||||
precision will be convert back to FP32.
|
||||
|
||||
Use a class instead of a decorator because otherwise, the prepared model can no longer be pickled (issue #273).
|
||||
|
||||
Args:
|
||||
model_forward (`Callable`):
|
||||
The function which outputs we want to treat.
|
||||
|
||||
Returns:
|
||||
The same function as `model_forward` but with converted outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, model_forward):
|
||||
self.model_forward = model_forward
|
||||
update_wrapper(self, model_forward)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return convert_to_fp32(self.model_forward(*args, **kwargs))
|
||||
|
||||
|
||||
convert_outputs_to_fp32 = ConvertOutputsToFp32
|
||||
|
||||
|
||||
def extract_model_from_parallel(model):
|
||||
"""
|
||||
Extract a model from its distributed containers.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): The model to extract.
|
||||
|
||||
Returns:
|
||||
`torch.nn.Module`: The extracted model.
|
||||
"""
|
||||
options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel)
|
||||
if is_deepspeed_available():
|
||||
options += (DeepSpeedEngine,)
|
||||
|
||||
while isinstance(model, options):
|
||||
model = model.module
|
||||
return model
|
||||
if isinstance(data, (tuple, list)):
|
||||
return find_batch_size(data[0])
|
||||
elif isinstance(data, Mapping):
|
||||
for k in data.keys():
|
||||
return find_batch_size(data[k])
|
||||
elif not isinstance(data, torch.Tensor):
|
||||
raise TypeError(f"Can only find the batch size of tensors but got {type(data)}.")
|
||||
return data.shape[0]
|
||||
|
||||
|
||||
def _tpu_gather(tensor, name="gather tensor"):
|
||||
@ -536,26 +341,6 @@ def slice_tensors(data, tensor_slice):
|
||||
return recursively_apply(_slice_tensor, data, tensor_slice)
|
||||
|
||||
|
||||
def find_batch_size(data):
|
||||
"""
|
||||
Recursively finds the batch size in a nested list/tuple/dictionary of lists of tensors.
|
||||
|
||||
Args:
|
||||
data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size.
|
||||
|
||||
Returns:
|
||||
`int`: The batch size.
|
||||
"""
|
||||
if isinstance(data, (tuple, list)):
|
||||
return find_batch_size(data[0])
|
||||
elif isinstance(data, Mapping):
|
||||
for k in data.keys():
|
||||
return find_batch_size(data[k])
|
||||
elif not isinstance(data, torch.Tensor):
|
||||
raise TypeError(f"Can only find the batch size of tensors but got {type(data)}.")
|
||||
return data.shape[0]
|
||||
|
||||
|
||||
def concatenate(data, dim=0):
|
||||
"""
|
||||
Recursively concatenate the tensors in a nested list/tuple/dictionary of lists of tensors with the same shape.
|
||||
@ -657,198 +442,72 @@ def reduce(tensor, reduction="mean"):
|
||||
return recursively_apply(_reduce_across_processes, tensor, error_on_other_type=True, reduction=reduction)
|
||||
|
||||
|
||||
def wait_for_everyone():
|
||||
def convert_to_fp32(tensor):
|
||||
"""
|
||||
Introduces a blocking point in the script, making sure all processes have reached this point before continuing.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure all processes will reach this instruction otherwise one of your processes will hang forever.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
if (
|
||||
AcceleratorState().distributed_type == DistributedType.MULTI_GPU
|
||||
or AcceleratorState().distributed_type == DistributedType.MULTI_CPU
|
||||
or AcceleratorState().distributed_type == DistributedType.DEEPSPEED
|
||||
):
|
||||
torch.distributed.barrier()
|
||||
elif AcceleratorState().distributed_type == DistributedType.TPU:
|
||||
xm.rendezvous("accelerate.utils.wait_for_everyone")
|
||||
|
||||
|
||||
def save(obj, f):
|
||||
"""
|
||||
Save the data to disk. Use in place of `torch.save()`.
|
||||
Recursively converts the elements nested list/tuple/dictionary of tensors in FP16/BF16 precision to FP32.
|
||||
|
||||
Args:
|
||||
obj: The data to save
|
||||
f: The file (or file-like object) to use to save the data
|
||||
tensor (nested list/tuple/dictionary of `torch.Tensor`):
|
||||
The data to convert from FP16/BF16 to FP32.
|
||||
|
||||
Returns:
|
||||
The same data structure as `tensor` with all tensors that were in FP16/BF16 precision converted to FP32.
|
||||
"""
|
||||
if AcceleratorState().distributed_type == DistributedType.TPU:
|
||||
xm.save(obj, f)
|
||||
elif AcceleratorState().local_process_index == 0:
|
||||
torch.save(obj, f)
|
||||
|
||||
def _convert_to_fp32(tensor):
|
||||
return tensor.float()
|
||||
|
||||
def _is_fp16_bf16_tensor(tensor):
|
||||
return hasattr(tensor, "dtype") and (
|
||||
tensor.dtype == torch.float16
|
||||
or (version.parse(torch.__version__) >= version.parse("1.10") and tensor.dtype == torch.bfloat16)
|
||||
)
|
||||
|
||||
return recursively_apply(_convert_to_fp32, tensor, test_type=_is_fp16_bf16_tensor)
|
||||
|
||||
|
||||
class PrepareForLaunch:
|
||||
class ConvertOutputsToFp32:
|
||||
"""
|
||||
Prepare a function that will launched in a distributed setup.
|
||||
Decorator to apply to a function outputing tensors (like a model forward pass) that ensures the outputs in FP16
|
||||
precision will be convert back to FP32.
|
||||
|
||||
Use a class instead of a decorator because otherwise, the prepared model can no longer be pickled (issue #273).
|
||||
|
||||
Args:
|
||||
launcher (`Callable`):
|
||||
The function to launch.
|
||||
distributed_type ([`~state.DistributedType`]):
|
||||
The distributed type to prepare for.
|
||||
debug (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not this is a debug launch.
|
||||
model_forward (`Callable`):
|
||||
The function which outputs we want to treat.
|
||||
|
||||
Returns:
|
||||
The same function as `model_forward` but with converted outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, launcher, distributed_type="NO", debug=False):
|
||||
self.launcher = launcher
|
||||
self.distributed_type = DistributedType(distributed_type)
|
||||
self.debug = debug
|
||||
def __init__(self, model_forward):
|
||||
self.model_forward = model_forward
|
||||
update_wrapper(self, model_forward)
|
||||
|
||||
def __call__(self, index, *args):
|
||||
if self.debug:
|
||||
world_size = int(os.environ.get("WORLD_SIZE"))
|
||||
rdv_file = os.environ.get("ACCELERATE_DEBUG_RDV_FILE")
|
||||
torch.distributed.init_process_group(
|
||||
"gloo",
|
||||
rank=index,
|
||||
store=torch.distributed.FileStore(rdv_file, world_size),
|
||||
world_size=world_size,
|
||||
)
|
||||
elif self.distributed_type == DistributedType.MULTI_GPU or self.distributed_type == DistributedType.MULTI_CPU:
|
||||
# Prepare the environment for torch.distributed
|
||||
os.environ["LOCAL_RANK"] = str(index)
|
||||
os.environ["RANK"] = str(index)
|
||||
|
||||
self.launcher(*args)
|
||||
def __call__(self, *args, **kwargs):
|
||||
return convert_to_fp32(self.model_forward(*args, **kwargs))
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepSpeedPlugin:
|
||||
|
||||
gradient_accumulation_steps: int = field(
|
||||
default=None, metadata={"help": "Number of steps to accumulate gradients before updating optimizer states"}
|
||||
)
|
||||
zero_stage: int = field(
|
||||
default=None,
|
||||
metadata={"help": "Possible options are 0,1,2,3; Default will be taken from environment variable"},
|
||||
)
|
||||
is_train_batch_min: str = field(
|
||||
default=True,
|
||||
metadata={"help": "If both train & eval dataloaders are specified, this will decide the train_batch_size"},
|
||||
)
|
||||
|
||||
auto_opt_mapping: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "whether to map torch.adam to deepspeed optimizer version of adam based on config"},
|
||||
)
|
||||
|
||||
offload_optimizer_device: bool = field(default=None, metadata={"help": "Possible options are none|cpu|nvme"})
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
if self.gradient_accumulation_steps is None:
|
||||
self.gradient_accumulation_steps = int(os.environ.get("GRADIENT_ACCUMULATION_STEPS", 1))
|
||||
|
||||
if self.zero_stage is None:
|
||||
self.zero_stage = int(os.environ.get("DEEPSPEED_ZERO_STAGE", 2))
|
||||
|
||||
if self.offload_optimizer_device is None:
|
||||
self.offload_optimizer_device = os.environ.get("DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE", "none")
|
||||
|
||||
self.deepspeed_config = {
|
||||
"train_batch_size": None,
|
||||
"gradient_accumulation_steps": self.gradient_accumulation_steps,
|
||||
"zero_optimization": {
|
||||
"stage": self.zero_stage,
|
||||
"offload_optimizer": {
|
||||
"device": self.offload_optimizer_device,
|
||||
},
|
||||
},
|
||||
"steps_per_print": float("inf"), # this will stop deepspeed from logging @ stdout
|
||||
"zero_allow_untested_optimizer": True,
|
||||
}
|
||||
convert_outputs_to_fp32 = ConvertOutputsToFp32
|
||||
|
||||
|
||||
@dataclass
|
||||
class FullyShardedDataParallelPlugin:
|
||||
def find_device(data):
|
||||
"""
|
||||
This plugin is used to enable fully sharded data parallelism.
|
||||
Finds the device on which a nested dict/list/tuple of tensors lies (assuming they are all on the same device).
|
||||
|
||||
Args:
|
||||
(nested list/tuple/dictionary of `torch.Tensor`): The data we want to know the device of.
|
||||
"""
|
||||
|
||||
sharding_strategy: "typing.Any" = field(
|
||||
default=None,
|
||||
metadata={"help": "Possible options are [1] FULL_SHARD, [2] SHARD_GRAD_OP"},
|
||||
)
|
||||
backward_prefetch: "typing.Any" = field(
|
||||
default=None,
|
||||
metadata={"help": "Possible options are [1] BACKWARD_PRE, [2] BACKWARD_POST"},
|
||||
)
|
||||
auto_wrap_policy: "typing.Any" = field(
|
||||
default=None,
|
||||
metadata={"help": "A callable specifying a policy to recursively wrap layers with FSDP"},
|
||||
)
|
||||
cpu_offload: Optional[Callable] = field(
|
||||
default=None,
|
||||
metadata={"help": "Decides Whether to offload parameters and gradients to CPU."},
|
||||
)
|
||||
min_num_params: int = field(
|
||||
default=None, metadata={"help": "FSDP's minimum number of parameters for Default Auto Wrapping."}
|
||||
)
|
||||
ignored_modules: Optional[Iterable[torch.nn.Module]] = field(
|
||||
default=None,
|
||||
metadata={"help": "A list of modules to ignore for FSDP."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, ShardingStrategy
|
||||
from torch.distributed.fsdp.wrap import default_auto_wrap_policy
|
||||
|
||||
if self.sharding_strategy is None:
|
||||
self.sharding_strategy = ShardingStrategy(int(os.environ.get("FSDP_SHARDING_STRATEGY", 1)))
|
||||
|
||||
if self.cpu_offload is None:
|
||||
if os.environ.get("FSDP_OFFLOAD_PARAMS", "false") == "true":
|
||||
self.cpu_offload = CPUOffload(offload_params=True)
|
||||
else:
|
||||
self.cpu_offload = CPUOffload(offload_params=False)
|
||||
|
||||
if self.min_num_params is None:
|
||||
self.min_num_params = int(os.environ.get("FSDP_MIN_NUM_PARAMS", 0))
|
||||
|
||||
if self.auto_wrap_policy is None:
|
||||
if self.min_num_params > 0:
|
||||
self.auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=self.min_num_params)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_environment(**kwargs):
|
||||
"""
|
||||
A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting.
|
||||
|
||||
Will convert the values in `kwargs` to strings and upper-case all the keys.
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
os.environ[key.upper()] = str(value)
|
||||
|
||||
yield
|
||||
|
||||
for key in kwargs:
|
||||
del os.environ[key.upper()]
|
||||
|
||||
|
||||
def get_pretty_name(obj):
|
||||
"""
|
||||
Gets a pretty name from `obj`.
|
||||
"""
|
||||
if not hasattr(obj, "__qualname__") and not hasattr(obj, "__name__"):
|
||||
obj = getattr(obj, "__class__", obj)
|
||||
if hasattr(obj, "__qualname__"):
|
||||
return obj.__qualname__
|
||||
if hasattr(obj, "__name__"):
|
||||
return obj.__name__
|
||||
return str(obj)
|
||||
if isinstance(data, Mapping):
|
||||
for obj in data.values():
|
||||
device = find_device(obj)
|
||||
if device is not None:
|
||||
return device
|
||||
elif isinstance(data, (tuple, list)):
|
||||
for obj in data:
|
||||
device = find_device(obj)
|
||||
if device is not None:
|
||||
return device
|
||||
elif isinstance(data, torch.Tensor):
|
||||
return data.device
|
||||
111
src/accelerate/utils/other.py
Normal file
111
src/accelerate/utils/other.py
Normal file
@ -0,0 +1,111 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
|
||||
from ..state import AcceleratorState
|
||||
from .dataclasses import DistributedType
|
||||
from .imports import is_deepspeed_available, is_tpu_available
|
||||
|
||||
|
||||
if is_deepspeed_available():
|
||||
from deepspeed import DeepSpeedEngine
|
||||
|
||||
if is_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
def extract_model_from_parallel(model):
|
||||
"""
|
||||
Extract a model from its distributed containers.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`): The model to extract.
|
||||
|
||||
Returns:
|
||||
`torch.nn.Module`: The extracted model.
|
||||
"""
|
||||
options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel)
|
||||
if is_deepspeed_available():
|
||||
options += (DeepSpeedEngine,)
|
||||
|
||||
while isinstance(model, options):
|
||||
model = model.module
|
||||
return model
|
||||
|
||||
|
||||
def wait_for_everyone():
|
||||
"""
|
||||
Introduces a blocking point in the script, making sure all processes have reached this point before continuing.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure all processes will reach this instruction otherwise one of your processes will hang forever.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
if (
|
||||
AcceleratorState().distributed_type == DistributedType.MULTI_GPU
|
||||
or AcceleratorState().distributed_type == DistributedType.MULTI_CPU
|
||||
or AcceleratorState().distributed_type == DistributedType.DEEPSPEED
|
||||
):
|
||||
torch.distributed.barrier()
|
||||
elif AcceleratorState().distributed_type == DistributedType.TPU:
|
||||
xm.rendezvous("accelerate.utils.wait_for_everyone")
|
||||
|
||||
|
||||
def save(obj, f):
|
||||
"""
|
||||
Save the data to disk. Use in place of `torch.save()`.
|
||||
|
||||
Args:
|
||||
obj: The data to save
|
||||
f: The file (or file-like object) to use to save the data
|
||||
"""
|
||||
if AcceleratorState().distributed_type == DistributedType.TPU:
|
||||
xm.save(obj, f)
|
||||
elif AcceleratorState().local_process_index == 0:
|
||||
torch.save(obj, f)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patch_environment(**kwargs):
|
||||
"""
|
||||
A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting.
|
||||
|
||||
Will convert the values in `kwargs` to strings and upper-case all the keys.
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
os.environ[key.upper()] = str(value)
|
||||
|
||||
yield
|
||||
|
||||
for key in kwargs:
|
||||
del os.environ[key.upper()]
|
||||
|
||||
|
||||
def get_pretty_name(obj):
|
||||
"""
|
||||
Gets a pretty name from `obj`.
|
||||
"""
|
||||
if not hasattr(obj, "__qualname__") and not hasattr(obj, "__name__"):
|
||||
obj = getattr(obj, "__class__", obj)
|
||||
if hasattr(obj, "__qualname__"):
|
||||
return obj.__qualname__
|
||||
if hasattr(obj, "__name__"):
|
||||
return obj.__name__
|
||||
return str(obj)
|
||||
87
src/accelerate/utils/random.py
Normal file
87
src/accelerate/utils/random.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..state import AcceleratorState
|
||||
from .dataclasses import DistributedType, RNGType
|
||||
from .imports import is_tpu_available
|
||||
|
||||
|
||||
if is_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
def set_seed(seed: int, device_specific: bool = False):
|
||||
"""
|
||||
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
||||
|
||||
Args:
|
||||
seed (`int`): The seed to set.
|
||||
device_specific (`bool`, *optional*, defaults to `False`):
|
||||
Whether to differ the seed on each device slightly with `self.process_index`.
|
||||
"""
|
||||
if device_specific:
|
||||
seed += AcceleratorState().process_index
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# ^^ safe to call this function even if cuda is not available
|
||||
if is_tpu_available():
|
||||
xm.set_rng_state(seed)
|
||||
|
||||
|
||||
def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optional[torch.Generator] = None):
|
||||
# Get the proper rng state
|
||||
if rng_type == RNGType.TORCH:
|
||||
rng_state = torch.get_rng_state()
|
||||
elif rng_type == RNGType.CUDA:
|
||||
rng_state = torch.cuda.get_rng_state()
|
||||
elif rng_type == RNGType.XLA:
|
||||
assert is_tpu_available(), "Can't synchronize XLA seeds on an environment without TPUs."
|
||||
rng_state = torch.tensor(xm.get_rng_state())
|
||||
elif rng_type == RNGType.GENERATOR:
|
||||
assert generator is not None, "Need a generator to synchronize its seed."
|
||||
rng_state = generator.get_state()
|
||||
|
||||
# Broadcast the rng state from device 0 to other devices
|
||||
state = AcceleratorState()
|
||||
if state.distributed_type == DistributedType.TPU:
|
||||
rng_state = xm.mesh_reduce("random_seed", rng_state, lambda x: x[0])
|
||||
elif state.distributed_type in [DistributedType.DEEPSPEED, DistributedType.MULTI_GPU]:
|
||||
rng_state = rng_state.to(state.device)
|
||||
torch.distributed.broadcast(rng_state, 0)
|
||||
rng_state = rng_state.cpu()
|
||||
elif state.distributed_type == DistributedType.MULTI_CPU:
|
||||
torch.distributed.broadcast(rng_state, 0)
|
||||
|
||||
# Set the broadcast rng state
|
||||
if rng_type == RNGType.TORCH:
|
||||
torch.set_rng_state(rng_state)
|
||||
elif rng_type == RNGType.CUDA:
|
||||
torch.cuda.set_rng_state(rng_state)
|
||||
elif rng_type == RNGType.XLA:
|
||||
xm.set_rng_state(rng_state.item())
|
||||
elif rng_type == RNGType.GENERATOR:
|
||||
generator.set_state(rng_state)
|
||||
|
||||
|
||||
def synchronize_rng_states(rng_types: List[Union[str, RNGType]], generator: Optional[torch.Generator] = None):
|
||||
for rng_type in rng_types:
|
||||
synchronize_rng_state(RNGType(rng_type), generator=generator)
|
||||
276
tests/test_big_modeling.py
Normal file
276
tests/test_big_modeling.py
Normal file
@ -0,0 +1,276 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from accelerate.big_modeling import (
|
||||
cpu_offload,
|
||||
disk_offload,
|
||||
dispatch_model,
|
||||
init_empty_weights,
|
||||
load_checkpoint_and_dispatch,
|
||||
)
|
||||
from accelerate.hooks import remove_hook_from_submodules
|
||||
from accelerate.test_utils import require_cuda, require_multi_gpu, slow
|
||||
from accelerate.utils import offload_state_dict
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
class ModelForTest(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(3, 4)
|
||||
self.batchnorm = nn.BatchNorm1d(4)
|
||||
self.linear2 = nn.Linear(4, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear2(self.batchnorm(self.linear1(x)))
|
||||
|
||||
|
||||
class BiggerModelForTest(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(3, 4)
|
||||
self.linear2 = nn.Linear(4, 5)
|
||||
self.batchnorm = nn.BatchNorm1d(5)
|
||||
self.linear3 = nn.Linear(5, 6)
|
||||
self.linear4 = nn.Linear(6, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear4(self.linear3(self.batchnorm(self.linear2(self.linear1(x)))))
|
||||
|
||||
|
||||
class BigModelingTester(unittest.TestCase):
|
||||
def test_init_empty_weights(self):
|
||||
# base use
|
||||
with init_empty_weights():
|
||||
module = nn.Linear(4, 5)
|
||||
self.assertEqual(module.weight.device, torch.device("meta"))
|
||||
|
||||
# base use with buffers, they are not touched
|
||||
with init_empty_weights():
|
||||
module = nn.BatchNorm1d(4)
|
||||
self.assertEqual(module.weight.device, torch.device("meta"))
|
||||
self.assertEqual(module.running_mean.device, torch.device("cpu"))
|
||||
|
||||
# Use with include_buffers=True
|
||||
with init_empty_weights(include_buffers=True):
|
||||
module = nn.BatchNorm1d(4)
|
||||
self.assertEqual(module.weight.device, torch.device("meta"))
|
||||
self.assertEqual(module.running_mean.device, torch.device("meta"))
|
||||
|
||||
# Double check we didn't break PyTorch
|
||||
module = nn.BatchNorm1d(4)
|
||||
self.assertEqual(module.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(module.running_mean.device, torch.device("cpu"))
|
||||
|
||||
def test_init_empty_weights_very_large_model(self):
|
||||
# This is a 100 billion parameters model.
|
||||
with init_empty_weights():
|
||||
_ = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
|
||||
|
||||
def test_cpu_offload(self):
|
||||
model = ModelForTest()
|
||||
x = torch.randn(2, 3)
|
||||
expected = model(x)
|
||||
|
||||
device = torch.device(0 if torch.cuda.is_available() else "cpu")
|
||||
|
||||
cpu_offload(model, execution_device=device)
|
||||
output = model(x)
|
||||
self.assertTrue(torch.allclose(expected, output.cpu()))
|
||||
|
||||
# Clean up for next test.
|
||||
remove_hook_from_submodules(model)
|
||||
|
||||
cpu_offload(model, execution_device=device, offload_buffers=True)
|
||||
output = model(x)
|
||||
self.assertTrue(torch.allclose(expected, output.cpu()))
|
||||
|
||||
@slow
|
||||
@require_cuda
|
||||
def test_cpu_offload_gpt2(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
inputs = tokenizer("Hello world! My name is", return_tensors="pt").to(0)
|
||||
|
||||
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
cpu_offload(gpt2, execution_device=0)
|
||||
outputs = gpt2.generate(inputs["input_ids"])
|
||||
self.assertEqual(
|
||||
tokenizer.decode(outputs[0].tolist()),
|
||||
"Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo",
|
||||
)
|
||||
|
||||
def test_disk_offload(self):
|
||||
model = ModelForTest()
|
||||
x = torch.randn(2, 3)
|
||||
expected = model(x)
|
||||
|
||||
device = torch.device(0 if torch.cuda.is_available() else "cpu")
|
||||
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
disk_offload(model, tmp_dir, execution_device=device)
|
||||
output = model(x)
|
||||
self.assertTrue(torch.allclose(expected, output.cpu()))
|
||||
|
||||
# Clean up for next test.
|
||||
remove_hook_from_submodules(model)
|
||||
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
disk_offload(model, tmp_dir, execution_device=device, offload_buffers=True)
|
||||
output = model(x)
|
||||
self.assertTrue(torch.allclose(expected, output.cpu()))
|
||||
|
||||
@slow
|
||||
@require_cuda
|
||||
def test_disk_offload_gpt2(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
inputs = tokenizer("Hello world! My name is", return_tensors="pt").to(0)
|
||||
|
||||
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
disk_offload(gpt2, tmp_dir, execution_device=0)
|
||||
outputs = gpt2.generate(inputs["input_ids"])
|
||||
self.assertEqual(
|
||||
tokenizer.decode(outputs[0].tolist()),
|
||||
"Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo",
|
||||
)
|
||||
|
||||
@require_cuda
|
||||
def test_dispatch_model(self):
|
||||
model = ModelForTest()
|
||||
device_map = {"linear1": "disk", "batchnorm": "cpu", "linear2": 0}
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
expected = model(x)
|
||||
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
dispatch_model(model, device_map, offload_dir=tmp_dir)
|
||||
output = model(x)
|
||||
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
|
||||
|
||||
@require_multi_gpu
|
||||
def test_dispatch_model_multi_gpu(self):
|
||||
model = BiggerModelForTest()
|
||||
device_map = {"linear1": "cpu", "linear2": "disk", "batchnorm": "cpu", "linear3": 0, "linear4": 1}
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
expected = model(x)
|
||||
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
dispatch_model(model, device_map, offload_dir=tmp_dir)
|
||||
output = model(x)
|
||||
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
|
||||
|
||||
@slow
|
||||
@require_multi_gpu
|
||||
def test_dispatch_model_gpt2_on_two_gpus(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
inputs = tokenizer("Hello world! My name is", return_tensors="pt").to(0)
|
||||
|
||||
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
# Dispatch on GPUs 0 and 1
|
||||
device_map = {
|
||||
"transformer.wte": 0,
|
||||
"transformer.wpe": 0,
|
||||
"transformer.ln_f": 1,
|
||||
"lm_head": 1,
|
||||
}
|
||||
for i in range(12):
|
||||
device_map[f"transformer.h.{i}"] = 0 if i <= 5 else 1
|
||||
|
||||
gpt2 = dispatch_model(gpt2, device_map)
|
||||
outputs = gpt2.generate(inputs["input_ids"])
|
||||
self.assertEqual(
|
||||
tokenizer.decode(outputs[0].tolist()),
|
||||
"Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo",
|
||||
)
|
||||
|
||||
# Dispatch with a bit of CPU offload
|
||||
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
for i in range(4):
|
||||
device_map[f"transformer.h.{i}"] = "cpu"
|
||||
gpt2 = dispatch_model(gpt2, device_map)
|
||||
outputs = gpt2.generate(inputs["input_ids"])
|
||||
self.assertEqual(
|
||||
tokenizer.decode(outputs[0].tolist()),
|
||||
"Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo",
|
||||
)
|
||||
# Dispatch with a bit of CPU and disk offload
|
||||
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
for i in range(2):
|
||||
device_map[f"transformer.h.{i}"] = "disk"
|
||||
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
state_dict = {
|
||||
k: p for k, p in gpt2.state_dict().items() if "transformer.h.0" in k or "transformer.h.1" in k
|
||||
}
|
||||
offload_state_dict(tmp_dir, state_dict)
|
||||
gpt2 = dispatch_model(gpt2, device_map, offload_dir=tmp_dir)
|
||||
outputs = gpt2.generate(inputs["input_ids"])
|
||||
self.assertEqual(
|
||||
tokenizer.decode(outputs[0].tolist()),
|
||||
"Hello world! My name is Kiyoshi, and I'm a student at the University of Tokyo",
|
||||
)
|
||||
|
||||
@require_cuda
|
||||
def test_load_checkpoint_and_dispatch(self):
|
||||
model = ModelForTest()
|
||||
device_map = {"linear1": "cpu", "batchnorm": "cpu", "linear2": 0}
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
expected = model(x)
|
||||
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
checkpoint = os.path.join(tmp_dir, "pt_model.bin")
|
||||
torch.save(model.state_dict(), checkpoint)
|
||||
|
||||
new_model = ModelForTest()
|
||||
new_model = load_checkpoint_and_dispatch(new_model, checkpoint, device_map=device_map)
|
||||
|
||||
# CPU-offloaded weights are on the meta device while waiting for the forward pass.
|
||||
self.assertEqual(new_model.linear1.weight.device, torch.device("meta"))
|
||||
self.assertEqual(new_model.linear2.weight.device, torch.device(0))
|
||||
|
||||
output = new_model(x)
|
||||
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
|
||||
|
||||
@require_multi_gpu
|
||||
def test_load_checkpoint_and_dispatch_multi_gpu(self):
|
||||
model = BiggerModelForTest()
|
||||
device_map = {"linear1": "cpu", "linear2": "cpu", "batchnorm": 0, "linear3": 0, "linear4": 1}
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
expected = model(x)
|
||||
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
checkpoint = os.path.join(tmp_dir, "pt_model.bin")
|
||||
torch.save(model.state_dict(), checkpoint)
|
||||
|
||||
new_model = BiggerModelForTest()
|
||||
new_model = load_checkpoint_and_dispatch(new_model, checkpoint, device_map=device_map)
|
||||
|
||||
# CPU-offloaded weights are on the meta device while waiting for the forward pass.
|
||||
self.assertEqual(new_model.linear1.weight.device, torch.device("meta"))
|
||||
self.assertEqual(new_model.linear2.weight.device, torch.device("meta"))
|
||||
self.assertEqual(new_model.linear3.weight.device, torch.device(0))
|
||||
self.assertEqual(new_model.linear4.weight.device, torch.device(1))
|
||||
|
||||
output = new_model(x)
|
||||
self.assertTrue(torch.allclose(expected, output.cpu(), atol=1e-5))
|
||||
@ -40,7 +40,7 @@ if SRC_DIRS is not None:
|
||||
# Should mock `{script_name}.get_dataloaders` via:
|
||||
# @mock.patch("{script_name}.get_dataloaders", mocked_dataloaders)
|
||||
|
||||
EXCLUDE_EXAMPLES = ["cross_validation.py", "multi_process_metrics.py", "memory.py"]
|
||||
EXCLUDE_EXAMPLES = ["cross_validation.py", "multi_process_metrics.py", "memory.py", "fsdp_with_peak_mem_tracking.py"]
|
||||
|
||||
|
||||
def mocked_dataloaders(accelerator, batch_size: int = 16):
|
||||
|
||||
330
tests/test_hooks.py
Normal file
330
tests/test_hooks.py
Normal file
@ -0,0 +1,330 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from accelerate.hooks import (
|
||||
AlignDevicesHook,
|
||||
ModelHook,
|
||||
SequentialHook,
|
||||
add_hook_to_module,
|
||||
attach_align_device_hook,
|
||||
remove_hook_from_module,
|
||||
remove_hook_from_submodules,
|
||||
)
|
||||
from accelerate.test_utils import require_multi_gpu
|
||||
|
||||
|
||||
class ModelForTest(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(3, 4)
|
||||
self.batchnorm = nn.BatchNorm1d(4)
|
||||
self.linear2 = nn.Linear(4, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear2(self.batchnorm(self.linear1(x)))
|
||||
|
||||
|
||||
class PreForwardHook(ModelHook):
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
return (args[0] + 1,) + args[1:], kwargs
|
||||
|
||||
|
||||
class PostForwardHook(ModelHook):
|
||||
def post_forward(self, module, output):
|
||||
return output + 1
|
||||
|
||||
|
||||
class HooksModelTester(unittest.TestCase):
|
||||
def test_add_and_remove_hooks(self):
|
||||
test_model = ModelForTest()
|
||||
test_hook = ModelHook()
|
||||
|
||||
add_hook_to_module(test_model, test_hook)
|
||||
self.assertEqual(test_model._hf_hook, test_hook)
|
||||
self.assertTrue(hasattr(test_model, "_old_forward"))
|
||||
|
||||
# Check adding the hook did not change the name or the signature
|
||||
self.assertEqual(test_model.forward.__name__, "forward")
|
||||
self.assertListEqual(list(inspect.signature(test_model.forward).parameters), ["x"])
|
||||
|
||||
remove_hook_from_module(test_model)
|
||||
self.assertFalse(hasattr(test_model, "_hf_hook"))
|
||||
self.assertFalse(hasattr(test_model, "_old_forward"))
|
||||
|
||||
def test_pre_forward_hook_is_executed(self):
|
||||
test_model = ModelForTest()
|
||||
x = torch.randn(2, 3)
|
||||
expected = test_model(x + 1)
|
||||
expected2 = test_model(x + 2)
|
||||
|
||||
test_hook = PreForwardHook()
|
||||
add_hook_to_module(test_model, test_hook)
|
||||
output1 = test_model(x)
|
||||
self.assertTrue(torch.allclose(output1, expected))
|
||||
|
||||
# Attaching a hook to a model when it already has one replaces, does not chain
|
||||
test_hook = PreForwardHook()
|
||||
add_hook_to_module(test_model, test_hook)
|
||||
output1 = test_model(x)
|
||||
self.assertTrue(torch.allclose(output1, expected))
|
||||
|
||||
# You need to use the sequential hook to chain two or more hooks
|
||||
test_hook = SequentialHook(PreForwardHook(), PreForwardHook())
|
||||
add_hook_to_module(test_model, test_hook)
|
||||
|
||||
output2 = test_model(x)
|
||||
assert torch.allclose(output2, expected2)
|
||||
|
||||
def test_post_forward_hook_is_executed(self):
|
||||
test_model = ModelForTest()
|
||||
x = torch.randn(2, 3)
|
||||
output = test_model(x)
|
||||
|
||||
test_hook = PostForwardHook()
|
||||
add_hook_to_module(test_model, test_hook)
|
||||
output1 = test_model(x)
|
||||
self.assertTrue(torch.allclose(output1, output + 1))
|
||||
|
||||
# Attaching a hook to a model when it already has one replaces, does not chain
|
||||
test_hook = PostForwardHook()
|
||||
add_hook_to_module(test_model, test_hook)
|
||||
output1 = test_model(x)
|
||||
self.assertTrue(torch.allclose(output1, output + 1))
|
||||
|
||||
# You need to use the sequential hook to chain two or more hooks
|
||||
test_hook = SequentialHook(PostForwardHook(), PostForwardHook())
|
||||
add_hook_to_module(test_model, test_hook)
|
||||
|
||||
output2 = test_model(x)
|
||||
assert torch.allclose(output2, output + 2)
|
||||
|
||||
def test_no_grad_in_hook(self):
|
||||
test_model = ModelForTest()
|
||||
x = torch.randn(2, 3)
|
||||
output = test_model(x)
|
||||
|
||||
test_hook = PostForwardHook()
|
||||
add_hook_to_module(test_model, test_hook)
|
||||
output1 = test_model(x)
|
||||
self.assertTrue(torch.allclose(output1, output + 1))
|
||||
self.assertTrue(output1.requires_grad)
|
||||
|
||||
test_hook.no_grad = True
|
||||
output1 = test_model(x)
|
||||
self.assertFalse(output1.requires_grad)
|
||||
|
||||
@require_multi_gpu
|
||||
def test_align_devices_as_model_parallelism(self):
|
||||
model = ModelForTest()
|
||||
# Everything is on CPU
|
||||
self.assertEqual(model.linear1.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("cpu"))
|
||||
|
||||
# This will move each submodule on different devices
|
||||
add_hook_to_module(model.linear1, AlignDevicesHook(execution_device=0))
|
||||
add_hook_to_module(model.batchnorm, AlignDevicesHook(execution_device=0))
|
||||
add_hook_to_module(model.linear2, AlignDevicesHook(execution_device=1))
|
||||
|
||||
self.assertEqual(model.linear1.weight.device, torch.device(0))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device(0))
|
||||
self.assertEqual(model.batchnorm.running_mean.device, torch.device(0))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device(1))
|
||||
|
||||
# We can still make a forward pass. The input does not need to be on any particular device
|
||||
x = torch.randn(2, 3)
|
||||
output = model(x)
|
||||
self.assertEqual(output.device, torch.device(1))
|
||||
|
||||
# We can add a general hook to put back output on same device as input.
|
||||
add_hook_to_module(model, AlignDevicesHook(io_same_device=True))
|
||||
x = torch.randn(2, 3).to(0)
|
||||
output = model(x)
|
||||
self.assertEqual(output.device, torch.device(0))
|
||||
|
||||
def test_align_devices_as_cpu_offload(self):
|
||||
model = ModelForTest()
|
||||
|
||||
# Everything is on CPU
|
||||
self.assertEqual(model.linear1.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("cpu"))
|
||||
|
||||
# This will move each submodule on different devices
|
||||
hook_kwargs = {"execution_device": 0 if torch.cuda.is_available() else "cpu", "offload": True}
|
||||
|
||||
add_hook_to_module(model.linear1, AlignDevicesHook(**hook_kwargs))
|
||||
add_hook_to_module(model.batchnorm, AlignDevicesHook(**hook_kwargs))
|
||||
add_hook_to_module(model.linear2, AlignDevicesHook(**hook_kwargs))
|
||||
|
||||
# Parameters have been offloaded, so on the meta device
|
||||
self.assertEqual(model.linear1.weight.device, torch.device("meta"))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("meta"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("meta"))
|
||||
# Buffers are not included in the offload by default, so are on the execution device
|
||||
device = torch.device(hook_kwargs["execution_device"])
|
||||
self.assertEqual(model.batchnorm.running_mean.device, device)
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
output = model(x)
|
||||
self.assertEqual(output.device, device)
|
||||
|
||||
# Removing hooks loads back the weights in the model.
|
||||
remove_hook_from_module(model.linear1)
|
||||
remove_hook_from_module(model.batchnorm)
|
||||
remove_hook_from_module(model.linear2)
|
||||
self.assertEqual(model.linear1.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("cpu"))
|
||||
|
||||
# Now test with buffers included in the offload
|
||||
hook_kwargs = {
|
||||
"execution_device": 0 if torch.cuda.is_available() else "cpu",
|
||||
"offload": True,
|
||||
"offload_buffers": True,
|
||||
}
|
||||
|
||||
add_hook_to_module(model.linear1, AlignDevicesHook(**hook_kwargs))
|
||||
add_hook_to_module(model.batchnorm, AlignDevicesHook(**hook_kwargs))
|
||||
add_hook_to_module(model.linear2, AlignDevicesHook(**hook_kwargs))
|
||||
|
||||
# Parameters have been offloaded, so on the meta device, buffers included
|
||||
self.assertEqual(model.linear1.weight.device, torch.device("meta"))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("meta"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("meta"))
|
||||
self.assertEqual(model.batchnorm.running_mean.device, torch.device("meta"))
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
output = model(x)
|
||||
self.assertEqual(output.device, device)
|
||||
|
||||
# Removing hooks loads back the weights in the model.
|
||||
remove_hook_from_module(model.linear1)
|
||||
remove_hook_from_module(model.batchnorm)
|
||||
remove_hook_from_module(model.linear2)
|
||||
self.assertEqual(model.linear1.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("cpu"))
|
||||
|
||||
def test_attach_align_device_hook_as_cpu_offload(self):
|
||||
model = ModelForTest()
|
||||
|
||||
# Everything is on CPU
|
||||
self.assertEqual(model.linear1.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("cpu"))
|
||||
|
||||
# This will move each submodule on different devices
|
||||
execution_device = 0 if torch.cuda.is_available() else "cpu"
|
||||
attach_align_device_hook(model, execution_device=execution_device, offload=True)
|
||||
|
||||
# Parameters have been offloaded, so on the meta device
|
||||
self.assertEqual(model.linear1.weight.device, torch.device("meta"))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("meta"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("meta"))
|
||||
# Buffers are not included in the offload by default, so are on the execution device
|
||||
device = torch.device(execution_device)
|
||||
self.assertEqual(model.batchnorm.running_mean.device, device)
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
output = model(x)
|
||||
self.assertEqual(output.device, device)
|
||||
|
||||
# Removing hooks loads back the weights in the model.
|
||||
remove_hook_from_submodules(model)
|
||||
self.assertEqual(model.linear1.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("cpu"))
|
||||
|
||||
# Now test with buffers included in the offload
|
||||
attach_align_device_hook(model, execution_device=execution_device, offload=True, offload_buffers=True)
|
||||
|
||||
# Parameters have been offloaded, so on the meta device, buffers included
|
||||
self.assertEqual(model.linear1.weight.device, torch.device("meta"))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("meta"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("meta"))
|
||||
self.assertEqual(model.batchnorm.running_mean.device, torch.device("meta"))
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
output = model(x)
|
||||
self.assertEqual(output.device, device)
|
||||
|
||||
# Removing hooks loads back the weights in the model.
|
||||
remove_hook_from_submodules(model)
|
||||
self.assertEqual(model.linear1.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("cpu"))
|
||||
|
||||
def test_attach_align_device_hook_as_cpu_offload_with_weight_map(self):
|
||||
model = ModelForTest()
|
||||
|
||||
# Everything is on CPU
|
||||
self.assertEqual(model.linear1.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("cpu"))
|
||||
|
||||
# This will move each submodule on different devices
|
||||
execution_device = 0 if torch.cuda.is_available() else "cpu"
|
||||
attach_align_device_hook(
|
||||
model, execution_device=execution_device, offload=True, weights_map=model.state_dict()
|
||||
)
|
||||
|
||||
# Parameters have been offloaded, so on the meta device
|
||||
self.assertEqual(model.linear1.weight.device, torch.device("meta"))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("meta"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("meta"))
|
||||
# Buffers are not included in the offload by default, so are on the execution device
|
||||
device = torch.device(execution_device)
|
||||
self.assertEqual(model.batchnorm.running_mean.device, device)
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
output = model(x)
|
||||
self.assertEqual(output.device, device)
|
||||
|
||||
# Removing hooks loads back the weights in the model.
|
||||
remove_hook_from_submodules(model)
|
||||
self.assertEqual(model.linear1.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("cpu"))
|
||||
|
||||
# Now test with buffers included in the offload
|
||||
attach_align_device_hook(
|
||||
model,
|
||||
execution_device=execution_device,
|
||||
offload=True,
|
||||
weights_map=model.state_dict(),
|
||||
offload_buffers=True,
|
||||
)
|
||||
|
||||
# Parameters have been offloaded, so on the meta device, buffers included
|
||||
self.assertEqual(model.linear1.weight.device, torch.device("meta"))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("meta"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("meta"))
|
||||
self.assertEqual(model.batchnorm.running_mean.device, torch.device("meta"))
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
output = model(x)
|
||||
self.assertEqual(output.device, device)
|
||||
|
||||
# Removing hooks loads back the weights in the model.
|
||||
remove_hook_from_submodules(model)
|
||||
self.assertEqual(model.linear1.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("cpu"))
|
||||
@ -21,8 +21,8 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs, GradScalerKwargs
|
||||
from accelerate.kwargs_handlers import KwargsHandler
|
||||
from accelerate.test_utils import execute_subprocess_async, require_cuda, require_multi_gpu
|
||||
from accelerate.utils import KwargsHandler
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -14,7 +14,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from accelerate.memory_utils import find_executable_batch_size
|
||||
from accelerate.utils.memory import find_executable_batch_size
|
||||
|
||||
|
||||
def raise_fake_out_of_memory():
|
||||
@ -50,6 +50,26 @@ class MemoryTest(unittest.TestCase):
|
||||
self.assertListEqual(batch_sizes, [128, 64, 32, 16, 8])
|
||||
self.assertListEqual([bs, arg1], [8, "hello"])
|
||||
|
||||
def test_start_zero(self):
|
||||
@find_executable_batch_size(starting_batch_size=0)
|
||||
def mock_training_loop_function(batch_size):
|
||||
pass
|
||||
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
mock_training_loop_function()
|
||||
self.assertIn("No executable batch size found, reached zero.", cm.exception.args[0])
|
||||
|
||||
def test_approach_zero(self):
|
||||
@find_executable_batch_size(starting_batch_size=16)
|
||||
def mock_training_loop_function(batch_size):
|
||||
if batch_size > 0:
|
||||
raise_fake_out_of_memory()
|
||||
pass
|
||||
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
mock_training_loop_function()
|
||||
self.assertIn("No executable batch size found, reached zero.", cm.exception.args[0])
|
||||
|
||||
def test_verbose_guard(self):
|
||||
@find_executable_batch_size(starting_batch_size=128)
|
||||
def mock_training_loop_function(batch_size, arg1, arg2):
|
||||
@ -60,3 +80,12 @@ class MemoryTest(unittest.TestCase):
|
||||
mock_training_loop_function(128, "hello", "world")
|
||||
self.assertIn("Batch size was passed into `f`", cm.exception.args[0])
|
||||
self.assertIn("`f(arg1='hello', arg2='world')", cm.exception.args[0])
|
||||
|
||||
def test_any_other_error(self):
|
||||
@find_executable_batch_size(starting_batch_size=16)
|
||||
def mock_training_loop_function(batch_size):
|
||||
raise ValueError("Oops, we had an error!")
|
||||
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
mock_training_loop_function()
|
||||
self.assertIn("Oops, we had an error!", cm.exception.args[0])
|
||||
|
||||
360
tests/test_modeling_utils.py
Normal file
360
tests/test_modeling_utils.py
Normal file
@ -0,0 +1,360 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from accelerate.test_utils import require_cuda, require_multi_gpu
|
||||
from accelerate.utils.modeling import (
|
||||
check_device_map,
|
||||
clean_device_map,
|
||||
compute_module_sizes,
|
||||
find_tied_parameters,
|
||||
infer_auto_device_map,
|
||||
load_checkpoint_in_model,
|
||||
named_module_tensors,
|
||||
set_module_tensor_to_device,
|
||||
)
|
||||
|
||||
|
||||
class ModelForTest(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(3, 4)
|
||||
self.batchnorm = nn.BatchNorm1d(4)
|
||||
self.linear2 = nn.Linear(4, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear2(self.batchnorm(self.linear1(x)))
|
||||
|
||||
|
||||
class ModelingUtilsTester(unittest.TestCase):
|
||||
def check_set_module_tensor_for_device(self, model, device1, device2):
|
||||
self.assertEqual(model.linear1.weight.device, torch.device(device1))
|
||||
|
||||
with self.subTest("Access by submodule and direct name for a parameter"):
|
||||
set_module_tensor_to_device(model.linear1, "weight", device2)
|
||||
self.assertEqual(model.linear1.weight.device, torch.device(device2))
|
||||
|
||||
if torch.device(device2) == torch.device("meta"):
|
||||
with self.assertRaises(ValueError):
|
||||
# We need a `value` to set the weight back on device1
|
||||
set_module_tensor_to_device(model.linear1, "weight", device1)
|
||||
|
||||
set_module_tensor_to_device(model.linear1, "weight", device1, value=torch.randn(4, 3))
|
||||
else:
|
||||
set_module_tensor_to_device(model.linear1, "weight", device1)
|
||||
self.assertEqual(model.linear1.weight.device, torch.device(device1))
|
||||
|
||||
with self.subTest("Access by module and full name for a parameter"):
|
||||
set_module_tensor_to_device(model, "linear1.weight", device2)
|
||||
self.assertEqual(model.linear1.weight.device, torch.device(device2))
|
||||
|
||||
if torch.device(device2) == torch.device("meta"):
|
||||
with self.assertRaises(ValueError):
|
||||
# We need a `value` to set the weight back on device1
|
||||
set_module_tensor_to_device(model, "linear1.weight", device1)
|
||||
set_module_tensor_to_device(model, "linear1.weight", device1, value=torch.randn(4, 3))
|
||||
else:
|
||||
set_module_tensor_to_device(model, "linear1.weight", device1)
|
||||
self.assertEqual(model.linear1.weight.device, torch.device(device1))
|
||||
|
||||
self.assertEqual(model.batchnorm.running_mean.device, torch.device(device1))
|
||||
|
||||
with self.subTest("Access by submodule and direct name for a buffer"):
|
||||
set_module_tensor_to_device(model.batchnorm, "running_mean", device2)
|
||||
self.assertEqual(model.batchnorm.running_mean.device, torch.device(device2))
|
||||
|
||||
if torch.device(device2) == torch.device("meta"):
|
||||
with self.assertRaises(ValueError):
|
||||
# We need a `value` to set the weight back on device1
|
||||
set_module_tensor_to_device(model.batchnorm, "running_mean", device1)
|
||||
set_module_tensor_to_device(model.batchnorm, "running_mean", device1, value=torch.randn(4))
|
||||
else:
|
||||
set_module_tensor_to_device(model.batchnorm, "running_mean", device1)
|
||||
self.assertEqual(model.batchnorm.running_mean.device, torch.device(device1))
|
||||
|
||||
with self.subTest("Access by module and full name for a parameter"):
|
||||
set_module_tensor_to_device(model, "batchnorm.running_mean", device2)
|
||||
self.assertEqual(model.batchnorm.running_mean.device, torch.device(device2))
|
||||
|
||||
if torch.device(device2) == torch.device("meta"):
|
||||
with self.assertRaises(ValueError):
|
||||
# We need a `value` to set the weight back on CPU
|
||||
set_module_tensor_to_device(model, "batchnorm.running_mean", device1)
|
||||
|
||||
set_module_tensor_to_device(model, "batchnorm.running_mean", device1, value=torch.randn(4))
|
||||
else:
|
||||
set_module_tensor_to_device(model, "batchnorm.running_mean", device1)
|
||||
self.assertEqual(model.batchnorm.running_mean.device, torch.device(device1))
|
||||
|
||||
def test_set_module_tensor_to_meta_and_cpu(self):
|
||||
model = ModelForTest()
|
||||
self.check_set_module_tensor_for_device(model, "cpu", "meta")
|
||||
|
||||
@require_cuda
|
||||
def test_set_module_tensor_to_cpu_and_gpu(self):
|
||||
model = ModelForTest()
|
||||
self.check_set_module_tensor_for_device(model, "cpu", 0)
|
||||
|
||||
@require_cuda
|
||||
def test_set_module_tensor_to_meta_and_gpu(self):
|
||||
model = ModelForTest().to(0)
|
||||
self.check_set_module_tensor_for_device(model, 0, "meta")
|
||||
|
||||
@require_multi_gpu
|
||||
def test_set_module_tensor_between_gpus(self):
|
||||
model = ModelForTest().to(0)
|
||||
self.check_set_module_tensor_for_device(model, 0, 1)
|
||||
|
||||
def test_named_tensors(self):
|
||||
model = nn.BatchNorm1d(4)
|
||||
named_tensors = named_module_tensors(model)
|
||||
self.assertListEqual(
|
||||
[name for name, _ in named_tensors],
|
||||
["weight", "bias", "running_mean", "running_var", "num_batches_tracked"],
|
||||
)
|
||||
|
||||
named_tensors = named_module_tensors(model, include_buffers=False)
|
||||
self.assertListEqual([name for name, _ in named_tensors], ["weight", "bias"])
|
||||
|
||||
model = ModelForTest()
|
||||
named_tensors = named_module_tensors(model)
|
||||
self.assertListEqual([name for name, _ in named_tensors], [])
|
||||
|
||||
named_tensors = named_module_tensors(model, recurse=True)
|
||||
self.assertListEqual(
|
||||
[name for name, _ in named_tensors],
|
||||
[
|
||||
"linear1.weight",
|
||||
"linear1.bias",
|
||||
"batchnorm.weight",
|
||||
"batchnorm.bias",
|
||||
"linear2.weight",
|
||||
"linear2.bias",
|
||||
"batchnorm.running_mean",
|
||||
"batchnorm.running_var",
|
||||
"batchnorm.num_batches_tracked",
|
||||
],
|
||||
)
|
||||
|
||||
named_tensors = named_module_tensors(model, include_buffers=False, recurse=True)
|
||||
self.assertListEqual(
|
||||
[name for name, _ in named_tensors],
|
||||
["linear1.weight", "linear1.bias", "batchnorm.weight", "batchnorm.bias", "linear2.weight", "linear2.bias"],
|
||||
)
|
||||
|
||||
def test_find_tied_parameters(self):
|
||||
model = ModelForTest()
|
||||
self.assertDictEqual(find_tied_parameters(model), {})
|
||||
model.linear2.weight = model.linear1.weight
|
||||
self.assertDictEqual(find_tied_parameters(model), {"linear1.weight": "linear2.weight"})
|
||||
|
||||
def test_compute_module_sizes(self):
|
||||
model = ModelForTest()
|
||||
expected_sizes = {"": 236, "linear1": 64, "linear1.weight": 48, "linear1.bias": 16}
|
||||
expected_sizes.update({"linear2": 100, "linear2.weight": 80, "linear2.bias": 20})
|
||||
expected_sizes.update({"batchnorm": 72, "batchnorm.weight": 16, "batchnorm.bias": 16})
|
||||
expected_sizes.update(
|
||||
{"batchnorm.running_mean": 16, "batchnorm.running_var": 16, "batchnorm.num_batches_tracked": 8}
|
||||
)
|
||||
|
||||
module_sizes = compute_module_sizes(model)
|
||||
self.assertDictEqual(module_sizes, expected_sizes)
|
||||
|
||||
model.half()
|
||||
expected_sizes = {k: s // 2 for k, s in expected_sizes.items()}
|
||||
# This one is not converted to half.
|
||||
expected_sizes["batchnorm.num_batches_tracked"] = 8
|
||||
# This impacts batchnorm and total
|
||||
expected_sizes["batchnorm"] += 4
|
||||
expected_sizes[""] += 4
|
||||
|
||||
module_sizes = compute_module_sizes(model)
|
||||
self.assertDictEqual(module_sizes, expected_sizes)
|
||||
|
||||
def test_check_device_map(self):
|
||||
model = ModelForTest()
|
||||
check_device_map(model, {"": 0})
|
||||
with self.assertRaises(ValueError):
|
||||
check_device_map(model, {"linear1": 0, "linear2": 1})
|
||||
|
||||
check_device_map(model, {"linear1": 0, "linear2": 1, "batchnorm": 1})
|
||||
|
||||
def shard_test_model(self, model, tmp_dir):
|
||||
module_index = {
|
||||
"linear1": "checkpoint_part1.bin",
|
||||
"batchnorm": "checkpoint_part2.bin",
|
||||
"linear2": "checkpoint_part3.bin",
|
||||
}
|
||||
index = {}
|
||||
for name, _ in model.state_dict().items():
|
||||
module = name.split(".")[0]
|
||||
index[name] = module_index[module]
|
||||
|
||||
with open(os.path.join(tmp_dir, "weight_map.index.json"), "w") as f:
|
||||
json.dump(index, f)
|
||||
|
||||
for module, fname in module_index.items():
|
||||
state_dict = {k: v for k, v in model.state_dict().items() if k.startswith(module)}
|
||||
full_fname = os.path.join(tmp_dir, fname)
|
||||
torch.save(state_dict, full_fname)
|
||||
|
||||
def test_load_checkpoint_in_model(self):
|
||||
# Check with whole checkpoint
|
||||
model = ModelForTest()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
fname = os.path.join(tmp_dir, "pt_model.bin")
|
||||
torch.save(model.state_dict(), fname)
|
||||
load_checkpoint_in_model(model, fname)
|
||||
|
||||
# Check with sharded index
|
||||
model = ModelForTest()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
self.shard_test_model(model, tmp_dir)
|
||||
index_file = os.path.join(tmp_dir, "weight_map.index.json")
|
||||
load_checkpoint_in_model(model, index_file)
|
||||
|
||||
# Check with sharded checkpoint
|
||||
model = ModelForTest()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
self.shard_test_model(model, tmp_dir)
|
||||
load_checkpoint_in_model(model, tmp_dir)
|
||||
|
||||
@require_cuda
|
||||
def test_load_checkpoint_in_model_one_gpu(self):
|
||||
device_map = {"linear1": 0, "batchnorm": "cpu", "linear2": "cpu"}
|
||||
|
||||
# Check with whole checkpoint
|
||||
model = ModelForTest()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
fname = os.path.join(tmp_dir, "pt_model.bin")
|
||||
torch.save(model.state_dict(), fname)
|
||||
load_checkpoint_in_model(model, fname, device_map=device_map)
|
||||
self.assertEqual(model.linear1.weight.device, torch.device(0))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("cpu"))
|
||||
|
||||
# Check with sharded index
|
||||
model = ModelForTest()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
self.shard_test_model(model, tmp_dir)
|
||||
index_file = os.path.join(tmp_dir, "weight_map.index.json")
|
||||
load_checkpoint_in_model(model, index_file, device_map=device_map)
|
||||
|
||||
self.assertEqual(model.linear1.weight.device, torch.device(0))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("cpu"))
|
||||
|
||||
# Check with sharded checkpoint folder
|
||||
model = ModelForTest()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
self.shard_test_model(model, tmp_dir)
|
||||
load_checkpoint_in_model(model, tmp_dir, device_map=device_map)
|
||||
|
||||
self.assertEqual(model.linear1.weight.device, torch.device(0))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device("cpu"))
|
||||
|
||||
@require_multi_gpu
|
||||
def test_load_checkpoint_in_model_two_gpu(self):
|
||||
device_map = {"linear1": 0, "batchnorm": "cpu", "linear2": 1}
|
||||
|
||||
# Check with whole checkpoint
|
||||
model = ModelForTest()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
fname = os.path.join(tmp_dir, "pt_model.bin")
|
||||
torch.save(model.state_dict(), fname)
|
||||
load_checkpoint_in_model(model, fname, device_map=device_map)
|
||||
self.assertEqual(model.linear1.weight.device, torch.device(0))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device(1))
|
||||
|
||||
# Check with sharded index
|
||||
model = ModelForTest()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
self.shard_test_model(model, tmp_dir)
|
||||
index_file = os.path.join(tmp_dir, "weight_map.index.json")
|
||||
load_checkpoint_in_model(model, index_file, device_map=device_map)
|
||||
|
||||
self.assertEqual(model.linear1.weight.device, torch.device(0))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device(1))
|
||||
|
||||
# Check with sharded checkpoint
|
||||
model = ModelForTest()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
self.shard_test_model(model, tmp_dir)
|
||||
load_checkpoint_in_model(model, tmp_dir, device_map=device_map)
|
||||
|
||||
self.assertEqual(model.linear1.weight.device, torch.device(0))
|
||||
self.assertEqual(model.batchnorm.weight.device, torch.device("cpu"))
|
||||
self.assertEqual(model.linear2.weight.device, torch.device(1))
|
||||
|
||||
def test_clean_device_map(self):
|
||||
# Regroup everything if all is on the same device
|
||||
self.assertDictEqual(clean_device_map({"a": 0, "b": 0, "c": 0}), {"": 0})
|
||||
# Regroups children of level 1 on the same device
|
||||
self.assertDictEqual(
|
||||
clean_device_map({"a.x": 0, "a.y": 0, "b.x": 1, "b.y": 1, "c": 1}), {"a": 0, "b": 1, "c": 1}
|
||||
)
|
||||
# Regroups children of level 2 on the same device
|
||||
self.assertDictEqual(
|
||||
clean_device_map({"a.x": 0, "a.y": 0, "b.x.0": 1, "b.x.1": 1, "b.y.0": 2, "b.y.1": 2, "c": 2}),
|
||||
{"a": 0, "b.x": 1, "b.y": 2, "c": 2},
|
||||
)
|
||||
|
||||
def test_infer_auto_device_map(self):
|
||||
model = ModelForTest()
|
||||
# model has size 236: linear1 64, batchnorm 72, linear2 100
|
||||
|
||||
device_map = infer_auto_device_map(model, max_memory={0: 200, 1: 200})
|
||||
# only linear1 fits on device 0 as we keep memory available for the maximum layer in case of offload
|
||||
self.assertDictEqual(device_map, {"linear1": 0, "batchnorm": 1, "linear2": 1})
|
||||
|
||||
device_map = infer_auto_device_map(model, max_memory={0: 200, 1: 172, 2: 200})
|
||||
# On device 1, we don't care about keeping size available for the max layer, so even if there is just the
|
||||
# size available for batchnorm + linear2, they fit here.
|
||||
self.assertDictEqual(device_map, {"linear1": 0, "batchnorm": 1, "linear2": 1})
|
||||
|
||||
model.linear1.weight = model.linear2.weight
|
||||
device_map = infer_auto_device_map(model, max_memory={0: 200, 1: 200})
|
||||
# By tying weights, the whole model fits on device 0
|
||||
self.assertDictEqual(device_map, {"": 0})
|
||||
|
||||
# When splitting a bigger model, the split is done at the layer level
|
||||
model = nn.Sequential(ModelForTest(), ModelForTest(), ModelForTest())
|
||||
device_map = infer_auto_device_map(model, max_memory={0: 500, 1: 500})
|
||||
self.assertDictEqual(device_map, {"0": 0, "1.linear1": 0, "1.batchnorm": 0, "1.linear2": 1, "2": 1})
|
||||
|
||||
# With no_split_module_classes, it's done at that module level
|
||||
model = nn.Sequential(ModelForTest(), ModelForTest(), ModelForTest())
|
||||
device_map = infer_auto_device_map(
|
||||
model, max_memory={0: 500, 1: 500}, no_split_module_classes=["ModelForTest"]
|
||||
)
|
||||
self.assertDictEqual(device_map, {"0": 0, "1": 1, "2": 1})
|
||||
|
||||
# Now if we have weights tied inside submodules, tied weights are on the same device.
|
||||
model = nn.Sequential(ModelForTest(), ModelForTest(), ModelForTest())
|
||||
layer0 = getattr(model, "0")
|
||||
layer2 = getattr(model, "2")
|
||||
layer0.linear2.weight = layer2.linear2.weight
|
||||
device_map = infer_auto_device_map(model, max_memory={0: 400, 1: 500})
|
||||
expected = {"0": 0, "2.linear2": 0, "1": 1, "2.linear1": 1, "2.batchnorm": 1}
|
||||
self.assertDictEqual(device_map, expected)
|
||||
87
tests/test_offload.py
Normal file
87
tests/test_offload.py
Normal file
@ -0,0 +1,87 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from accelerate.utils import OffloadedWeightsLoader, offload_state_dict
|
||||
|
||||
|
||||
class ModelForTest(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(3, 4)
|
||||
self.batchnorm = nn.BatchNorm1d(4)
|
||||
self.linear2 = nn.Linear(4, 5)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear2(self.batchnorm(self.linear1(x)))
|
||||
|
||||
|
||||
class OffloadTester(unittest.TestCase):
|
||||
def test_offload_state_dict(self):
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
model = ModelForTest()
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
offload_state_dict(tmp_dir, model.state_dict())
|
||||
index_file = os.path.join(tmp_dir, "index.json")
|
||||
self.assertTrue(os.path.isfile(index_file))
|
||||
# TODO: add tests on what is inside the index
|
||||
|
||||
for key in ["linear1.weight", "linear1.bias", "linear2.weight", "linear2.bias"]:
|
||||
weight_file = os.path.join(tmp_dir, f"{key}.dat")
|
||||
self.assertTrue(os.path.isfile(weight_file))
|
||||
# TODO: add tests on the fact weights are properly loaded
|
||||
|
||||
def test_offload_weights_loader(self):
|
||||
model = ModelForTest()
|
||||
state_dict = model.state_dict()
|
||||
cpu_part = {k: v for k, v in state_dict.items() if "linear2" not in k}
|
||||
disk_part = {k: v for k, v in state_dict.items() if "linear2" in k}
|
||||
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
offload_state_dict(tmp_dir, disk_part)
|
||||
weight_map = OffloadedWeightsLoader(state_dict=cpu_part, save_folder=tmp_dir)
|
||||
|
||||
# Every key is there with the right value
|
||||
self.assertEqual(sorted(weight_map), sorted(state_dict.keys()))
|
||||
for key, param in state_dict.items():
|
||||
self.assertTrue(torch.allclose(param, weight_map[key]))
|
||||
|
||||
cpu_part = {k: v for k, v in state_dict.items() if "weight" in k}
|
||||
disk_part = {k: v for k, v in state_dict.items() if "weight" not in k}
|
||||
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
offload_state_dict(tmp_dir, disk_part)
|
||||
weight_map = OffloadedWeightsLoader(state_dict=cpu_part, save_folder=tmp_dir)
|
||||
|
||||
# Every key is there with the right value
|
||||
self.assertEqual(sorted(weight_map), sorted(state_dict.keys()))
|
||||
for key, param in state_dict.items():
|
||||
self.assertTrue(torch.allclose(param, weight_map[key]))
|
||||
|
||||
with TemporaryDirectory() as tmp_dir:
|
||||
offload_state_dict(tmp_dir, state_dict)
|
||||
# Duplicates are removed
|
||||
weight_map = OffloadedWeightsLoader(state_dict=cpu_part, save_folder=tmp_dir)
|
||||
|
||||
# Every key is there with the right value
|
||||
self.assertEqual(sorted(weight_map), sorted(state_dict.keys()))
|
||||
for key, param in state_dict.items():
|
||||
self.assertTrue(torch.allclose(param, weight_map[key]))
|
||||
@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
import pytest
|
||||
from accelerate.commands.config.config_args import SageMakerConfig
|
||||
from accelerate.commands.launch import _convert_nargs_to_dict
|
||||
from accelerate.state import ComputeEnvironment
|
||||
from accelerate.utils import ComputeEnvironment
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -26,10 +26,19 @@ from unittest import mock
|
||||
|
||||
# We use TF to parse the logs
|
||||
from accelerate import Accelerator
|
||||
from accelerate.test_utils.testing import MockingTestCase, TempDirTestCase, require_tensorflow
|
||||
from accelerate.test_utils.testing import (
|
||||
MockingTestCase,
|
||||
TempDirTestCase,
|
||||
require_comet_ml,
|
||||
require_tensorflow,
|
||||
require_wandb,
|
||||
)
|
||||
from accelerate.tracking import CometMLTracker, GeneralTracker
|
||||
from accelerate.utils import is_tensorflow_available
|
||||
from comet_ml import OfflineExperiment
|
||||
from accelerate.utils import is_comet_ml_available, is_tensorflow_available
|
||||
|
||||
|
||||
if is_comet_ml_available():
|
||||
from comet_ml import OfflineExperiment
|
||||
|
||||
|
||||
if is_tensorflow_available():
|
||||
@ -110,6 +119,7 @@ class TensorBoardTrackingTest(unittest.TestCase):
|
||||
_ = Accelerator(log_with="tensorboard", logging_dir=dirpath)
|
||||
|
||||
|
||||
@require_wandb
|
||||
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
|
||||
class WandBTrackingTest(TempDirTestCase, MockingTestCase):
|
||||
def setUp(self):
|
||||
@ -179,6 +189,7 @@ def offline_init(self, run_name: str, tmpdir: str):
|
||||
logger.info("Make sure to log any initial configurations with `self.store_init_configuration` before training!")
|
||||
|
||||
|
||||
@require_comet_ml
|
||||
@mock.patch.object(CometMLTracker, "__init__", offline_init)
|
||||
class CometMLTest(unittest.TestCase):
|
||||
@staticmethod
|
||||
|
||||
@ -20,7 +20,7 @@ from collections import UserDict, namedtuple
|
||||
import torch
|
||||
|
||||
from accelerate.test_utils.training import RegressionModel
|
||||
from accelerate.utils import convert_outputs_to_fp32, patch_environment, send_to_device
|
||||
from accelerate.utils import convert_outputs_to_fp32, find_device, patch_environment, send_to_device
|
||||
|
||||
|
||||
TestNamedTuple = namedtuple("TestNamedTuple", "a b c")
|
||||
@ -78,3 +78,8 @@ class UtilsTester(unittest.TestCase):
|
||||
model = RegressionModel()
|
||||
model.forward = convert_outputs_to_fp32(model.forward)
|
||||
_ = pickle.dumps(model)
|
||||
|
||||
def test_find_device(self):
|
||||
self.assertEqual(find_device([1, "a", torch.tensor([1, 2, 3])]), torch.device("cpu"))
|
||||
self.assertEqual(find_device({"a": 1, "b": torch.tensor([1, 2, 3])}), torch.device("cpu"))
|
||||
self.assertIsNone(find_device([1, "a"]))
|
||||
|
||||
Reference in New Issue
Block a user