mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[nccl-pg] Store PG global rank information in tracing logs (#115730)
Storing the list of global ranks associated with each PG allows us to correlate traces across different ranks. Test Plan: OSS CI Pull Request resolved: https://github.com/pytorch/pytorch/pull/115730 Approved by: https://github.com/fduwjj
This commit is contained in:
committed by
PyTorch MergeBot
parent
b38e14c12a
commit
ffc826bf10
@ -15,6 +15,8 @@ ParamCommsDebugInfo::ParamCommsDebugInfo(
|
||||
at::ScalarType dType,
|
||||
std::vector<int64_t> inSplitSizes,
|
||||
std::vector<int64_t> outSplitSizes,
|
||||
int globalRankStart,
|
||||
int globalRankStride,
|
||||
int worldSize)
|
||||
: rank_(rank),
|
||||
worldSize_(worldSize),
|
||||
@ -23,6 +25,8 @@ ParamCommsDebugInfo::ParamCommsDebugInfo(
|
||||
outMessageNelems_(outNelems),
|
||||
dType_(dType),
|
||||
inputSplitSizes_(std::move(inSplitSizes)),
|
||||
outputSplitSizes_(std::move(outSplitSizes)) {}
|
||||
outputSplitSizes_(std::move(outSplitSizes)),
|
||||
globalRankStart_(globalRankStart),
|
||||
globalRankStride_(globalRankStride) {}
|
||||
|
||||
} // namespace torch
|
||||
|
||||
@ -20,6 +20,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
||||
at::ScalarType dType,
|
||||
std::vector<int64_t> inSplitSizes,
|
||||
std::vector<int64_t> outSplitSizes,
|
||||
int globalRankStart,
|
||||
int globalRankStride,
|
||||
int worldSize);
|
||||
|
||||
~ParamCommsDebugInfo() override = default;
|
||||
@ -32,6 +34,14 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
||||
return worldSize_;
|
||||
}
|
||||
|
||||
int getGlobalRankStart() const {
|
||||
return globalRankStart_;
|
||||
}
|
||||
|
||||
int getGlobalRankStride() const {
|
||||
return globalRankStride_;
|
||||
}
|
||||
|
||||
const std::string getColumnName() const {
|
||||
return columnName_;
|
||||
}
|
||||
@ -65,6 +75,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
||||
at::ScalarType dType_ = at::kByte;
|
||||
std::vector<int64_t> inputSplitSizes_;
|
||||
std::vector<int64_t> outputSplitSizes_;
|
||||
int globalRankStart_;
|
||||
int globalRankStride_;
|
||||
};
|
||||
|
||||
#define RECORD_PARAM_COMMS( \
|
||||
@ -77,6 +89,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
||||
dType, \
|
||||
inSplitSizes, \
|
||||
outSplitSizes, \
|
||||
globalRankStart, \
|
||||
globalRankStride, \
|
||||
worldSize) \
|
||||
auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \
|
||||
rank, \
|
||||
@ -86,6 +100,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
||||
dType, \
|
||||
inSplitSizes, \
|
||||
outSplitSizes, \
|
||||
globalRankStart, \
|
||||
globalRankStride, \
|
||||
worldSize); \
|
||||
c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
|
||||
std::initializer_list<const c10::IValue> paramList = { \
|
||||
@ -95,6 +111,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
||||
colName, \
|
||||
inSplitSizes, \
|
||||
outSplitSizes, \
|
||||
globalRankStart, \
|
||||
globalRankStride, \
|
||||
worldSize}; \
|
||||
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
|
||||
RECORD_FUNCTION(at::kParamCommsCallName, paramInputs);
|
||||
@ -111,6 +129,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
||||
dType, \
|
||||
inSplitSizes, \
|
||||
outSplitSizes, \
|
||||
globalRankStart, \
|
||||
globalRankStride, \
|
||||
worldSize) \
|
||||
auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \
|
||||
rank, \
|
||||
@ -120,6 +140,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
||||
dType, \
|
||||
inSplitSizes, \
|
||||
outSplitSizes, \
|
||||
globalRankStart, \
|
||||
globalRankStride, \
|
||||
worldSize); \
|
||||
c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
|
||||
std::initializer_list<const c10::IValue> paramList = { \
|
||||
@ -130,6 +152,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
||||
colName, \
|
||||
inSplitSizes, \
|
||||
outSplitSizes, \
|
||||
globalRankStart, \
|
||||
globalRankStride, \
|
||||
worldSize}; \
|
||||
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
|
||||
RECORD_FUNCTION_WITH_INPUTS_OUTPUTS( \
|
||||
|
||||
@ -641,6 +641,8 @@ bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) {
|
||||
at::kByte, // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSizes
|
||||
-1,
|
||||
-1,
|
||||
static_cast<int>(devices_.size())); // worldSize
|
||||
synchronizeInternal(timeout);
|
||||
// Always return true, because abort API is not implemented.
|
||||
@ -806,6 +808,38 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
<< ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_
|
||||
<< ", NCCL_DEBUG: " << nccl_debug << ", ID=" << this->getID();
|
||||
|
||||
if (options_->global_ranks_in_group.empty()) {
|
||||
this->globalRankStart = 0;
|
||||
} else {
|
||||
this->globalRankStart = options_->global_ranks_in_group[0];
|
||||
}
|
||||
|
||||
if (options_->global_ranks_in_group.empty()) {
|
||||
this->globalRankStride = 1;
|
||||
} else if (options_->global_ranks_in_group.size() == 1) {
|
||||
this->globalRankStride = 0;
|
||||
} else {
|
||||
bool ranksAreStrided = true;
|
||||
int startRank = options_->global_ranks_in_group[0];
|
||||
int stride =
|
||||
options_->global_ranks_in_group[1] - options_->global_ranks_in_group[0];
|
||||
for (std::vector<uint64_t>::size_type i = 0;
|
||||
i < options_->global_ranks_in_group.size();
|
||||
i++) {
|
||||
if (options_->global_ranks_in_group[i] != startRank + i * stride) {
|
||||
ranksAreStrided = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (ranksAreStrided) {
|
||||
this->globalRankStride = options_->global_ranks_in_group[1] -
|
||||
options_->global_ranks_in_group[0];
|
||||
} else {
|
||||
this->globalRankStride = -1;
|
||||
}
|
||||
}
|
||||
|
||||
RECORD_PARAM_COMMS(
|
||||
0, // seq
|
||||
this->getID(),
|
||||
@ -816,6 +850,8 @@ ProcessGroupNCCL::ProcessGroupNCCL(
|
||||
at::kByte, // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSizes
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
size_); // worldSize
|
||||
|
||||
// Attach hooks to cache allocator to trigger the hooks whenever a traced
|
||||
@ -2774,6 +2810,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce(
|
||||
tensor.scalar_type(), // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSizes
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
// avoidRecordStreams_ note: collective() will stash tensors.
|
||||
@ -2800,6 +2838,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_coalesced(
|
||||
// I'm not sure what in,outSplitSizes mean here.
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSizes
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
// avoidRecordStreams_ note: collective() will stash tensors.
|
||||
@ -2826,6 +2866,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast(
|
||||
tensor.scalar_type(), // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSizes
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
// avoidRecordStreams_ note: collective() will stash tensors.
|
||||
@ -2889,6 +2931,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_broadcast_oop(
|
||||
tensor.scalar_type(), // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSizes
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
return collective(
|
||||
@ -2931,6 +2975,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce(
|
||||
tensor.scalar_type(), // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSizes
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
int dev_in_group = 0;
|
||||
@ -2995,6 +3041,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_oop(
|
||||
tensor.scalar_type(), // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSizes
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
int dev_in_group{0};
|
||||
@ -3053,6 +3101,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather(
|
||||
tensor.scalar_type(), // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSize
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
return collective(
|
||||
@ -3196,6 +3246,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
|
||||
tensor.scalar_type(), // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSizes
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
return collective(
|
||||
@ -3315,6 +3367,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
|
||||
tensor.scalar_type(), // dtype
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSizes
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
auto inputs = std::vector<at::Tensor>{inputTensor};
|
||||
@ -3403,6 +3457,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) {
|
||||
at::kByte, // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSizes
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
std::vector<at::Device> devices;
|
||||
@ -3484,6 +3540,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
|
||||
inputTensor.scalar_type(), // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSizes
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
// avoidRecordStreams_ note: collective() will stash inputTensors and
|
||||
@ -3526,6 +3584,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base(
|
||||
inputTensor.scalar_type(), // dType
|
||||
inputSplitSizes, // inSplitSizes
|
||||
outputSplitSizes, // outSplitSizes
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
// avoidRecordStreams_ note: collective() will stash inputTensors and
|
||||
@ -3602,6 +3662,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall(
|
||||
inputTensors.front().scalar_type(), // dType
|
||||
inSplitSizes, // inSplitSizes
|
||||
outSplitSizes, // outSplitSizes
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
std::vector<at::Tensor> inputTensor0 = {inputTensors[0]};
|
||||
@ -3653,6 +3715,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::send(
|
||||
tensor.scalar_type(), // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSizes
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
auto ret = pointToPoint(
|
||||
@ -3691,6 +3755,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::recv(
|
||||
tensor.scalar_type(), // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSizes
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
auto ret = pointToPoint(
|
||||
@ -3851,6 +3917,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::gather(
|
||||
tensor.scalar_type(), // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSize
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
// avoidRecordStreams_ note: collective() will stash inputTensors and
|
||||
@ -3938,6 +4006,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter(
|
||||
tensor.scalar_type(), // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSize
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
// avoidRecordStreams_ note: collective() will stash outputTensors and
|
||||
@ -4009,6 +4079,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base(
|
||||
tensor.scalar_type(), // dType
|
||||
std::vector<int64_t>(), // inSplitSizes
|
||||
std::vector<int64_t>(), // outSplitSize
|
||||
globalRankStart, // globalRankStart
|
||||
globalRankStride, // globalRankStride
|
||||
this->getSize()); // worldSize
|
||||
|
||||
// just a wrapper to fit the collective interface
|
||||
|
||||
@ -590,6 +590,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
OpType opType);
|
||||
|
||||
private:
|
||||
int globalRankStart;
|
||||
int globalRankStride;
|
||||
|
||||
// Helper that encapsulates work shared across all collective communication
|
||||
// primitives. The callbacks have the following signatures:
|
||||
//
|
||||
|
||||
@ -158,6 +158,8 @@ void CommTraceLogger::recordComms(
|
||||
dtype,
|
||||
curInSplitSizes_,
|
||||
curOutSplitSizes_,
|
||||
-1,
|
||||
-1,
|
||||
world_size);
|
||||
|
||||
++seqnum;
|
||||
|
||||
@ -344,6 +344,8 @@ static constexpr auto kInMsgNelems = "In msg nelems";
|
||||
static constexpr auto kOutMsgNelems = "Out msg nelems";
|
||||
static constexpr auto kInSplit = "In split size";
|
||||
static constexpr auto kOutSplit = "Out split size";
|
||||
static constexpr auto kGlobalRankStart = "Global rank start";
|
||||
static constexpr auto kGlobalRankStride = "Global rank stride";
|
||||
static constexpr auto kGroupSize = "Group size";
|
||||
static constexpr int32_t kTruncatLength = 30;
|
||||
#endif // USE_C10D
|
||||
@ -395,6 +397,10 @@ std::unordered_map<std::string, std::string> saveNcclMeta(
|
||||
outSplitSizes.begin() + kTruncatLength,
|
||||
", ")));
|
||||
}
|
||||
map.emplace(
|
||||
kGlobalRankStart, std::to_string(debugInfo->getGlobalRankStart()));
|
||||
map.emplace(
|
||||
kGlobalRankStride, std::to_string(debugInfo->getGlobalRankStride()));
|
||||
map.emplace(kGroupSize, std::to_string(debugInfo->getWorldSize()));
|
||||
#endif // USE_C10D
|
||||
#endif // USE_DISTRIBUTED
|
||||
|
||||
Reference in New Issue
Block a user