Compare commits

...

11 Commits

3 changed files with 35 additions and 12 deletions

View File

@ -217,6 +217,10 @@ class Accelerator:
If set to `True`, in cases where the total batch size across all processes does not exactly divide the
dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
all workers.
use_seedable_sampler (`bool`, *optional*, defaults to `False`):
Whether or not use a fully seedable random sampler ([`~data_loader.SeedableRandomSampler`]). Comes at a
cost of potentially different performances due to different shuffling algorithms, but will ensure the
training results are fully reproducible.
step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`):
Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only
done under certain circumstances (at the end of each epoch, for instance).
@ -262,6 +266,7 @@ class Accelerator:
gradient_accumulation_plugin: GradientAccumulationPlugin | None = None,
dispatch_batches: bool | None = None,
even_batches: bool = True,
use_seedable_sampler: bool = False,
step_scheduler_with_optimizer: bool = True,
kwargs_handlers: list[KwargsHandler] | None = None,
dynamo_backend: DynamoBackend | str | None = None,
@ -417,6 +422,7 @@ class Accelerator:
self.split_batches = split_batches
self.dispatch_batches = dispatch_batches
self.even_batches = even_batches
self.use_seedable_sampler = use_seedable_sampler
self.step_scheduler_with_optimizer = step_scheduler_with_optimizer
# Mixed precision attributes
@ -1811,7 +1817,10 @@ class Accelerator:
return tuple(result)
def prepare_data_loader(
self, data_loader: torch.utils.data.DataLoader, device_placement=None, slice_fn_for_dispatch=None
self,
data_loader: torch.utils.data.DataLoader,
device_placement=None,
slice_fn_for_dispatch=None,
):
"""
Prepares a PyTorch DataLoader for training in any distributed setup. It is recommended to use
@ -1828,6 +1837,7 @@ class Accelerator:
[`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will
be ignored otherwise.
Example:
```python
@ -1857,6 +1867,7 @@ class Accelerator:
dispatch_batches=self.dispatch_batches,
even_batches=self.even_batches,
slice_fn_for_dispatch=slice_fn_for_dispatch,
use_seedable_sampler=self.use_seedable_sampler,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader

View File

@ -744,6 +744,7 @@ def prepare_data_loader(
dispatch_batches: Optional[bool] = None,
even_batches: bool = True,
slice_fn_for_dispatch: Optional[Callable] = None,
use_seedable_sampler: bool = False,
) -> DataLoader:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
@ -797,6 +798,10 @@ def prepare_data_loader(
If passed, this function will be used to slice tensors across `num_processes`. Will default to
[`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will be
ignored otherwise.
use_seedable_sampler (`bool`, *optional*, defaults to `False`):
Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better
reproducability. Comes at a cost of potentially different performances due to different shuffling
algorithms.
Returns:
`torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches
@ -840,7 +845,7 @@ def prepare_data_loader(
sampler = getattr(dataloader.sampler, "sampler", None)
else:
sampler = getattr(dataloader.batch_sampler, "sampler", None)
if isinstance(sampler, RandomSampler):
if isinstance(sampler, RandomSampler) and use_seedable_sampler:
# When iterating through the dataloader during distributed processes
# we want to ensure that on each process we are iterating through the same
# samples in the same order if a seed is set. This requires a tweak
@ -899,7 +904,7 @@ def prepare_data_loader(
kwargs["batch_size"] = (
dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size
)
if isinstance(sampler, SeedableRandomSampler):
if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler:
if sampler_is_batch_sampler:
dataloader.sampler.sampler = sampler
else:

View File

@ -335,19 +335,22 @@ def custom_sampler_check():
), "Custom sampler was changed after calling `prepare_data_loader`"
def mock_training(length, batch_size, generator):
def mock_training(length, batch_size, generator, use_seedable_sampler=False):
set_seed(42)
generator.manual_seed(42)
train_set = RegressionDataset(length=length, seed=42)
# The SeedableRandomSampler is needed during distributed setups
# for full reproducability across processes with the `DataLoader`
sampler = SeedableRandomSampler(
generator=generator,
data_source=train_set,
num_samples=len(train_set),
)
train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler)
if use_seedable_sampler:
# The SeedableRandomSampler is needed during distributed setups
# for full reproducability across processes with the `DataLoader`
sampler = SeedableRandomSampler(
generator=generator,
data_source=train_set,
num_samples=len(train_set),
)
train_dl = DataLoader(train_set, batch_size=batch_size, sampler=sampler)
else:
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for epoch in range(3):
@ -370,6 +373,10 @@ def training_check():
assert are_the_same_tensors(old_model.a), "Did not obtain the same model on both processes."
assert are_the_same_tensors(old_model.b), "Did not obtain the same model on both processes."
train_set, old_model = mock_training(length, batch_size * state.num_processes, generator, True)
assert are_the_same_tensors(old_model.a), "Did not obtain the same model on both processes."
assert are_the_same_tensors(old_model.b), "Did not obtain the same model on both processes."
accelerator = Accelerator()
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()