mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit 3c5ca685d6f5b6f3971c0cd20a054aa355610419. Reverted https://github.com/pytorch/pytorch/pull/164652 on behalf of https://github.com/izaitsevfb due to need to revert due to a conflict with revert of https://github.com/pytorch/pytorch/pull/162659 ([comment](https://github.com/pytorch/pytorch/pull/164652#issuecomment-3369346707))
181 lines
8.5 KiB
C++
181 lines
8.5 KiB
C++
#pragma once
|
|
|
|
#include <ATen/core/ivalue.h>
|
|
#include <ATen/record_function.h>
|
|
#include <c10/macros/Macros.h>
|
|
#include <c10/util/ThreadLocalDebugInfo.h>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
|
|
class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
|
public:
|
|
ParamCommsDebugInfo() = default;
|
|
ParamCommsDebugInfo(
|
|
std::tuple<std::string, std::string> pgName,
|
|
int rank,
|
|
std::string&& collName,
|
|
int64_t inNelems,
|
|
int64_t outNelems,
|
|
at::ScalarType dType,
|
|
std::vector<int64_t> inSplitSizes,
|
|
std::vector<int64_t> outSplitSizes,
|
|
int globalRankStart,
|
|
int globalRankStride,
|
|
int worldSize);
|
|
|
|
~ParamCommsDebugInfo() override = default;
|
|
|
|
const std::string getProcessGroupName() const {
|
|
return std::get<0>(pgName_);
|
|
}
|
|
|
|
const std::string getProcessGroupDesc() const {
|
|
return std::get<1>(pgName_);
|
|
}
|
|
|
|
int getRank() const {
|
|
return rank_;
|
|
}
|
|
|
|
int getWorldSize() const {
|
|
return worldSize_;
|
|
}
|
|
|
|
int getGlobalRankStart() const {
|
|
return globalRankStart_;
|
|
}
|
|
|
|
int getGlobalRankStride() const {
|
|
return globalRankStride_;
|
|
}
|
|
|
|
const std::string getCollectiveName() const {
|
|
return collectiveName_;
|
|
}
|
|
|
|
int64_t getInMessageNelems() const {
|
|
return inMessageNelems_;
|
|
}
|
|
|
|
int64_t getOutMessageNelems() const {
|
|
return outMessageNelems_;
|
|
}
|
|
|
|
at::ScalarType getDType() const {
|
|
return dType_;
|
|
}
|
|
|
|
const std::vector<int64_t>& getInputSplitSizes() const {
|
|
return inputSplitSizes_;
|
|
}
|
|
|
|
const std::vector<int64_t>& getOutputSplitSizes() const {
|
|
return outputSplitSizes_;
|
|
}
|
|
|
|
const std::vector<int64_t>& getGroupRanks() const {
|
|
return groupRanks_;
|
|
}
|
|
|
|
private:
|
|
std::tuple<std::string, std::string> pgName_; // <group_name, group_desc>
|
|
int rank_{};
|
|
int worldSize_{};
|
|
std::string collectiveName_;
|
|
int64_t inMessageNelems_{};
|
|
int64_t outMessageNelems_{};
|
|
at::ScalarType dType_ = at::kByte;
|
|
std::vector<int64_t> inputSplitSizes_;
|
|
std::vector<int64_t> outputSplitSizes_;
|
|
int globalRankStart_{};
|
|
int globalRankStride_{};
|
|
std::vector<int64_t> groupRanks_{};
|
|
};
|
|
|
|
#define RECORD_PARAM_COMMS( \
|
|
seq, \
|
|
pgName, \
|
|
rank, \
|
|
collName, \
|
|
inNelems, \
|
|
outNelems, \
|
|
dType, \
|
|
inSplitSizes, \
|
|
outSplitSizes, \
|
|
globalRankStart, \
|
|
globalRankStride, \
|
|
worldSize) \
|
|
auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \
|
|
pgName, \
|
|
rank, \
|
|
collName, \
|
|
inNelems, \
|
|
outNelems, \
|
|
dType, \
|
|
inSplitSizes, \
|
|
outSplitSizes, \
|
|
globalRankStart, \
|
|
globalRankStride, \
|
|
worldSize); \
|
|
c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
|
|
std::initializer_list<const c10::IValue> paramList = { \
|
|
seq, \
|
|
pgName, \
|
|
rank, \
|
|
collName, \
|
|
inSplitSizes, \
|
|
outSplitSizes, \
|
|
globalRankStart, \
|
|
globalRankStride, \
|
|
worldSize}; \
|
|
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
|
|
RECORD_FUNCTION(at::kParamCommsCallName, paramInputs);
|
|
|
|
#define RECORD_PARAM_COMMS_DATA( \
|
|
seq, \
|
|
pgName, \
|
|
InputTensors, \
|
|
OutputTensors, \
|
|
rank, \
|
|
collName, \
|
|
inNelems, \
|
|
outNelems, \
|
|
dType, \
|
|
inSplitSizes, \
|
|
outSplitSizes, \
|
|
globalRankStart, \
|
|
globalRankStride, \
|
|
worldSize) \
|
|
auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \
|
|
pgName, \
|
|
rank, \
|
|
collName, \
|
|
inNelems, \
|
|
outNelems, \
|
|
dType, \
|
|
inSplitSizes, \
|
|
outSplitSizes, \
|
|
globalRankStart, \
|
|
globalRankStride, \
|
|
worldSize); \
|
|
c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
|
|
std::initializer_list<const c10::IValue> paramList = { \
|
|
c10::IValue(InputTensors), \
|
|
seq, \
|
|
pgName, \
|
|
rank, \
|
|
collName, \
|
|
inSplitSizes, \
|
|
outSplitSizes, \
|
|
globalRankStart, \
|
|
globalRankStride, \
|
|
worldSize}; \
|
|
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
|
|
RECORD_FUNCTION_WITH_INPUTS_OUTPUTS( \
|
|
at::kParamCommsCallName, \
|
|
paramInputs, \
|
|
std::vector<c10::IValue>(1, c10::IValue(OutputTensors)));
|
|
} // namespace torch
|