diff --git a/src/accelerate/state.py b/src/accelerate/state.py index 9353c422..e7542638 100644 --- a/src/accelerate/state.py +++ b/src/accelerate/state.py @@ -400,7 +400,7 @@ class PartialState: DistributedType.DEEPSPEED, DistributedType.FSDP, ): - torch.distributed.barrier() + torch.distributed.barrier(device_ids=[self.process_index]) elif self.distributed_type == DistributedType.XLA: xm.rendezvous("accelerate.utils.wait_for_everyone")