Files
pytorch/torch/csrc/distributed/c10d/ParamCommsUtils.hpp
Yuanyuan Chen 9fff8155c3 [2/N] Fix clang-tidy readability checks (#164652)
This PR applies clang-tidy readability checks to jit sources and all headers in the code base.
`readability-redundant-inline-specifier` is suppressed because it incurs too many changes. `readability-redundant-inline-specifier` is used to detect redundant inline specifiers on function and variable declarations. There are many in-class method definitions that are marked inline.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164652
Approved by: https://github.com/Skylion007
2025-10-06 01:06:01 +00:00

181 lines
8.5 KiB
C++

#pragma once
#include <ATen/core/ivalue.h>
#include <ATen/record_function.h>
#include <c10/macros/Macros.h>
#include <c10/util/ThreadLocalDebugInfo.h>
#include <string>
#include <vector>
namespace torch {
class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
public:
ParamCommsDebugInfo() = default;
ParamCommsDebugInfo(
std::tuple<std::string, std::string> pgName,
int rank,
std::string&& collName,
int64_t inNelems,
int64_t outNelems,
at::ScalarType dType,
std::vector<int64_t> inSplitSizes,
std::vector<int64_t> outSplitSizes,
int globalRankStart,
int globalRankStride,
int worldSize);
~ParamCommsDebugInfo() override = default;
const std::string getProcessGroupName() const {
return std::get<0>(pgName_);
}
const std::string getProcessGroupDesc() const {
return std::get<1>(pgName_);
}
int getRank() const {
return rank_;
}
int getWorldSize() const {
return worldSize_;
}
int getGlobalRankStart() const {
return globalRankStart_;
}
int getGlobalRankStride() const {
return globalRankStride_;
}
const std::string getCollectiveName() const {
return collectiveName_;
}
int64_t getInMessageNelems() const {
return inMessageNelems_;
}
int64_t getOutMessageNelems() const {
return outMessageNelems_;
}
at::ScalarType getDType() const {
return dType_;
}
const std::vector<int64_t>& getInputSplitSizes() const {
return inputSplitSizes_;
}
const std::vector<int64_t>& getOutputSplitSizes() const {
return outputSplitSizes_;
}
const std::vector<int64_t>& getGroupRanks() const {
return groupRanks_;
}
private:
std::tuple<std::string, std::string> pgName_; // <group_name, group_desc>
int rank_{};
int worldSize_{};
std::string collectiveName_;
int64_t inMessageNelems_{};
int64_t outMessageNelems_{};
at::ScalarType dType_ = at::kByte;
std::vector<int64_t> inputSplitSizes_;
std::vector<int64_t> outputSplitSizes_;
int globalRankStart_{};
int globalRankStride_{};
std::vector<int64_t> groupRanks_;
};
#define RECORD_PARAM_COMMS( \
seq, \
pgName, \
rank, \
collName, \
inNelems, \
outNelems, \
dType, \
inSplitSizes, \
outSplitSizes, \
globalRankStart, \
globalRankStride, \
worldSize) \
auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \
pgName, \
rank, \
collName, \
inNelems, \
outNelems, \
dType, \
inSplitSizes, \
outSplitSizes, \
globalRankStart, \
globalRankStride, \
worldSize); \
c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
std::initializer_list<const c10::IValue> paramList = { \
seq, \
pgName, \
rank, \
collName, \
inSplitSizes, \
outSplitSizes, \
globalRankStart, \
globalRankStride, \
worldSize}; \
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
RECORD_FUNCTION(at::kParamCommsCallName, paramInputs);
#define RECORD_PARAM_COMMS_DATA( \
seq, \
pgName, \
InputTensors, \
OutputTensors, \
rank, \
collName, \
inNelems, \
outNelems, \
dType, \
inSplitSizes, \
outSplitSizes, \
globalRankStart, \
globalRankStride, \
worldSize) \
auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \
pgName, \
rank, \
collName, \
inNelems, \
outNelems, \
dType, \
inSplitSizes, \
outSplitSizes, \
globalRankStart, \
globalRankStride, \
worldSize); \
c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
std::initializer_list<const c10::IValue> paramList = { \
c10::IValue(InputTensors), \
seq, \
pgName, \
rank, \
collName, \
inSplitSizes, \
outSplitSizes, \
globalRankStart, \
globalRankStride, \
worldSize}; \
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
RECORD_FUNCTION_WITH_INPUTS_OUTPUTS( \
at::kParamCommsCallName, \
paramInputs, \
std::vector<c10::IValue>(1, c10::IValue(OutputTensors)));
} // namespace torch