Compare commits

...

5 Commits

Author SHA1 Message Date
83617de6b0 More prints 2023-04-25 10:49:55 -04:00
3310c53fa8 Wait for everyone 2023-04-25 10:46:47 -04:00
b56a2583f2 Wait for everyone 2023-04-25 10:44:56 -04:00
634e84f519 More debug 2023-04-25 10:39:09 -04:00
2581a2e331 Print 2023-04-25 10:32:50 -04:00

View File

@ -467,11 +467,15 @@ class DataLoaderDispatcher(DataLoader):
def _fetch_batches(self, iterator):
batches, batch = None, None
# On process 0, we gather the batch to dispatch.
print("Starting to dispatch")
if self.state.process_index == 0:
print("In process zero")
try:
if self.split_batches:
# One batch of the main iterator is dispatched and split.
print("Getting next batch")
batch = next(iterator)
print(f'Batch: {batch}')
else:
# num_processes batches of the main iterator are concatenated then dispatched and split.
# We add the batches one by one so we have the remainder available when drop_last=False.
@ -482,12 +486,18 @@ class DataLoaderDispatcher(DataLoader):
# In both cases, we need to get the structure of the batch that we will broadcast on other
# processes to initialize the tensors with the right shape.
# data_structure, stop_iteration
print("getting batch info")
batch_info = [get_data_structure(batch), False]
print(f'Batch info: {batch_info}')
except StopIteration:
print("Hit stop iteration")
batch_info = [None, True]
else:
batch_info = [None, self._stop_iteration]
# This is inplace, so after this instruction, every process has the same `batch_info` as process 0.
print(f'Batch info on process {AcceleratorState().process_index}: {batch_info}')
from accelerate.utils import wait_for_everyone
wait_for_everyone()
broadcast_object_list(batch_info)
self._stop_iteration = batch_info[1]
if self._stop_iteration: