[nccl-pg] Store PG global rank information in tracing logs (#115730)

Storing the list of global ranks associated with each PG allows us to correlate traces across different ranks.

Test Plan:

OSS CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115730
Approved by: https://github.com/fduwjj
This commit is contained in:
Pavan Balaji
2023-12-14 00:59:12 +00:00
committed by PyTorch MergeBot
parent b38e14c12a
commit ffc826bf10
6 changed files with 112 additions and 1 deletions

View File

@ -15,6 +15,8 @@ ParamCommsDebugInfo::ParamCommsDebugInfo(
at::ScalarType dType,
std::vector<int64_t> inSplitSizes,
std::vector<int64_t> outSplitSizes,
int globalRankStart,
int globalRankStride,
int worldSize)
: rank_(rank),
worldSize_(worldSize),
@ -23,6 +25,8 @@ ParamCommsDebugInfo::ParamCommsDebugInfo(
outMessageNelems_(outNelems),
dType_(dType),
inputSplitSizes_(std::move(inSplitSizes)),
outputSplitSizes_(std::move(outSplitSizes)) {}
outputSplitSizes_(std::move(outSplitSizes)),
globalRankStart_(globalRankStart),
globalRankStride_(globalRankStride) {}
} // namespace torch

View File

@ -20,6 +20,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
at::ScalarType dType,
std::vector<int64_t> inSplitSizes,
std::vector<int64_t> outSplitSizes,
int globalRankStart,
int globalRankStride,
int worldSize);
~ParamCommsDebugInfo() override = default;
@ -32,6 +34,14 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
return worldSize_;
}
int getGlobalRankStart() const {
return globalRankStart_;
}
int getGlobalRankStride() const {
return globalRankStride_;
}
const std::string getColumnName() const {
return columnName_;
}
@ -65,6 +75,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
at::ScalarType dType_ = at::kByte;
std::vector<int64_t> inputSplitSizes_;
std::vector<int64_t> outputSplitSizes_;
int globalRankStart_;
int globalRankStride_;
};
#define RECORD_PARAM_COMMS( \
@ -77,6 +89,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
dType, \
inSplitSizes, \
outSplitSizes, \
globalRankStart, \
globalRankStride, \
worldSize) \
auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \
rank, \
@ -86,6 +100,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
dType, \
inSplitSizes, \
outSplitSizes, \
globalRankStart, \
globalRankStride, \
worldSize); \
c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
std::initializer_list<const c10::IValue> paramList = { \
@ -95,6 +111,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
colName, \
inSplitSizes, \
outSplitSizes, \
globalRankStart, \
globalRankStride, \
worldSize}; \
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
RECORD_FUNCTION(at::kParamCommsCallName, paramInputs);
@ -111,6 +129,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
dType, \
inSplitSizes, \
outSplitSizes, \
globalRankStart, \
globalRankStride, \
worldSize) \
auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \
rank, \
@ -120,6 +140,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
dType, \
inSplitSizes, \
outSplitSizes, \
globalRankStart, \
globalRankStride, \
worldSize); \
c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
std::initializer_list<const c10::IValue> paramList = { \
@ -130,6 +152,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
colName, \
inSplitSizes, \
outSplitSizes, \
globalRankStart, \
globalRankStride, \
worldSize}; \
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
RECORD_FUNCTION_WITH_INPUTS_OUTPUTS( \

View File

@ -641,6 +641,8 @@ bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) {
at::kByte, // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSizes
-1,
-1,
static_cast<int>(devices_.size())); // worldSize
synchronizeInternal(timeout);
// Always return true, because abort API is not implemented.
@ -806,6 +808,38 @@ ProcessGroupNCCL::ProcessGroupNCCL(
<< ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_
<< ", NCCL_DEBUG: " << nccl_debug << ", ID=" << this->getID();
if (options_->global_ranks_in_group.empty()) {
this->globalRankStart = 0;
} else {
this->globalRankStart = options_->global_ranks_in_group[0];
}
if (options_->global_ranks_in_group.empty()) {
this->globalRankStride = 1;
} else if (options_->global_ranks_in_group.size() == 1) {
this->globalRankStride = 0;
} else {
bool ranksAreStrided = true;
int startRank = options_->global_ranks_in_group[0];
int stride =
options_->global_ranks_in_group[1] - options_->global_ranks_in_group[0];
for (std::vector<uint64_t>::size_type i = 0;
i < options_->global_ranks_in_group.size();
i++) {
if (options_->global_ranks_in_group[i] != startRank + i * stride) {
ranksAreStrided = false;
break;
}
}
if (ranksAreStrided) {
this->globalRankStride = options_->global_ranks_in_group[1] -
options_->global_ranks_in_group[0];
} else {
this->globalRankStride = -1;
}
}
RECORD_PARAM_COMMS(
0, // seq
this->getID(),
@ -816,6 +850,8 @@ ProcessGroupNCCL::ProcessGroupNCCL(
at::kByte, // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSizes
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
size_); // worldSize
// Attach hooks to cache allocator to trigger the hooks whenever a traced
@ -2774,6 +2810,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce(
tensor.scalar_type(), // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSizes
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
// avoidRecordStreams_ note: collective() will stash tensors.
@ -2800,6 +2838,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_coalesced(
// I'm not sure what in,outSplitSizes mean here.
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSizes
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
// avoidRecordStreams_ note: collective() will stash tensors.
@ -2826,6 +2866,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
tensor.scalar_type(), // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSizes
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
// avoidRecordStreams_ note: collective() will stash tensors.
@ -2889,6 +2931,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_broadcast_oop(
tensor.scalar_type(), // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSizes
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
return collective(
@ -2931,6 +2975,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce(
tensor.scalar_type(), // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSizes
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
int dev_in_group = 0;
@ -2995,6 +3041,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_oop(
tensor.scalar_type(), // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSizes
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
int dev_in_group{0};
@ -3053,6 +3101,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
tensor.scalar_type(), // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSize
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
return collective(
@ -3196,6 +3246,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
tensor.scalar_type(), // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSizes
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
return collective(
@ -3315,6 +3367,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
tensor.scalar_type(), // dtype
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSizes
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
auto inputs = std::vector<at::Tensor>{inputTensor};
@ -3403,6 +3457,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) {
at::kByte, // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSizes
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
std::vector<at::Device> devices;
@ -3484,6 +3540,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
inputTensor.scalar_type(), // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSizes
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
// avoidRecordStreams_ note: collective() will stash inputTensors and
@ -3526,6 +3584,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
inputTensor.scalar_type(), // dType
inputSplitSizes, // inSplitSizes
outputSplitSizes, // outSplitSizes
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
// avoidRecordStreams_ note: collective() will stash inputTensors and
@ -3602,6 +3662,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
inputTensors.front().scalar_type(), // dType
inSplitSizes, // inSplitSizes
outSplitSizes, // outSplitSizes
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
std::vector<at::Tensor> inputTensor0 = {inputTensors[0]};
@ -3653,6 +3715,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::send(
tensor.scalar_type(), // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSizes
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
auto ret = pointToPoint(
@ -3691,6 +3755,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::recv(
tensor.scalar_type(), // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSizes
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
auto ret = pointToPoint(
@ -3851,6 +3917,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::gather(
tensor.scalar_type(), // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSize
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
// avoidRecordStreams_ note: collective() will stash inputTensors and
@ -3938,6 +4006,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter(
tensor.scalar_type(), // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSize
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
// avoidRecordStreams_ note: collective() will stash outputTensors and
@ -4009,6 +4079,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
tensor.scalar_type(), // dType
std::vector<int64_t>(), // inSplitSizes
std::vector<int64_t>(), // outSplitSize
globalRankStart, // globalRankStart
globalRankStride, // globalRankStride
this->getSize()); // worldSize
// just a wrapper to fit the collective interface

View File

@ -590,6 +590,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
OpType opType);
private:
int globalRankStart;
int globalRankStride;
// Helper that encapsulates work shared across all collective communication
// primitives. The callbacks have the following signatures:
//

View File

@ -158,6 +158,8 @@ void CommTraceLogger::recordComms(
dtype,
curInSplitSizes_,
curOutSplitSizes_,
-1,
-1,
world_size);
++seqnum;

View File

@ -344,6 +344,8 @@ static constexpr auto kInMsgNelems = "In msg nelems";
static constexpr auto kOutMsgNelems = "Out msg nelems";
static constexpr auto kInSplit = "In split size";
static constexpr auto kOutSplit = "Out split size";
static constexpr auto kGlobalRankStart = "Global rank start";
static constexpr auto kGlobalRankStride = "Global rank stride";
static constexpr auto kGroupSize = "Group size";
static constexpr int32_t kTruncatLength = 30;
#endif // USE_C10D
@ -395,6 +397,10 @@ std::unordered_map<std::string, std::string> saveNcclMeta(
outSplitSizes.begin() + kTruncatLength,
", ")));
}
map.emplace(
kGlobalRankStart, std::to_string(debugInfo->getGlobalRankStart()));
map.emplace(
kGlobalRankStride, std::to_string(debugInfo->getGlobalRankStride()));
map.emplace(kGroupSize, std::to_string(debugInfo->getWorldSize()));
#endif // USE_C10D
#endif // USE_DISTRIBUTED