Fixup dataloader state dict bugs + incorporate load/save_state API (#3034)

* v1

* More testing, need to try on H100

* Bigger batch for h100 test

* test tweak

* Fixup all tests!

* Bookmark

* Fix issues, working now

* rm num samples

* Uncomment

* Give stateful dl end of dl

* Make skip DL stateful

* Migrate to update_state_dict

* try/finally

* Add comments to test

* rm comment

* Document

* refactor out for eventual override

* Doc nit

* Brute force it
This commit is contained in:
Zach Mueller
2024-08-23 15:13:33 -04:00
committed by GitHub
parent 2d4f1dda7e
commit 726140cad2
8 changed files with 219 additions and 75 deletions

View File

@ -19,9 +19,10 @@ import torch
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
from accelerate import Accelerator, DistributedType
from accelerate import Accelerator, DataLoaderConfiguration, DistributedType
from accelerate.utils import set_seed
########################################################################
@ -125,7 +126,8 @@ def training_function(config, args):
if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1":
config["num_epochs"] = 2
# Initialize accelerator
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config)
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
lr = config["lr"]
num_epochs = int(config["num_epochs"])
@ -217,8 +219,11 @@ def training_function(config, args):
model.train()
# New Code #
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
# We need to skip steps until we reach the resumed step only if we are not using a stateful dataloader
if not args.use_stateful_dataloader:
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
active_dataloader = train_dataloader
overall_step += resume_step
else:
# After the first iteration though, we need to go back to the original dataloader
@ -248,7 +253,6 @@ def training_function(config, args):
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
model.eval()
for step, batch in enumerate(eval_dataloader):
# We could avoid this line since we set the accelerator with `device_placement=True` (the default).
@ -261,7 +265,6 @@ def training_function(config, args):
predictions=predictions,
references=references,
)
eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
@ -309,6 +312,11 @@ def main():
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--use_stateful_dataloader",
action="store_true",
help="If the dataloader should be a resumable stateful dataloader.",
)
args = parser.parse_args()
config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": 16}
training_function(config, args)

View File

