Decouple prepare_data_loader() from Accelerator (#3047)

This commit is contained in:
Sidd Karamcheti
2024-08-26 07:19:59 -07:00
committed by GitHub
parent 726140cad2
commit 2789933938
2 changed files with 126 additions and 17 deletions

View File

@ -20,7 +20,7 @@ import torch
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler
from .logging import get_logger
from .state import AcceleratorState, DistributedType, GradientState, PartialState, is_torch_xla_available
from .state import DistributedType, GradientState, PartialState, is_torch_xla_available
from .utils import (
RNGType,
broadcast,
@ -720,7 +720,7 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
self.gradient_state = GradientState()
self.state = AcceleratorState()
self.state = PartialState()
self._drop_last = _drop_last
self._non_blocking = _non_blocking
self.skip_batches = skip_batches
@ -937,10 +937,9 @@ def prepare_data_loader(
device (`torch.device`):
The target device for the returned `DataLoader`.
num_processes (`int`, *optional*):
The number of processes running concurrently. Will default to the value given by
[`~state.AcceleratorState`].
The number of processes running concurrently. Will default to the value given by [`~state.PartialState`].
process_index (`int`, *optional*):
The index of the current process. Will default to the value given by [`~state.AcceleratorState`].
The index of the current process. Will default to the value given by [`~state.PartialState`].
split_batches (`bool`, *optional*, defaults to `False`):
Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
@ -1009,8 +1008,8 @@ def prepare_data_loader(
if dispatch_batches and not put_on_device:
raise ValueError("Using `dispatch_batches=True` requires `put_on_device=True`.")
# Grab defaults from AcceleratorState
state = AcceleratorState()
# Grab defaults from PartialState
state = PartialState()
if num_processes is None:
num_processes = state.num_processes
if process_index is None:

View File

@ -20,7 +20,7 @@ import torch
from parameterized import parameterized
from torch.utils.data import BatchSampler, DataLoader, IterableDataset
from accelerate import Accelerator
from accelerate import Accelerator, PartialState
from accelerate.data_loader import (
BatchSamplerShard,
DataLoaderDispatcher,
@ -29,11 +29,12 @@ from accelerate.data_loader import (
IterableDatasetShard,
SkipBatchSampler,
SkipDataLoader,
prepare_data_loader,
skip_first_batches,
)
from accelerate.state import GradientState
from accelerate.test_utils.testing import require_torchdata_stateful_dataloader
from accelerate.utils import is_torchdata_stateful_dataloader_available
from accelerate.utils.dataclasses import DataLoaderConfiguration
if is_torchdata_stateful_dataloader_available():
@ -401,9 +402,8 @@ class DataLoaderTester(unittest.TestCase):
def test_iterable_dataset_using_none_batch_size(self):
dataset = SimpleIterableDataset(100)
accelerator = Accelerator()
dataloader = DataLoader(dataset, batch_size=None)
dataloader = accelerator.prepare(dataloader)
dataloader = prepare_data_loader(dataloader)
for d in dataloader:
assert isinstance(d, torch.Tensor)
@ -417,7 +417,6 @@ class DataLoaderTester(unittest.TestCase):
`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that subclasses of DataLoaderAdapter
are instances of DataLoader and DataLoaderStateMixin.
"""
Accelerator()
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2)
dl_shard = DataLoaderShard(range(16), batch_size=4)
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4)
@ -454,7 +453,6 @@ class DataLoaderTester(unittest.TestCase):
assert dataloader.end_of_dataloader == (idx == 3)
def test_end_of_dataloader_dispatcher(self):
Accelerator()
dataloader = DataLoaderDispatcher(range(16), batch_size=4)
for idx, _ in enumerate(dataloader):
assert dataloader.end_of_dataloader == (idx == 3)
@ -492,7 +490,6 @@ class StatefulDataLoaderTester(unittest.TestCase):
@require_torchdata_stateful_dataloader
def test_end_of_dataloader_dispatcher(self):
Accelerator()
dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)
assert isinstance(dataloader, StatefulDataLoader)
for idx, _ in enumerate(dataloader):
@ -535,8 +532,6 @@ class StatefulDataLoaderTester(unittest.TestCase):
"""
Test that saving a stateful dataloader's state, then loading it back, gives the same results.
"""
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
Accelerator(dataloader_config=dataloader_config)
dataset = list(range(16))
dataloader = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers)
@ -565,7 +560,6 @@ class StatefulDataLoaderTester(unittest.TestCase):
`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that if use_stateful_dataloader=True,
subclasses of DataLoaderAdapter are instances of StatefulDataLoader and DataLoaderStateMixin.
"""
Accelerator()
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True)
dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True)
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)
@ -689,3 +683,119 @@ class StatefulDataLoaderTester(unittest.TestCase):
assert expected_batch_results[1] == dl_results[1]
assert accelerator.gradient_state.active_dataloader is None
@parameterized.expand([0, 2], name_func=parameterized_custom_name_func)
@require_torchdata_stateful_dataloader
def test_decoupled_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader(self, num_workers):
"""
Assert that `state_dict()` and `load_state_dict()` for derived subclasses of `DataLoaderAdapter` produce
the same behavior as `state_dict()` and `load_state_dict()` for `StatefulDataLoader` when *not* using
Accelerator (and instead using the decoupled `PartialState` workflow).
"""
dataset = list(range(64))
# Set the seed for reproducibility
def g():
return torch.Generator().manual_seed(42)
state = PartialState()
stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())
skip_dl = SkipDataLoader(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
dl_shard = DataLoaderShard(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
dl_dispatcher = DataLoaderDispatcher(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
dataloaders_under_test = [skip_dl, dl_shard, dl_dispatcher]
num_batches_to_skip = 8
def get_first_n_batches(dl, n, device):
"""
Iterate over the first `n` batches of a dataloader then break, returning the batches in a list.
"""
batches = []
for idx, batch in enumerate(dl):
if idx == n - 1:
if hasattr(dl, "end"):
dl.end()
break
batches.append(batch.to(device))
return batches
# Iterate over all of the dataloaders identically, expect the same values
expected_batches = get_first_n_batches(stateful_dl, num_batches_to_skip, state.device)
batches_from_dataloaders = [
get_first_n_batches(dl, num_batches_to_skip, state.device) for dl in dataloaders_under_test
]
for dl_batches in batches_from_dataloaders:
for expected, actual in zip(expected_batches, dl_batches):
assert torch.allclose(expected, actual)
# The adapters should all produce the same state_dict as the reference stateful dataloader
expected_state_dict = stateful_dl.state_dict()
skip_dl_state_dict = skip_dl.state_dict()
dl_shard_state_dict = dl_shard.state_dict()
dl_dispatcher_state_dict = dl_dispatcher.state_dict()
assert expected_state_dict == skip_dl_state_dict
assert expected_state_dict == dl_shard_state_dict
assert expected_state_dict == dl_dispatcher_state_dict
# Load the state dict into new dataloaders
manual_skip_dl = SkipDataLoader(
dataset,
batch_size=4,
num_workers=num_workers,
generator=g(),
skip_batches=num_batches_to_skip,
use_stateful_dataloader=True,
)
loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())
loaded_stateful_dl.load_state_dict(expected_state_dict)
loaded_skip_dl = SkipDataLoader(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
loaded_skip_dl.load_state_dict(expected_state_dict)
loaded_dl_shard = DataLoaderShard(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
loaded_dl_shard.load_state_dict(expected_state_dict)
loaded_dl_dispatcher = DataLoaderDispatcher(
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
)
loaded_dl_dispatcher.load_state_dict(expected_state_dict)
# Continue the iteration, expecting identical behavior across the board
def get_all_batches(dl, device):
"""
Iterate over all batches of a dataloader, returning (batches, num_batches_yielded)
"""
batches = []
num_batches_yielded = 0
for batch in dl:
batches.append(batch.to(device))
num_batches_yielded += 1
return (batches, num_batches_yielded)
expected_batch_results = get_all_batches(loaded_stateful_dl, state.device)
dataloader_batch_results = [
get_all_batches(dl, state.device)
for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher]
]
for dl_results in dataloader_batch_results:
for expected, actual in zip(expected_batches, dl_batches):
assert torch.allclose(expected[0], actual[0])
assert expected_batch_results[1] == dl_results[1]
# Using the decoupled (`PartialState`) workflow, GradientState should be automatically initialized (with
# default parameters) by `DataLoaderDispatcher`
assert GradientState._shared_state != {}, "GradientState should already be initialized!"
gradient_state = GradientState()
assert gradient_state.active_dataloader is None