mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-13 21:59:16 +08:00
Compare commits
5 Commits
v0.32.0
...
dataloader
| Author | SHA1 | Date | |
|---|---|---|---|
| 83617de6b0 | |||
| 3310c53fa8 | |||
| b56a2583f2 | |||
| 634e84f519 | |||
| 2581a2e331 |
@ -467,11 +467,15 @@ class DataLoaderDispatcher(DataLoader):
|
||||
def _fetch_batches(self, iterator):
|
||||
batches, batch = None, None
|
||||
# On process 0, we gather the batch to dispatch.
|
||||
print("Starting to dispatch")
|
||||
if self.state.process_index == 0:
|
||||
print("In process zero")
|
||||
try:
|
||||
if self.split_batches:
|
||||
# One batch of the main iterator is dispatched and split.
|
||||
print("Getting next batch")
|
||||
batch = next(iterator)
|
||||
print(f'Batch: {batch}')
|
||||
else:
|
||||
# num_processes batches of the main iterator are concatenated then dispatched and split.
|
||||
# We add the batches one by one so we have the remainder available when drop_last=False.
|
||||
@ -482,12 +486,18 @@ class DataLoaderDispatcher(DataLoader):
|
||||
# In both cases, we need to get the structure of the batch that we will broadcast on other
|
||||
# processes to initialize the tensors with the right shape.
|
||||
# data_structure, stop_iteration
|
||||
print("getting batch info")
|
||||
batch_info = [get_data_structure(batch), False]
|
||||
print(f'Batch info: {batch_info}')
|
||||
except StopIteration:
|
||||
print("Hit stop iteration")
|
||||
batch_info = [None, True]
|
||||
else:
|
||||
batch_info = [None, self._stop_iteration]
|
||||
# This is inplace, so after this instruction, every process has the same `batch_info` as process 0.
|
||||
print(f'Batch info on process {AcceleratorState().process_index}: {batch_info}')
|
||||
from accelerate.utils import wait_for_everyone
|
||||
wait_for_everyone()
|
||||
broadcast_object_list(batch_info)
|
||||
self._stop_iteration = batch_info[1]
|
||||
if self._stop_iteration:
|
||||
|
||||
Reference in New Issue
Block a user