mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-24 15:44:58 +08:00
In Gloo backend use ring reduction by default (#9309)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9309 This is faster when you're dealing with a small number of processes. Around the 16 processes mark the halving/doubling algorithm is faster. Reviewed By: apaszke Differential Revision: D8785364 fbshipit-source-id: 4a03326266e473026d943787186e149d0cc489f0
This commit is contained in:
committed by
Facebook Github Bot
parent
00b4b4703e
commit
aeccec755d
@ -1,8 +1,10 @@
|
||||
#include "ProcessGroupGloo.hpp"
|
||||
|
||||
#include <gloo/allreduce_halving_doubling.h>
|
||||
#include <gloo/allreduce_ring_chunked.h>
|
||||
#include <gloo/broadcast_one_to_all.h>
|
||||
#include <gloo/cuda_allreduce_halving_doubling.h>
|
||||
#include <gloo/cuda_allreduce_ring_chunked.h>
|
||||
#include <gloo/cuda_broadcast_one_to_all.h>
|
||||
#include <gloo/rendezvous/context.h>
|
||||
#include <gloo/transport/tcp/device.h>
|
||||
@ -320,22 +322,40 @@ void ProcessGroupGloo::createAllreduce(AlgorithmEntry& entry) {
|
||||
auto& context = contexts_[0];
|
||||
|
||||
if (backend == at::kCPU) {
|
||||
entry.algorithm = std::unique_ptr<::gloo::Algorithm>(
|
||||
new ::gloo::AllreduceHalvingDoubling<T>(
|
||||
context,
|
||||
getDataPointers<T>(entry.src),
|
||||
entry.src[0].numel(),
|
||||
reductionFunction<T>(key.reduceOp)));
|
||||
if (getSize() < 16) {
|
||||
entry.algorithm = std::unique_ptr<::gloo::Algorithm>(
|
||||
new ::gloo::AllreduceRingChunked<T>(
|
||||
context,
|
||||
getDataPointers<T>(entry.src),
|
||||
entry.src[0].numel(),
|
||||
reductionFunction<T>(key.reduceOp)));
|
||||
} else {
|
||||
entry.algorithm = std::unique_ptr<::gloo::Algorithm>(
|
||||
new ::gloo::AllreduceHalvingDoubling<T>(
|
||||
context,
|
||||
getDataPointers<T>(entry.src),
|
||||
entry.src[0].numel(),
|
||||
reductionFunction<T>(key.reduceOp)));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (backend == at::kCUDA) {
|
||||
entry.algorithm = std::unique_ptr<::gloo::Algorithm>(
|
||||
new ::gloo::CudaAllreduceHalvingDoubling<T>(
|
||||
context,
|
||||
getDataPointers<T>(entry.src),
|
||||
entry.src[0].numel(),
|
||||
getStreamVector(entry)));
|
||||
if (getSize() < 16) {
|
||||
entry.algorithm = std::unique_ptr<::gloo::Algorithm>(
|
||||
new ::gloo::CudaAllreduceRingChunked<T>(
|
||||
context,
|
||||
getDataPointers<T>(entry.src),
|
||||
entry.src[0].numel(),
|
||||
getStreamVector(entry)));
|
||||
} else {
|
||||
entry.algorithm = std::unique_ptr<::gloo::Algorithm>(
|
||||
new ::gloo::CudaAllreduceHalvingDoubling<T>(
|
||||
context,
|
||||
getDataPointers<T>(entry.src),
|
||||
entry.src[0].numel(),
|
||||
getStreamVector(entry)));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user