@ -23,7 +23,7 @@ from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor
from accelerate import Accelerator
from accelerate import Accelerator, DataLoaderConfiguration
########################################################################
@ -72,12 +72,19 @@ class PetsDataset(Dataset):
def training_function(config, args):
# Initialize accelerator
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
if args.with_tracking:
accelerator = Accelerator(
cpu=args.cpu, mixed_precision=args.mixed_precision, log_with="all", project_dir=args.project_dir
cpu=args.cpu,
mixed_precision=args.mixed_precision,
log_with="all",
project_dir=args.project_dir,
dataloader_config=dataloader_config,
)
else:
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
accelerator = Accelerator(
cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config
)
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
lr = config["lr"]
@ -297,6 +304,11 @@ def main():
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--use_stateful_dataloader",
action="store_true",
help="If the dataloader should be a resumable stateful dataloader.",
)
parser.add_argument(
"--with_tracking",
action="store_true",

View File

@ -21,7 +21,7 @@ from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
from accelerate import Accelerator, DistributedType
from accelerate import Accelerator, DataLoaderConfiguration, DistributedType
########################################################################
@ -49,12 +49,19 @@ EVAL_BATCH_SIZE = 32
def training_function(config, args):
# Initialize accelerator
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
if args.with_tracking:
accelerator = Accelerator(
cpu=args.cpu, mixed_precision=args.mixed_precision, log_with="all", project_dir=args.project_dir
cpu=args.cpu,
mixed_precision=args.mixed_precision,
dataloader_config=dataloader_config,
log_with="all",
project_dir=args.project_dir,
)
else:
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
accelerator = Accelerator(
cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config
)
if hasattr(args.checkpointing_steps, "isdigit"):
if args.checkpointing_steps == "epoch":
@ -194,7 +201,10 @@ def training_function(config, args):
total_loss = 0
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
if not args.use_stateful_dataloader:
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
else:
active_dataloader = train_dataloader
overall_step += resume_step
else:
# After the first iteration though, we need to go back to the original dataloader
@ -283,6 +293,11 @@ def main():
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--use_stateful_dataloader",
action="store_true",
help="If the dataloader should be a resumable stateful dataloader.",
)
parser.add_argument(
"--with_tracking",
action="store_true",

View File

@ -127,6 +127,11 @@ def save_accelerator_state(
sampler = dataloader.get_sampler()
if isinstance(sampler, SeedableRandomSampler):
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
if getattr(dataloader, "use_stateful_dataloader", False):
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
output_dataloader_state_dict_file = output_dir.joinpath(dataloader_state_dict_name)
state_dict = dataloader.state_dict()
torch.save(state_dict, output_dataloader_state_dict_file)
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")
# GradScaler state
@ -241,6 +246,12 @@ def load_accelerator_state(
sampler = dataloader.get_sampler()
if isinstance(sampler, SeedableRandomSampler):
sampler = dataloader.set_sampler(torch.load(input_sampler_file))
if getattr(dataloader, "use_stateful_dataloader", False):
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
if input_dataloader_state_dict_file.exists():
state_dict = torch.load(input_dataloader_state_dict_file)
dataloader.load_state_dict(state_dict)
logger.info("All dataloader sampler states loaded successfully")
# GradScaler state

View File

@ -365,6 +365,13 @@ class DataLoaderStateMixin:
- **remainder** (`int`) -- The number of items that are remaining in the last batch, relative to the total
batch size
<Tip warning={true}>
Inheriters of this class should ensure that the class creates a `GradientState()` instance, stored in
`self.gradient_state`.
</Tip>
"""
def __init_subclass__(cls, **kwargs):
@ -443,7 +450,29 @@ class DataLoaderAdapter:
def load_state_dict(self, state_dict):
self.base_dataloader.load_state_dict(state_dict)
self.dl_state_dict = self.state_dict
def adjust_state_dict_for_prefetch(self):
"""
Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
`self.dl_state_dict` by a factor of `num_processes - 1`, however if a custom correction is needed, this can be
overridden.
This should modify `self.dl_state_dict` directly
"""
# The state dict will be off by a factor of `n-1` batch too many during DDP,
# so we need to adjust it here
if PartialState().distributed_type != DistributedType.NO:
factor = PartialState().num_processes - 1
if self.dl_state_dict["_sampler_iter_yielded"] > 0:
self.dl_state_dict["_sampler_iter_yielded"] -= factor
if self.dl_state_dict["_num_yielded"] > 0:
self.dl_state_dict["_num_yielded"] -= factor
if self.dl_state_dict["_index_sampler_state"] is not None:
if (
"samples_yielded" in self.dl_state_dict["_index_sampler_state"]
and self.dl_state_dict["_index_sampler_state"]["samples_yielded"] > 0
):
self.dl_state_dict["_index_sampler_state"]["samples_yielded"] -= self.batch_size * factor
def _update_state_dict(self):
# The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
@ -453,6 +482,10 @@ class DataLoaderAdapter:
# _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
if hasattr(self.base_dataloader, "state_dict"):
self.dl_state_dict = self.base_dataloader.state_dict()
# Potentially modify the state_dict to adjust for prefetching
self.adjust_state_dict_for_prefetch()
# Then tag if we are at the end of the dataloader
self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader
class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
@ -539,6 +572,7 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
current_batch = next_batch
except StopIteration:
self.end_of_dataloader = True
self._update_state_dict()
if batch_index >= self.skip_batches:
yield current_batch
break
@ -809,6 +843,7 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
if stop_iteration:
self.end_of_dataloader = True
self._update_state_dict()
self.remainder = observed_batch_size
if batch_index >= self.skip_batches:
yield batch
@ -1146,7 +1181,7 @@ class SkipBatchSampler(BatchSampler):
return len(self.batch_sampler) - self.skip_batches
class SkipDataLoader(DataLoaderAdapter):
class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
"""
Subclass of a PyTorch `DataLoader` that will skip the first batches.
@ -1164,12 +1199,15 @@ class SkipDataLoader(DataLoaderAdapter):
def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
self.skip_batches = skip_batches
self.gradient_state = GradientState()
def __iter__(self):
self.begin()
for index, batch in enumerate(self.base_dataloader.__iter__()):
if index >= self.skip_batches:
self._update_state_dict()
yield batch
self.end()
def skip_first_batches(dataloader, num_batches=0):

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import warnings
from typing import List
from unittest.mock import Mock
@ -77,12 +77,17 @@ def create_accelerator(even_batches=True):
return accelerator
def create_dataloader(accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False):
def create_dataloader(
accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False, shuffle: bool = False
):
"""
Create a simple DataLoader to use during the test cases
"""
values = torch.as_tensor(range(dataset_size))
if shuffle:
values = values[torch.randperm(values.size(0))]
if iterable:
dataset = DummyIterableDataset(torch.as_tensor(range(dataset_size)))
dataset = DummyIterableDataset(values)
else:
dataset = TensorDataset(torch.as_tensor(range(dataset_size)))
@ -260,6 +265,81 @@ def test_data_loader(data_loader, accelerator):
), "Not all the dataset elements have been iterated in an epoch due to duplication of samples across processes."
def test_stateful_dataloader(accelerator):
"""
Tests that a stateful dataloader can be iterated over, saved after a few batches using `load_state_dict`, and then
resumed from the saved state.
The result should be the same as the rest of the data that iterated over after saving.
"""
old_dataloader_config = accelerator.dataloader_config
try:
accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
prepared_dl = create_dataloader(
accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True
)
untrained_batches = []
# Calculate what step that will be
total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes)
last_batch_num = total_batches - 1
for step, batch in enumerate(prepared_dl):
# Step just before
if step == last_batch_num - 1:
state_dict = prepared_dl.state_dict()
if step >= last_batch_num:
# Otherwise grab the "unseen" batches
untrained_batches.append(batch)
not_skipped_batches = accelerator.gather(untrained_batches)
prepared_dl.load_state_dict(state_dict)
resumed_batches = []
for batch in prepared_dl:
resumed_batches.append(batch)
resumed_batches = accelerator.gather(resumed_batches)
for b1, b2 in zip(not_skipped_batches, resumed_batches):
for v1, v2 in zip(b1, b2):
assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
finally:
accelerator.dataloader_config = old_dataloader_config
def test_stateful_dataloader_save_state(accelerator):
"""
Tests that a stateful dataloader can be iterated over, saved after a few batches using `Accelerator.save_state`,
and then resumed from the saved state.
The result should be the same as the rest of the data that iterated over after saving.
"""
old_dataloader_config = accelerator.dataloader_config
try:
with tempfile.TemporaryDirectory() as tmpdir:
accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
prepared_dl = create_dataloader(
accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True
)
untrained_batches = []
# Calculate what step that will be
total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes)
last_batch_num = total_batches - 1
for step, batch in enumerate(prepared_dl):
# Step just before
if step == last_batch_num - 1:
accelerator.save_state(tmpdir)
if step >= last_batch_num:
# Otherwise grab the "unseen" batches
untrained_batches.append(batch)
not_skipped_batches = accelerator.gather(untrained_batches)
accelerator.load_state(tmpdir)
resumed_batches = []
for batch in prepared_dl:
resumed_batches.append(batch)
resumed_batches = accelerator.gather(resumed_batches)
for b1, b2 in zip(not_skipped_batches, resumed_batches):
for v1, v2 in zip(b1, b2):
assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
finally:
accelerator.dataloader_config = old_dataloader_config
def main():
accelerator = create_accelerator()
torch.manual_seed(accelerator.process_index)
@ -306,6 +386,8 @@ def main():
sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False)
loader = DataLoader(dataset, sampler=sampler, batch_size=None, collate_fn=default_collate, num_workers=NUM_WORKERS)
test_data_loader(loader, accelerator)
test_stateful_dataloader(accelerator)
test_stateful_dataloader_save_state(accelerator)
accelerator.end_training()

