fix: specify device_ids in torch.distributed.barrier for PartialState (#3744)

This commit is contained in:
Quentin Gallouédec
2025-08-26 05:05:33 -07:00
committed by GitHub
parent 5dd3d0b690
commit c4460e33ef

View File

@ -400,7 +400,7 @@ class PartialState:
DistributedType.DEEPSPEED,
DistributedType.FSDP,
):
torch.distributed.barrier()
torch.distributed.barrier(device_ids=[self.process_index])
elif self.distributed_type == DistributedType.XLA:
xm.rendezvous("accelerate.utils.wait_for_everyone")