mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 07:24:58 +08:00
Update on "[c10d] Integrate ncclGather into PT"
NCCL 2.28.3 now supports `ncclGather`, this PR aims at integrating it. (release note: https://docs.nvidia.com/deeplearning/nccl/release-notes/rel_2-28-3.html#rel_2-28-3) [ghstack-poisoned]
This commit is contained in:
@ -1092,8 +1092,8 @@ void gather(
|
||||
void* recv_ptr = nullptr;
|
||||
at::Tensor flat; // keep alive until after NCCL call
|
||||
if (cur_rank == root) {
|
||||
TORCH_CHECK(
|
||||
(int)outputs.size() == numranks,
|
||||
TORCH_CHECK_VALUE(
|
||||
static_cast<int>(outputs.size()) == numranks,
|
||||
"root must provide inputs.size()==numranks");
|
||||
// Allocate one flat buffer [world_size * count]
|
||||
flat = at::empty({numranks * count}, inputs.options());
|
||||
|
||||
Reference in New Issue
Block a user