mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
Fixup dataloader state dict bugs + incorporate load/save_state API (#3034)
* v1 * More testing, need to try on H100 * Bigger batch for h100 test * test tweak * Fixup all tests! * Bookmark * Fix issues, working now * rm num samples * Uncomment * Give stateful dl end of dl * Make skip DL stateful * Migrate to update_state_dict * try/finally * Add comments to test * rm comment * Document * refactor out for eventual override * Doc nit * Brute force it
This commit is contained in:
@ -19,9 +19,10 @@ import torch
|
||||
from datasets import load_dataset
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup
|
||||
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate import Accelerator, DataLoaderConfiguration, DistributedType
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
|
||||
########################################################################
|
||||
@ -125,7 +126,8 @@ def training_function(config, args):
|
||||
if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1":
|
||||
config["num_epochs"] = 2
|
||||
# Initialize accelerator
|
||||
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
|
||||
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
|
||||
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config)
|
||||
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
|
||||
lr = config["lr"]
|
||||
num_epochs = int(config["num_epochs"])
|
||||
@ -217,8 +219,11 @@ def training_function(config, args):
|
||||
model.train()
|
||||
# New Code #
|
||||
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
|
||||
# We need to skip steps until we reach the resumed step
|
||||
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
|
||||
# We need to skip steps until we reach the resumed step only if we are not using a stateful dataloader
|
||||
if not args.use_stateful_dataloader:
|
||||
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
|
||||
else:
|
||||
active_dataloader = train_dataloader
|
||||
overall_step += resume_step
|
||||
else:
|
||||
# After the first iteration though, we need to go back to the original dataloader
|
||||
@ -248,7 +253,6 @@ def training_function(config, args):
|
||||
if args.output_dir is not None:
|
||||
output_dir = os.path.join(args.output_dir, output_dir)
|
||||
accelerator.save_state(output_dir)
|
||||
|
||||
model.eval()
|
||||
for step, batch in enumerate(eval_dataloader):
|
||||
# We could avoid this line since we set the accelerator with `device_placement=True` (the default).
|
||||
@ -261,7 +265,6 @@ def training_function(config, args):
|
||||
predictions=predictions,
|
||||
references=references,
|
||||
)
|
||||
|
||||
eval_metric = metric.compute()
|
||||
# Use accelerator.print to print only on the main process.
|
||||
accelerator.print(f"epoch {epoch}:", eval_metric)
|
||||
@ -309,6 +312,11 @@ def main():
|
||||
default=None,
|
||||
help="If the training should continue from a checkpoint folder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_stateful_dataloader",
|
||||
action="store_true",
|
||||
help="If the dataloader should be a resumable stateful dataloader.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": 16}
|
||||
training_function(config, args)
|
||||
|
@ -23,7 +23,7 @@ from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision.transforms import Compose, RandomResizedCrop, Resize, ToTensor
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, DataLoaderConfiguration
|
||||
|
||||
|
||||
########################################################################
|
||||
@ -72,12 +72,19 @@ class PetsDataset(Dataset):
|
||||
|
||||
def training_function(config, args):
|
||||
# Initialize accelerator
|
||||
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
|
||||
if args.with_tracking:
|
||||
accelerator = Accelerator(
|
||||
cpu=args.cpu, mixed_precision=args.mixed_precision, log_with="all", project_dir=args.project_dir
|
||||
cpu=args.cpu,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with="all",
|
||||
project_dir=args.project_dir,
|
||||
dataloader_config=dataloader_config,
|
||||
)
|
||||
else:
|
||||
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
|
||||
accelerator = Accelerator(
|
||||
cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config
|
||||
)
|
||||
|
||||
# Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
|
||||
lr = config["lr"]
|
||||
@ -297,6 +304,11 @@ def main():
|
||||
default=None,
|
||||
help="If the training should continue from a checkpoint folder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_stateful_dataloader",
|
||||
action="store_true",
|
||||
help="If the dataloader should be a resumable stateful dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_tracking",
|
||||
action="store_true",
|
||||
|
@ -21,7 +21,7 @@ from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed
|
||||
|
||||
from accelerate import Accelerator, DistributedType
|
||||
from accelerate import Accelerator, DataLoaderConfiguration, DistributedType
|
||||
|
||||
|
||||
########################################################################
|
||||
@ -49,12 +49,19 @@ EVAL_BATCH_SIZE = 32
|
||||
|
||||
def training_function(config, args):
|
||||
# Initialize accelerator
|
||||
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=args.use_stateful_dataloader)
|
||||
if args.with_tracking:
|
||||
accelerator = Accelerator(
|
||||
cpu=args.cpu, mixed_precision=args.mixed_precision, log_with="all", project_dir=args.project_dir
|
||||
cpu=args.cpu,
|
||||
mixed_precision=args.mixed_precision,
|
||||
dataloader_config=dataloader_config,
|
||||
log_with="all",
|
||||
project_dir=args.project_dir,
|
||||
)
|
||||
else:
|
||||
accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
|
||||
accelerator = Accelerator(
|
||||
cpu=args.cpu, mixed_precision=args.mixed_precision, dataloader_config=dataloader_config
|
||||
)
|
||||
|
||||
if hasattr(args.checkpointing_steps, "isdigit"):
|
||||
if args.checkpointing_steps == "epoch":
|
||||
@ -194,7 +201,10 @@ def training_function(config, args):
|
||||
total_loss = 0
|
||||
if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
|
||||
# We need to skip steps until we reach the resumed step
|
||||
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
|
||||
if not args.use_stateful_dataloader:
|
||||
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
|
||||
else:
|
||||
active_dataloader = train_dataloader
|
||||
overall_step += resume_step
|
||||
else:
|
||||
# After the first iteration though, we need to go back to the original dataloader
|
||||
@ -283,6 +293,11 @@ def main():
|
||||
default=None,
|
||||
help="If the training should continue from a checkpoint folder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_stateful_dataloader",
|
||||
action="store_true",
|
||||
help="If the dataloader should be a resumable stateful dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_tracking",
|
||||
action="store_true",
|
||||
|
@ -127,6 +127,11 @@ def save_accelerator_state(
|
||||
sampler = dataloader.get_sampler()
|
||||
if isinstance(sampler, SeedableRandomSampler):
|
||||
save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False)
|
||||
if getattr(dataloader, "use_stateful_dataloader", False):
|
||||
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
|
||||
output_dataloader_state_dict_file = output_dir.joinpath(dataloader_state_dict_name)
|
||||
state_dict = dataloader.state_dict()
|
||||
torch.save(state_dict, output_dataloader_state_dict_file)
|
||||
logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}")
|
||||
|
||||
# GradScaler state
|
||||
@ -241,6 +246,12 @@ def load_accelerator_state(
|
||||
sampler = dataloader.get_sampler()
|
||||
if isinstance(sampler, SeedableRandomSampler):
|
||||
sampler = dataloader.set_sampler(torch.load(input_sampler_file))
|
||||
if getattr(dataloader, "use_stateful_dataloader", False):
|
||||
dataloader_state_dict_name = "dl_state_dict.bin" if i == 0 else f"dl_state_dict_{i}.bin"
|
||||
input_dataloader_state_dict_file = input_dir.joinpath(dataloader_state_dict_name)
|
||||
if input_dataloader_state_dict_file.exists():
|
||||
state_dict = torch.load(input_dataloader_state_dict_file)
|
||||
dataloader.load_state_dict(state_dict)
|
||||
logger.info("All dataloader sampler states loaded successfully")
|
||||
|
||||
# GradScaler state
|
||||
|
@ -365,6 +365,13 @@ class DataLoaderStateMixin:
|
||||
- **remainder** (`int`) -- The number of items that are remaining in the last batch, relative to the total
|
||||
batch size
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Inheriters of this class should ensure that the class creates a `GradientState()` instance, stored in
|
||||
`self.gradient_state`.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
@ -443,7 +450,29 @@ class DataLoaderAdapter:
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.base_dataloader.load_state_dict(state_dict)
|
||||
self.dl_state_dict = self.state_dict
|
||||
|
||||
def adjust_state_dict_for_prefetch(self):
|
||||
"""
|
||||
Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
|
||||
`self.dl_state_dict` by a factor of `num_processes - 1`, however if a custom correction is needed, this can be
|
||||
overridden.
|
||||
|
||||
This should modify `self.dl_state_dict` directly
|
||||
"""
|
||||
# The state dict will be off by a factor of `n-1` batch too many during DDP,
|
||||
# so we need to adjust it here
|
||||
if PartialState().distributed_type != DistributedType.NO:
|
||||
factor = PartialState().num_processes - 1
|
||||
if self.dl_state_dict["_sampler_iter_yielded"] > 0:
|
||||
self.dl_state_dict["_sampler_iter_yielded"] -= factor
|
||||
if self.dl_state_dict["_num_yielded"] > 0:
|
||||
self.dl_state_dict["_num_yielded"] -= factor
|
||||
if self.dl_state_dict["_index_sampler_state"] is not None:
|
||||
if (
|
||||
"samples_yielded" in self.dl_state_dict["_index_sampler_state"]
|
||||
and self.dl_state_dict["_index_sampler_state"]["samples_yielded"] > 0
|
||||
):
|
||||
self.dl_state_dict["_index_sampler_state"]["samples_yielded"] -= self.batch_size * factor
|
||||
|
||||
def _update_state_dict(self):
|
||||
# The state_dict of the underlying base_dataloader may be ahead of what is currently being yielded.
|
||||
@ -453,6 +482,10 @@ class DataLoaderAdapter:
|
||||
# _update_state_dict is called to snapshot the state_dict that would properly recover the DataLoaderAdapter.
|
||||
if hasattr(self.base_dataloader, "state_dict"):
|
||||
self.dl_state_dict = self.base_dataloader.state_dict()
|
||||
# Potentially modify the state_dict to adjust for prefetching
|
||||
self.adjust_state_dict_for_prefetch()
|
||||
# Then tag if we are at the end of the dataloader
|
||||
self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader
|
||||
|
||||
|
||||
class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
@ -539,6 +572,7 @@ class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
current_batch = next_batch
|
||||
except StopIteration:
|
||||
self.end_of_dataloader = True
|
||||
self._update_state_dict()
|
||||
if batch_index >= self.skip_batches:
|
||||
yield current_batch
|
||||
break
|
||||
@ -809,6 +843,7 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
|
||||
if stop_iteration:
|
||||
self.end_of_dataloader = True
|
||||
self._update_state_dict()
|
||||
self.remainder = observed_batch_size
|
||||
if batch_index >= self.skip_batches:
|
||||
yield batch
|
||||
@ -1146,7 +1181,7 @@ class SkipBatchSampler(BatchSampler):
|
||||
return len(self.batch_sampler) - self.skip_batches
|
||||
|
||||
|
||||
class SkipDataLoader(DataLoaderAdapter):
|
||||
class SkipDataLoader(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
"""
|
||||
Subclass of a PyTorch `DataLoader` that will skip the first batches.
|
||||
|
||||
@ -1164,12 +1199,15 @@ class SkipDataLoader(DataLoaderAdapter):
|
||||
def __init__(self, dataset, skip_batches=0, use_stateful_dataloader=False, **kwargs):
|
||||
super().__init__(dataset, use_stateful_dataloader=use_stateful_dataloader, **kwargs)
|
||||
self.skip_batches = skip_batches
|
||||
self.gradient_state = GradientState()
|
||||
|
||||
def __iter__(self):
|
||||
self.begin()
|
||||
for index, batch in enumerate(self.base_dataloader.__iter__()):
|
||||
if index >= self.skip_batches:
|
||||
self._update_state_dict()
|
||||
yield batch
|
||||
self.end()
|
||||
|
||||
|
||||
def skip_first_batches(dataloader, num_batches=0):
|
||||
|
@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import tempfile
|
||||
import warnings
|
||||
from typing import List
|
||||
from unittest.mock import Mock
|
||||
@ -77,12 +77,17 @@ def create_accelerator(even_batches=True):
|
||||
return accelerator
|
||||
|
||||
|
||||
def create_dataloader(accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False):
|
||||
def create_dataloader(
|
||||
accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False, shuffle: bool = False
|
||||
):
|
||||
"""
|
||||
Create a simple DataLoader to use during the test cases
|
||||
"""
|
||||
values = torch.as_tensor(range(dataset_size))
|
||||
if shuffle:
|
||||
values = values[torch.randperm(values.size(0))]
|
||||
if iterable:
|
||||
dataset = DummyIterableDataset(torch.as_tensor(range(dataset_size)))
|
||||
dataset = DummyIterableDataset(values)
|
||||
else:
|
||||
dataset = TensorDataset(torch.as_tensor(range(dataset_size)))
|
||||
|
||||
@ -260,6 +265,81 @@ def test_data_loader(data_loader, accelerator):
|
||||
), "Not all the dataset elements have been iterated in an epoch due to duplication of samples across processes."
|
||||
|
||||
|
||||
def test_stateful_dataloader(accelerator):
|
||||
"""
|
||||
Tests that a stateful dataloader can be iterated over, saved after a few batches using `load_state_dict`, and then
|
||||
resumed from the saved state.
|
||||
|
||||
The result should be the same as the rest of the data that iterated over after saving.
|
||||
"""
|
||||
old_dataloader_config = accelerator.dataloader_config
|
||||
try:
|
||||
accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
|
||||
prepared_dl = create_dataloader(
|
||||
accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True
|
||||
)
|
||||
untrained_batches = []
|
||||
# Calculate what step that will be
|
||||
total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes)
|
||||
last_batch_num = total_batches - 1
|
||||
for step, batch in enumerate(prepared_dl):
|
||||
# Step just before
|
||||
if step == last_batch_num - 1:
|
||||
state_dict = prepared_dl.state_dict()
|
||||
if step >= last_batch_num:
|
||||
# Otherwise grab the "unseen" batches
|
||||
untrained_batches.append(batch)
|
||||
not_skipped_batches = accelerator.gather(untrained_batches)
|
||||
prepared_dl.load_state_dict(state_dict)
|
||||
resumed_batches = []
|
||||
for batch in prepared_dl:
|
||||
resumed_batches.append(batch)
|
||||
resumed_batches = accelerator.gather(resumed_batches)
|
||||
for b1, b2 in zip(not_skipped_batches, resumed_batches):
|
||||
for v1, v2 in zip(b1, b2):
|
||||
assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
|
||||
finally:
|
||||
accelerator.dataloader_config = old_dataloader_config
|
||||
|
||||
|
||||
def test_stateful_dataloader_save_state(accelerator):
|
||||
"""
|
||||
Tests that a stateful dataloader can be iterated over, saved after a few batches using `Accelerator.save_state`,
|
||||
and then resumed from the saved state.
|
||||
|
||||
The result should be the same as the rest of the data that iterated over after saving.
|
||||
"""
|
||||
old_dataloader_config = accelerator.dataloader_config
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
accelerator.dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
|
||||
prepared_dl = create_dataloader(
|
||||
accelerator, dataset_size=32 * accelerator.num_processes, batch_size=4, iterable=True, shuffle=True
|
||||
)
|
||||
untrained_batches = []
|
||||
# Calculate what step that will be
|
||||
total_batches = 32 * accelerator.num_processes // (4 * accelerator.num_processes)
|
||||
last_batch_num = total_batches - 1
|
||||
for step, batch in enumerate(prepared_dl):
|
||||
# Step just before
|
||||
if step == last_batch_num - 1:
|
||||
accelerator.save_state(tmpdir)
|
||||
if step >= last_batch_num:
|
||||
# Otherwise grab the "unseen" batches
|
||||
untrained_batches.append(batch)
|
||||
not_skipped_batches = accelerator.gather(untrained_batches)
|
||||
accelerator.load_state(tmpdir)
|
||||
resumed_batches = []
|
||||
for batch in prepared_dl:
|
||||
resumed_batches.append(batch)
|
||||
resumed_batches = accelerator.gather(resumed_batches)
|
||||
for b1, b2 in zip(not_skipped_batches, resumed_batches):
|
||||
for v1, v2 in zip(b1, b2):
|
||||
assert torch.equal(v1, v2), f"Batch {b1} and {b2} are not equal"
|
||||
finally:
|
||||
accelerator.dataloader_config = old_dataloader_config
|
||||
|
||||
|
||||
def main():
|
||||
accelerator = create_accelerator()
|
||||
torch.manual_seed(accelerator.process_index)
|
||||
@ -306,6 +386,8 @@ def main():
|
||||
sampler = BatchSampler(RandomSampler(dataset), batch_size=BATCH_SIZE, drop_last=False)
|
||||
loader = DataLoader(dataset, sampler=sampler, batch_size=None, collate_fn=default_collate, num_workers=NUM_WORKERS)
|
||||
test_data_loader(loader, accelerator)
|
||||
test_stateful_dataloader(accelerator)
|
||||
test_stateful_dataloader_save_state(accelerator)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
@ -27,7 +27,6 @@ from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
|
||||
from accelerate.accelerator import Accelerator
|
||||
from accelerate.data_loader import skip_first_batches
|
||||
from accelerate.state import GradientState, PartialState
|
||||
from accelerate.test_utils import (
|
||||
require_bnb,
|
||||
@ -682,11 +681,12 @@ class AcceleratorTester(AccelerateTestCase):
|
||||
Test that saving and loading a model with a stateful dataloader returns the same model,
|
||||
and that the dataloader's iterator is restored properly."""
|
||||
set_seed(42)
|
||||
n_train_batches = 64 # Use enough batches to ensure we can get partial iterations on large compute
|
||||
dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=True)
|
||||
accelerator = Accelerator(dataloader_config=dataloader_config)
|
||||
|
||||
model, optimizer, scheduler, train_dl, valid_dl = create_components(tied_weights)
|
||||
train_dl, valid_dl = create_dataloaders_for_test(num_workers=num_workers)
|
||||
train_dl, valid_dl = create_dataloaders_for_test(n_train_batches=n_train_batches, num_workers=num_workers)
|
||||
model = ModelForTest()
|
||||
|
||||
(
|
||||
@ -703,77 +703,53 @@ class AcceleratorTester(AccelerateTestCase):
|
||||
# Perform 3 training iterations to ensure the dataloader's iterator is advanced
|
||||
num_batches_to_skip = 3
|
||||
model.train()
|
||||
for step, batch in enumerate(prepared_train_dl):
|
||||
x, y = batch
|
||||
x.to(accelerator.device)
|
||||
y.to(accelerator.device)
|
||||
with accelerator.accumulate(prepared_model):
|
||||
untrained_batches = []
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
for step, batch in enumerate(prepared_train_dl):
|
||||
x, y = batch
|
||||
outputs = prepared_model(x)
|
||||
loss = torch.nn.functional.mse_loss(outputs, y)
|
||||
accelerator.backward(loss)
|
||||
prepared_optimizer.step()
|
||||
prepared_scheduler.step()
|
||||
prepared_optimizer.zero_grad()
|
||||
if step == num_batches_to_skip - 1:
|
||||
state_dict = prepared_train_dl.state_dict()
|
||||
# When breaking out without fully going through the iterator, must call end() to unregister this iterator from gradient state.
|
||||
# TODO: Maybe this could be done automatically?
|
||||
prepared_train_dl.end()
|
||||
break
|
||||
if step == num_batches_to_skip - 1:
|
||||
# Save the state once we've gone through a few batches
|
||||
accelerator.save_state(f"{tmpdirname}/state", safe_serialization=use_safetensors)
|
||||
if step >= num_batches_to_skip:
|
||||
untrained_batches.append(batch)
|
||||
|
||||
assert accelerator.gradient_state.active_dataloader is None
|
||||
not_skipped_batches = accelerator.gather(untrained_batches)
|
||||
# We then unwrap the trained model
|
||||
unwrapped_model = accelerator.unwrap_model(prepared_model)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
# Save model for later use
|
||||
accelerator.save_model(model, tmpdirname, safe_serialization=use_safetensors)
|
||||
original_linear1 = unwrapped_model.linear1.weight.clone()
|
||||
original_batchnorm = unwrapped_model.batchnorm.weight.clone()
|
||||
original_linear2 = unwrapped_model.linear2.weight.clone()
|
||||
|
||||
# Starting from where we left off, train this model to the end of the DataLoader
|
||||
prepared_train_dl = skip_first_batches(prepared_train_dl, num_batches_to_skip)
|
||||
batches_seen_with_original_dl = 0
|
||||
for batch in prepared_train_dl:
|
||||
x, y = batch
|
||||
x.to(accelerator.device)
|
||||
y.to(accelerator.device)
|
||||
with accelerator.accumulate(prepared_model):
|
||||
outputs = prepared_model(x)
|
||||
loss = torch.nn.functional.mse_loss(outputs, y)
|
||||
accelerator.backward(loss)
|
||||
prepared_optimizer.step()
|
||||
prepared_scheduler.step()
|
||||
prepared_optimizer.zero_grad()
|
||||
batches_seen_with_original_dl += 1
|
||||
|
||||
original_linear1 = prepared_model.linear1.weight.clone()
|
||||
original_batchnorm = prepared_model.batchnorm.weight.clone()
|
||||
original_linear2 = prepared_model.linear2.weight.clone()
|
||||
|
||||
# Load the model and state dict
|
||||
load_checkpoint_in_model(model, tmpdirname)
|
||||
stateful_train_dl, _ = create_dataloaders_for_test(num_workers=num_workers)
|
||||
prepared_stateful_train_dl = accelerator.prepare_data_loader(stateful_train_dl)
|
||||
prepared_stateful_train_dl.load_state_dict(state_dict)
|
||||
# Resume the state
|
||||
accelerator.load_state(f"{tmpdirname}/state")
|
||||
|
||||
# Train this to the end of the DataLoader
|
||||
batches_seen_with_loaded_dl = 0
|
||||
for batch in prepared_stateful_train_dl:
|
||||
for batch in prepared_train_dl:
|
||||
x, y = batch
|
||||
x.to(accelerator.device)
|
||||
y.to(accelerator.device)
|
||||
with accelerator.accumulate(prepared_model):
|
||||
outputs = prepared_model(x)
|
||||
loss = torch.nn.functional.mse_loss(outputs, y)
|
||||
accelerator.backward(loss)
|
||||
prepared_optimizer.step()
|
||||
prepared_scheduler.step()
|
||||
prepared_optimizer.zero_grad()
|
||||
outputs = prepared_model(x)
|
||||
loss = torch.nn.functional.mse_loss(outputs, y)
|
||||
accelerator.backward(loss)
|
||||
prepared_optimizer.step()
|
||||
prepared_scheduler.step()
|
||||
prepared_optimizer.zero_grad()
|
||||
batches_seen_with_loaded_dl += 1
|
||||
|
||||
new_linear1 = prepared_model.linear1.weight
|
||||
new_batchnorm = prepared_model.batchnorm.weight
|
||||
new_linear2 = prepared_model.linear2.weight
|
||||
unwrapped_model_2 = accelerator.unwrap_model(prepared_model)
|
||||
|
||||
new_linear1 = unwrapped_model_2.linear1.weight
|
||||
new_batchnorm = unwrapped_model_2.batchnorm.weight
|
||||
new_linear2 = unwrapped_model_2.linear2.weight
|
||||
|
||||
# Assert equalities
|
||||
assert batches_seen_with_original_dl == batches_seen_with_loaded_dl
|
||||
assert batches_seen_with_loaded_dl == len(not_skipped_batches)
|
||||
assert torch.allclose(original_linear1, new_linear1)
|
||||
assert torch.allclose(original_batchnorm, new_batchnorm)
|
||||
assert torch.allclose(original_linear2, new_linear2)
|
||||
|
@ -45,11 +45,13 @@ from accelerate.utils import write_basic_config
|
||||
|
||||
EXCLUDE_EXAMPLES = [
|
||||
"cross_validation.py",
|
||||
"checkpointing.py",
|
||||
"gradient_accumulation.py",
|
||||
"local_sgd.py",
|
||||
"multi_process_metrics.py",
|
||||
"memory.py",
|
||||
"schedule_free.py",
|
||||
"tracking.py",
|
||||
"automatic_gradient_accumulation.py",
|
||||
"fsdp_with_peak_mem_tracking.py",
|
||||
"deepspeed_with_config_support.py",
|
||||
|
Reference in New Issue
Block a user