Use cuda tensors for allgather (#1548)

This commit is contained in:
Olatunji Ruwase
2021-11-10 17:27:00 -08:00
committed by GitHub
parent af443f63f4
commit bd3ebddf36

View File

@ -889,7 +889,7 @@ class Init(InsertPostInitMethodToModuleSubClasses):
if self.use_all_gather_base:
# try the _all_gather_base on PyTorch master branch
handle = dist._all_gather_base(flat_tensor,
param.ds_tensor,
param.ds_tensor.cuda(),
group=self.ds_process_group,
async_op=async_op)
else: