mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-20 09:34:28 +08:00
Compare commits
3 Commits
v1.10.0
...
cp-dataloa
| Author | SHA1 | Date | |
|---|---|---|---|
| 4b5b838185 | |||
| 7729f44040 | |||
| 65a3bc0beb |
@ -648,6 +648,10 @@ class Accelerator:
|
||||
def dispatch_batches(self):
|
||||
return self.dataloader_config.dispatch_batches
|
||||
|
||||
@property
|
||||
def cp(self):
|
||||
return self.dataloader_config.cp
|
||||
|
||||
@property
|
||||
def even_batches(self):
|
||||
return self.dataloader_config.even_batches
|
||||
@ -2410,6 +2414,7 @@ class Accelerator:
|
||||
non_blocking=self.non_blocking,
|
||||
use_stateful_dataloader=self.use_stateful_dataloader,
|
||||
torch_device_mesh=device_mesh,
|
||||
cp=self.cp,
|
||||
)
|
||||
self._dataloaders.append(prepared_data_loader)
|
||||
return prepared_data_loader
|
||||
|
||||
@ -1007,6 +1007,7 @@ def prepare_data_loader(
|
||||
non_blocking: bool = False,
|
||||
use_stateful_dataloader: bool = False,
|
||||
torch_device_mesh=None,
|
||||
cp=False,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
|
||||
@ -1137,6 +1138,10 @@ def prepare_data_loader(
|
||||
process_index = process_index // submesh_tp_size
|
||||
num_processes = submesh_fsdp_size * submesh_dp_size
|
||||
|
||||
if cp:
|
||||
process_index = 0
|
||||
num_processes = 1
|
||||
|
||||
# Sanity check
|
||||
if split_batches:
|
||||
if dataloader.batch_size is not None:
|
||||
|
||||
@ -815,6 +815,7 @@ class DataLoaderConfiguration:
|
||||
" underlying dataset is an `IterableDataset`, `False` otherwise."
|
||||
},
|
||||
)
|
||||
cp: bool = field(default=False)
|
||||
even_batches: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
|
||||
Reference in New Issue
Block a user