mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-14 14:14:32 +08:00
Compare commits
5 Commits
v0.32.0
...
dataloader
| Author | SHA1 | Date | |
|---|---|---|---|
| 83617de6b0 | |||
| 3310c53fa8 | |||
| b56a2583f2 | |||
| 634e84f519 | |||
| 2581a2e331 |
@ -467,11 +467,15 @@ class DataLoaderDispatcher(DataLoader):
|
|||||||
def _fetch_batches(self, iterator):
|
def _fetch_batches(self, iterator):
|
||||||
batches, batch = None, None
|
batches, batch = None, None
|
||||||
# On process 0, we gather the batch to dispatch.
|
# On process 0, we gather the batch to dispatch.
|
||||||
|
print("Starting to dispatch")
|
||||||
if self.state.process_index == 0:
|
if self.state.process_index == 0:
|
||||||
|
print("In process zero")
|
||||||
try:
|
try:
|
||||||
if self.split_batches:
|
if self.split_batches:
|
||||||
# One batch of the main iterator is dispatched and split.
|
# One batch of the main iterator is dispatched and split.
|
||||||
|
print("Getting next batch")
|
||||||
batch = next(iterator)
|
batch = next(iterator)
|
||||||
|
print(f'Batch: {batch}')
|
||||||
else:
|
else:
|
||||||
# num_processes batches of the main iterator are concatenated then dispatched and split.
|
# num_processes batches of the main iterator are concatenated then dispatched and split.
|
||||||
# We add the batches one by one so we have the remainder available when drop_last=False.
|
# We add the batches one by one so we have the remainder available when drop_last=False.
|
||||||
@ -482,12 +486,18 @@ class DataLoaderDispatcher(DataLoader):
|
|||||||
# In both cases, we need to get the structure of the batch that we will broadcast on other
|
# In both cases, we need to get the structure of the batch that we will broadcast on other
|
||||||
# processes to initialize the tensors with the right shape.
|
# processes to initialize the tensors with the right shape.
|
||||||
# data_structure, stop_iteration
|
# data_structure, stop_iteration
|
||||||
|
print("getting batch info")
|
||||||
batch_info = [get_data_structure(batch), False]
|
batch_info = [get_data_structure(batch), False]
|
||||||
|
print(f'Batch info: {batch_info}')
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
|
print("Hit stop iteration")
|
||||||
batch_info = [None, True]
|
batch_info = [None, True]
|
||||||
else:
|
else:
|
||||||
batch_info = [None, self._stop_iteration]
|
batch_info = [None, self._stop_iteration]
|
||||||
# This is inplace, so after this instruction, every process has the same `batch_info` as process 0.
|
# This is inplace, so after this instruction, every process has the same `batch_info` as process 0.
|
||||||
|
print(f'Batch info on process {AcceleratorState().process_index}: {batch_info}')
|
||||||
|
from accelerate.utils import wait_for_everyone
|
||||||
|
wait_for_everyone()
|
||||||
broadcast_object_list(batch_info)
|
broadcast_object_list(batch_info)
|
||||||
self._stop_iteration = batch_info[1]
|
self._stop_iteration = batch_info[1]
|
||||||
if self._stop_iteration:
|
if self._stop_iteration:
|
||||||
|
|||||||
Reference in New Issue
Block a user