Compare commits

...

1 Commits

2 changed files with 19 additions and 0 deletions

View File

@ -4768,6 +4768,13 @@ class Trainer:
"`non_blocking` is enabled but `dataloader_pin_memory` is not. For the best performance, it's recommended to enable both."
)
dataloader_config.non_blocking = non_blocking
use_stateful_dataloader = accelerator_config.pop("use_stateful_dataloader")
if use_stateful_dataloader:
if not is_accelerate_available("0.34.0"):
raise ImportError(
"`use_stateful_dataloader` is only supported in accelerate v0.34.0 and above. Please upgrade accelerate to use this feature."
)
dataloader_config.use_stateful_dataloader = use_stateful_dataloader
# this would have been updated above, no need for it anymore
accelerator_config.pop("gradient_accumulation_kwargs")

View File

@ -1233,6 +1233,9 @@ class AcceleratorConfig:
training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results.
use_stateful_dataloader (`bool`, *optional*, defaults to `False`):
If set to `True`, the dataloader prepared by the Accelerator will be backed by [torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
Doing so makes resuming from a checkpoint faster as the `DataLoader` is configured with a resumable state. Requires `torchdata>=0.8.0` and `accelerate>=0.24.0`.
gradient_accumulation_kwargs (`dict`, *optional*):
Additional kwargs to configure gradient accumulation, see [`accelerate.utils.GradientAccumulationPlugin`].
Any of the following (optional) keys are acceptable:
@ -1288,6 +1291,15 @@ class AcceleratorConfig:
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
},
)
use_stateful_dataloader: bool = field(
default=False,
metadata={
"help": "If set to `True`, the dataloader prepared by the Accelerator will be backed by "
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). "
"Doing so makes resuming from a checkpoint faster as the `DataLoader` is configured with a resumable state. "
"Requires `torchdata>=0.8.0` and `accelerate>=0.24.0`."
},
)
non_blocking: Optional[bool] = field(
default=False,