Compare commits

...

3 Commits

Author SHA1 Message Date
4b5b838185 other modif 2025-06-12 17:37:50 +00:00
7729f44040 some modif 2025-06-12 17:23:23 +00:00
65a3bc0beb test 2025-06-12 14:53:39 +00:00
3 changed files with 11 additions and 0 deletions

View File

@ -648,6 +648,10 @@ class Accelerator:
def dispatch_batches(self):
return self.dataloader_config.dispatch_batches
@property
def cp(self):
return self.dataloader_config.cp
@property
def even_batches(self):
return self.dataloader_config.even_batches
@ -2410,6 +2414,7 @@ class Accelerator:
non_blocking=self.non_blocking,
use_stateful_dataloader=self.use_stateful_dataloader,
torch_device_mesh=device_mesh,
cp=self.cp,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader

View File

@ -1007,6 +1007,7 @@ def prepare_data_loader(
non_blocking: bool = False,
use_stateful_dataloader: bool = False,
torch_device_mesh=None,
cp=False,
) -> DataLoader:
"""
Wraps a PyTorch `DataLoader` to generate batches for one of the processes only.
@ -1137,6 +1138,10 @@ def prepare_data_loader(
process_index = process_index // submesh_tp_size
num_processes = submesh_fsdp_size * submesh_dp_size
if cp:
process_index = 0
num_processes = 1
# Sanity check
if split_batches:
if dataloader.batch_size is not None:

View File

@ -815,6 +815,7 @@ class DataLoaderConfiguration:
" underlying dataset is an `IterableDataset`, `False` otherwise."
},
)
cp: bool = field(default=False)
even_batches: bool = field(
default=True,
metadata={