Files
pytorch/torch/csrc/distributed/c10d/ParamCommsUtils.hpp
Yue Dong ed15fa7cc2 [Kineto][NCCL][3/n] Get the NCCL communication info from PARAM_COMMS_INFO (#111846)
This diff enables the functionality to get the NCCL communication metadata from `c10::DebugInfoKind::PARAM_COMMS_INFO` available in `ThreadLocalDebugInfo`.

To make the overhead lighweight and avoid comparing the function name on each op, we add the method `bool isNcclMeta()`, which decided during initialization.

Differential Revision: [D50439211](https://our.internmc.facebook.com/intern/diff/D50439211/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111846
Approved by: https://github.com/aaronenyeshi
ghstack dependencies: #111842, #111843
2023-10-25 20:35:06 +00:00

140 lines
6.7 KiB
C++

#pragma once
#include <ATen/core/ivalue.h>
#include <ATen/record_function.h>
#include <c10/macros/Macros.h>
#include <c10/util/ThreadLocalDebugInfo.h>
#include <string>
#include <vector>
namespace torch {
class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
public:
ParamCommsDebugInfo() = default;
ParamCommsDebugInfo(
int rank,
std::string&& colName,
int inSize,
int outSize,
at::ScalarType dType,
std::vector<int64_t> inSplitSizes,
std::vector<int64_t> outSplitSizes,
int worldSize);
~ParamCommsDebugInfo() override = default;
int getRank() const {
return rank_;
}
int getWorldSize() const {
return worldSize_;
}
const std::string getColumnName() const {
return columnName_;
}
int getInMessageSize() const {
return inMessageSize_;
}
int getOutMessageSize() const {
return outMessageSize_;
}
at::ScalarType getDType() const {
return dType_;
}
const std::vector<int64_t>& getInputSplitSizes() const {
return inputSplitSizes_;
}
const std::vector<int64_t>& getOutputSplitSizes() const {
return outputSplitSizes_;
}
private:
int rank_{};
int worldSize_{};
std::string columnName_;
int inMessageSize_{};
int outMessageSize_{};
at::ScalarType dType_ = at::kByte;
std::vector<int64_t> inputSplitSizes_;
std::vector<int64_t> outputSplitSizes_;
};
#define RECORD_PARAM_COMMS( \
seq, \
pg_ptr, \
rank, \
colName, \
inSize, \
outSize, \
dType, \
inSplitSizes, \
outSplitSizes, \
worldSize) \
auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \
rank, \
colName, \
inSize, \
outSize, \
dType, \
inSplitSizes, \
outSplitSizes, \
worldSize); \
c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
std::initializer_list<const c10::IValue> paramList = { \
c10::IValue(seq), \
c10::IValue(pg_ptr), \
rank, \
colName, \
inSplitSizes, \
outSplitSizes, \
worldSize}; \
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
RECORD_FUNCTION(at::kParamCommsCallName, paramInputs);
#define RECORD_PARAM_COMMS_DATA( \
seq, \
pg_ptr, \
InputTensors, \
OutputTensors, \
rank, \
colName, \
inSize, \
outSize, \
dType, \
inSplitSizes, \
outSplitSizes, \
worldSize) \
auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>( \
rank, \
colName, \
inSize, \
outSize, \
dType, \
inSplitSizes, \
outSplitSizes, \
worldSize); \
c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \
std::initializer_list<const c10::IValue> paramList = { \
c10::IValue(InputTensors), \
c10::IValue(seq), \
c10::IValue(pg_ptr), \
rank, \
colName, \
inSplitSizes, \
outSplitSizes, \
worldSize}; \
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
RECORD_FUNCTION_WITH_INPUTS_OUTPUTS( \
at::kParamCommsCallName, \
paramInputs, \
std::vector<c10::IValue>(1, c10::IValue(OutputTensors)));
} // namespace torch