View File

@ -27,7 +27,6 @@ from torch.utils.data import DataLoader, TensorDataset
from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
from accelerate.accelerator import Accelerator
from accelerate.data_loader import skip_first_batches
from accelerate.state import GradientState, PartialState
from accelerate.test_utils import (
require_bnb,
@ -682,11 +681,12 @@ class AcceleratorTester(AccelerateTestCase):
Test that saving and loading a model with a stateful dataloader returns the same model,
and that the dataloader's iterator is restored properly."""
set_seed(42)
n_train_batches = 64 # Use enough batches to ensure we can get partial iterations on large compute
dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=True)
accelerator = Accelerator(dataloader_config=dataloader_config)
model, optimizer, scheduler, train_dl, valid_dl = create_components(tied_weights)
train_dl, valid_dl = create_dataloaders_for_test(num_workers=num_workers)
train_dl, valid_dl = create_dataloaders_for_test(n_train_batches=n_train_batches, num_workers=num_workers)
model = ModelForTest()
(
@ -703,77 +703,53 @@ class AcceleratorTester(AccelerateTestCase):
# Perform 3 training iterations to ensure the dataloader's iterator is advanced
num_batches_to_skip = 3
model.train()
for step, batch in enumerate(prepared_train_dl):
x, y = batch
x.to(accelerator.device)
y.to(accelerator.device)
with accelerator.accumulate(prepared_model):
untrained_batches = []
with tempfile.TemporaryDirectory() as tmpdirname:
for step, batch in enumerate(prepared_train_dl):
x, y = batch
outputs = prepared_model(x)
loss = torch.nn.functional.mse_loss(outputs, y)
accelerator.backward(loss)
prepared_optimizer.step()
prepared_scheduler.step()
prepared_optimizer.zero_grad()
if step == num_batches_to_skip - 1:
state_dict = prepared_train_dl.state_dict()
# When breaking out without fully going through the iterator, must call end() to unregister this iterator from gradient state.
# TODO: Maybe this could be done automatically?
prepared_train_dl.end()
break
if step == num_batches_to_skip - 1:
# Save the state once we've gone through a few batches
accelerator.save_state(f"{tmpdirname}/state", safe_serialization=use_safetensors)
if step >= num_batches_to_skip:
untrained_batches.append(batch)
assert accelerator.gradient_state.active_dataloader is None
not_skipped_batches = accelerator.gather(untrained_batches)
# We then unwrap the trained model
unwrapped_model = accelerator.unwrap_model(prepared_model)
with tempfile.TemporaryDirectory() as tmpdirname:
# Save model for later use
accelerator.save_model(model, tmpdirname, safe_serialization=use_safetensors)
original_linear1 = unwrapped_model.linear1.weight.clone()
original_batchnorm = unwrapped_model.batchnorm.weight.clone()
original_linear2 = unwrapped_model.linear2.weight.clone()
# Starting from where we left off, train this model to the end of the DataLoader
prepared_train_dl = skip_first_batches(prepared_train_dl, num_batches_to_skip)
batches_seen_with_original_dl = 0
for batch in prepared_train_dl:
x, y = batch
x.to(accelerator.device)
y.to(accelerator.device)
with accelerator.accumulate(prepared_model):
outputs = prepared_model(x)
loss = torch.nn.functional.mse_loss(outputs, y)
accelerator.backward(loss)
prepared_optimizer.step()
prepared_scheduler.step()
prepared_optimizer.zero_grad()
batches_seen_with_original_dl += 1
original_linear1 = prepared_model.linear1.weight.clone()
original_batchnorm = prepared_model.batchnorm.weight.clone()
original_linear2 = prepared_model.linear2.weight.clone()
# Load the model and state dict
load_checkpoint_in_model(model, tmpdirname)
stateful_train_dl, _ = create_dataloaders_for_test(num_workers=num_workers)
prepared_stateful_train_dl = accelerator.prepare_data_loader(stateful_train_dl)
prepared_stateful_train_dl.load_state_dict(state_dict)
# Resume the state
accelerator.load_state(f"{tmpdirname}/state")
# Train this to the end of the DataLoader
batches_seen_with_loaded_dl = 0
for batch in prepared_stateful_train_dl:
for batch in prepared_train_dl:
x, y = batch
x.to(accelerator.device)
y.to(accelerator.device)
with accelerator.accumulate(prepared_model):
outputs = prepared_model(x)
loss = torch.nn.functional.mse_loss(outputs, y)
accelerator.backward(loss)
prepared_optimizer.step()
prepared_scheduler.step()
prepared_optimizer.zero_grad()
outputs = prepared_model(x)
loss = torch.nn.functional.mse_loss(outputs, y)
accelerator.backward(loss)
prepared_optimizer.step()
prepared_scheduler.step()
prepared_optimizer.zero_grad()
batches_seen_with_loaded_dl += 1
new_linear1 = prepared_model.linear1.weight
new_batchnorm = prepared_model.batchnorm.weight
new_linear2 = prepared_model.linear2.weight
unwrapped_model_2 = accelerator.unwrap_model(prepared_model)
new_linear1 = unwrapped_model_2.linear1.weight
new_batchnorm = unwrapped_model_2.batchnorm.weight
new_linear2 = unwrapped_model_2.linear2.weight
# Assert equalities
assert batches_seen_with_original_dl == batches_seen_with_loaded_dl
assert batches_seen_with_loaded_dl == len(not_skipped_batches)
assert torch.allclose(original_linear1, new_linear1)
assert torch.allclose(original_batchnorm, new_batchnorm)
assert torch.allclose(original_linear2, new_linear2)

View File

@ -45,11 +45,13 @@ from accelerate.utils import write_basic_config
EXCLUDE_EXAMPLES = [
"cross_validation.py",
"checkpointing.py",
"gradient_accumulation.py",
"local_sgd.py",
"multi_process_metrics.py",
"memory.py",
"schedule_free.py",
"tracking.py",
"automatic_gradient_accumulation.py",
"fsdp_with_peak_mem_tracking.py",
"deepspeed_with_config_support.py",