mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
Decouple prepare_data_loader()
from Accelerator (#3047)
This commit is contained in:
@ -20,7 +20,7 @@ import torch
|
||||
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler
|
||||
|
||||
from .logging import get_logger
|
||||
from .state import AcceleratorState, DistributedType, GradientState, PartialState, is_torch_xla_available
|
||||
from .state import DistributedType, GradientState, PartialState, is_torch_xla_available
|
||||
from .utils import (
|
||||
RNGType,
|
||||
broadcast,
|
||||
@ -720,7 +720,7 @@ class DataLoaderDispatcher(DataLoaderAdapter, DataLoaderStateMixin):
|
||||
torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
|
||||
|
||||
self.gradient_state = GradientState()
|
||||
self.state = AcceleratorState()
|
||||
self.state = PartialState()
|
||||
self._drop_last = _drop_last
|
||||
self._non_blocking = _non_blocking
|
||||
self.skip_batches = skip_batches
|
||||
@ -937,10 +937,9 @@ def prepare_data_loader(
|
||||
device (`torch.device`):
|
||||
The target device for the returned `DataLoader`.
|
||||
num_processes (`int`, *optional*):
|
||||
The number of processes running concurrently. Will default to the value given by
|
||||
[`~state.AcceleratorState`].
|
||||
The number of processes running concurrently. Will default to the value given by [`~state.PartialState`].
|
||||
process_index (`int`, *optional*):
|
||||
The index of the current process. Will default to the value given by [`~state.AcceleratorState`].
|
||||
The index of the current process. Will default to the value given by [`~state.PartialState`].
|
||||
split_batches (`bool`, *optional*, defaults to `False`):
|
||||
Whether the resulting `DataLoader` should split the batches of the original data loader across devices or
|
||||
yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of
|
||||
@ -1009,8 +1008,8 @@ def prepare_data_loader(
|
||||
|
||||
if dispatch_batches and not put_on_device:
|
||||
raise ValueError("Using `dispatch_batches=True` requires `put_on_device=True`.")
|
||||
# Grab defaults from AcceleratorState
|
||||
state = AcceleratorState()
|
||||
# Grab defaults from PartialState
|
||||
state = PartialState()
|
||||
if num_processes is None:
|
||||
num_processes = state.num_processes
|
||||
if process_index is None:
|
||||
|
@ -20,7 +20,7 @@ import torch
|
||||
from parameterized import parameterized
|
||||
from torch.utils.data import BatchSampler, DataLoader, IterableDataset
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate import Accelerator, PartialState
|
||||
from accelerate.data_loader import (
|
||||
BatchSamplerShard,
|
||||
DataLoaderDispatcher,
|
||||
@ -29,11 +29,12 @@ from accelerate.data_loader import (
|
||||
IterableDatasetShard,
|
||||
SkipBatchSampler,
|
||||
SkipDataLoader,
|
||||
prepare_data_loader,
|
||||
skip_first_batches,
|
||||
)
|
||||
from accelerate.state import GradientState
|
||||
from accelerate.test_utils.testing import require_torchdata_stateful_dataloader
|
||||
from accelerate.utils import is_torchdata_stateful_dataloader_available
|
||||
from accelerate.utils.dataclasses import DataLoaderConfiguration
|
||||
|
||||
|
||||
if is_torchdata_stateful_dataloader_available():
|
||||
@ -401,9 +402,8 @@ class DataLoaderTester(unittest.TestCase):
|
||||
|
||||
def test_iterable_dataset_using_none_batch_size(self):
|
||||
dataset = SimpleIterableDataset(100)
|
||||
accelerator = Accelerator()
|
||||
dataloader = DataLoader(dataset, batch_size=None)
|
||||
dataloader = accelerator.prepare(dataloader)
|
||||
dataloader = prepare_data_loader(dataloader)
|
||||
for d in dataloader:
|
||||
assert isinstance(d, torch.Tensor)
|
||||
|
||||
@ -417,7 +417,6 @@ class DataLoaderTester(unittest.TestCase):
|
||||
`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that subclasses of DataLoaderAdapter
|
||||
are instances of DataLoader and DataLoaderStateMixin.
|
||||
"""
|
||||
Accelerator()
|
||||
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2)
|
||||
dl_shard = DataLoaderShard(range(16), batch_size=4)
|
||||
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4)
|
||||
@ -454,7 +453,6 @@ class DataLoaderTester(unittest.TestCase):
|
||||
assert dataloader.end_of_dataloader == (idx == 3)
|
||||
|
||||
def test_end_of_dataloader_dispatcher(self):
|
||||
Accelerator()
|
||||
dataloader = DataLoaderDispatcher(range(16), batch_size=4)
|
||||
for idx, _ in enumerate(dataloader):
|
||||
assert dataloader.end_of_dataloader == (idx == 3)
|
||||
@ -492,7 +490,6 @@ class StatefulDataLoaderTester(unittest.TestCase):
|
||||
|
||||
@require_torchdata_stateful_dataloader
|
||||
def test_end_of_dataloader_dispatcher(self):
|
||||
Accelerator()
|
||||
dataloader = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)
|
||||
assert isinstance(dataloader, StatefulDataLoader)
|
||||
for idx, _ in enumerate(dataloader):
|
||||
@ -535,8 +532,6 @@ class StatefulDataLoaderTester(unittest.TestCase):
|
||||
"""
|
||||
Test that saving a stateful dataloader's state, then loading it back, gives the same results.
|
||||
"""
|
||||
dataloader_config = DataLoaderConfiguration(use_stateful_dataloader=True)
|
||||
Accelerator(dataloader_config=dataloader_config)
|
||||
dataset = list(range(16))
|
||||
dataloader = DataLoaderDispatcher(dataset, batch_size=4, use_stateful_dataloader=True, num_workers=num_workers)
|
||||
|
||||
@ -565,7 +560,6 @@ class StatefulDataLoaderTester(unittest.TestCase):
|
||||
`DataLoaderAdapter`'s parent classes are dynamically constructed, assert that if use_stateful_dataloader=True,
|
||||
subclasses of DataLoaderAdapter are instances of StatefulDataLoader and DataLoaderStateMixin.
|
||||
"""
|
||||
Accelerator()
|
||||
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True)
|
||||
dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True)
|
||||
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)
|
||||
@ -689,3 +683,119 @@ class StatefulDataLoaderTester(unittest.TestCase):
|
||||
assert expected_batch_results[1] == dl_results[1]
|
||||
|
||||
assert accelerator.gradient_state.active_dataloader is None
|
||||
|
||||
@parameterized.expand([0, 2], name_func=parameterized_custom_name_func)
|
||||
@require_torchdata_stateful_dataloader
|
||||
def test_decoupled_stateful_dataloader_adapter_equivalent_to_torchdata_stateful_dataloader(self, num_workers):
|
||||
"""
|
||||
Assert that `state_dict()` and `load_state_dict()` for derived subclasses of `DataLoaderAdapter` produce
|
||||
the same behavior as `state_dict()` and `load_state_dict()` for `StatefulDataLoader` when *not* using
|
||||
Accelerator (and instead using the decoupled `PartialState` workflow).
|
||||
"""
|
||||
dataset = list(range(64))
|
||||
|
||||
# Set the seed for reproducibility
|
||||
def g():
|
||||
return torch.Generator().manual_seed(42)
|
||||
|
||||
state = PartialState()
|
||||
stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())
|
||||
skip_dl = SkipDataLoader(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
dl_shard = DataLoaderShard(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
dl_dispatcher = DataLoaderDispatcher(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
|
||||
dataloaders_under_test = [skip_dl, dl_shard, dl_dispatcher]
|
||||
|
||||
num_batches_to_skip = 8
|
||||
|
||||
def get_first_n_batches(dl, n, device):
|
||||
"""
|
||||
Iterate over the first `n` batches of a dataloader then break, returning the batches in a list.
|
||||
"""
|
||||
batches = []
|
||||
for idx, batch in enumerate(dl):
|
||||
if idx == n - 1:
|
||||
if hasattr(dl, "end"):
|
||||
dl.end()
|
||||
break
|
||||
batches.append(batch.to(device))
|
||||
return batches
|
||||
|
||||
# Iterate over all of the dataloaders identically, expect the same values
|
||||
expected_batches = get_first_n_batches(stateful_dl, num_batches_to_skip, state.device)
|
||||
batches_from_dataloaders = [
|
||||
get_first_n_batches(dl, num_batches_to_skip, state.device) for dl in dataloaders_under_test
|
||||
]
|
||||
|
||||
for dl_batches in batches_from_dataloaders:
|
||||
for expected, actual in zip(expected_batches, dl_batches):
|
||||
assert torch.allclose(expected, actual)
|
||||
|
||||
# The adapters should all produce the same state_dict as the reference stateful dataloader
|
||||
expected_state_dict = stateful_dl.state_dict()
|
||||
skip_dl_state_dict = skip_dl.state_dict()
|
||||
dl_shard_state_dict = dl_shard.state_dict()
|
||||
dl_dispatcher_state_dict = dl_dispatcher.state_dict()
|
||||
|
||||
assert expected_state_dict == skip_dl_state_dict
|
||||
assert expected_state_dict == dl_shard_state_dict
|
||||
assert expected_state_dict == dl_dispatcher_state_dict
|
||||
|
||||
# Load the state dict into new dataloaders
|
||||
manual_skip_dl = SkipDataLoader(
|
||||
dataset,
|
||||
batch_size=4,
|
||||
num_workers=num_workers,
|
||||
generator=g(),
|
||||
skip_batches=num_batches_to_skip,
|
||||
use_stateful_dataloader=True,
|
||||
)
|
||||
loaded_stateful_dl = StatefulDataLoader(dataset, batch_size=4, num_workers=num_workers, generator=g())
|
||||
loaded_stateful_dl.load_state_dict(expected_state_dict)
|
||||
loaded_skip_dl = SkipDataLoader(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
loaded_skip_dl.load_state_dict(expected_state_dict)
|
||||
loaded_dl_shard = DataLoaderShard(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
loaded_dl_shard.load_state_dict(expected_state_dict)
|
||||
loaded_dl_dispatcher = DataLoaderDispatcher(
|
||||
dataset, batch_size=4, num_workers=num_workers, generator=g(), use_stateful_dataloader=True
|
||||
)
|
||||
loaded_dl_dispatcher.load_state_dict(expected_state_dict)
|
||||
|
||||
# Continue the iteration, expecting identical behavior across the board
|
||||
def get_all_batches(dl, device):
|
||||
"""
|
||||
Iterate over all batches of a dataloader, returning (batches, num_batches_yielded)
|
||||
"""
|
||||
batches = []
|
||||
num_batches_yielded = 0
|
||||
for batch in dl:
|
||||
batches.append(batch.to(device))
|
||||
num_batches_yielded += 1
|
||||
return (batches, num_batches_yielded)
|
||||
|
||||
expected_batch_results = get_all_batches(loaded_stateful_dl, state.device)
|
||||
dataloader_batch_results = [
|
||||
get_all_batches(dl, state.device)
|
||||
for dl in [manual_skip_dl, loaded_skip_dl, loaded_dl_shard, loaded_dl_dispatcher]
|
||||
]
|
||||
for dl_results in dataloader_batch_results:
|
||||
for expected, actual in zip(expected_batches, dl_batches):
|
||||
assert torch.allclose(expected[0], actual[0])
|
||||
assert expected_batch_results[1] == dl_results[1]
|
||||
|
||||
# Using the decoupled (`PartialState`) workflow, GradientState should be automatically initialized (with
|
||||
# default parameters) by `DataLoaderDispatcher`
|
||||
assert GradientState._shared_state != {}, "GradientState should already be initialized!"
|
||||
|
||||
gradient_state = GradientState()
|
||||
assert gradient_state.active_dataloader is None
|
||||
|
Reference in New Issue
Block a user