[Kineto][NCCL][3/n] Get the NCCL communication info from PARAM_COMMS_INFO (#111846)

This diff enables the functionality to get the NCCL communication metadata from `c10::DebugInfoKind::PARAM_COMMS_INFO` available in `ThreadLocalDebugInfo`.

To make the overhead lighweight and avoid comparing the function name on each op, we add the method `bool isNcclMeta()`, which decided during initialization.

Differential Revision: [D50439211](https://our.internmc.facebook.com/intern/diff/D50439211/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111846
Approved by: https://github.com/aaronenyeshi
ghstack dependencies: #111842, #111843
This commit is contained in:
Yue Dong
2023-10-24 19:29:18 -07:00
committed by PyTorch MergeBot
parent 1623cc5815
commit ed15fa7cc2
10 changed files with 94 additions and 7 deletions

View File

@ -9,8 +9,6 @@
namespace torch {
extern TORCH_API const std::string kParamCommsCallName;
class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
public:
ParamCommsDebugInfo() = default;
@ -99,7 +97,7 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
outSplitSizes, \
worldSize}; \
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
RECORD_FUNCTION(torch::kParamCommsCallName, paramInputs);
RECORD_FUNCTION(at::kParamCommsCallName, paramInputs);
#define RECORD_PARAM_COMMS_DATA( \
seq, \
@ -135,7 +133,7 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
worldSize}; \
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
RECORD_FUNCTION_WITH_INPUTS_OUTPUTS( \
torch::kParamCommsCallName, \
at::kParamCommsCallName, \
paramInputs, \
std::vector<c10::IValue>(1, c10::IValue(OutputTensors)));
} // namespace torch