mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Kineto][NCCL][3/n] Get the NCCL communication info from PARAM_COMMS_INFO (#111846)
This diff enables the functionality to get the NCCL communication metadata from `c10::DebugInfoKind::PARAM_COMMS_INFO` available in `ThreadLocalDebugInfo`. To make the overhead lighweight and avoid comparing the function name on each op, we add the method `bool isNcclMeta()`, which decided during initialization. Differential Revision: [D50439211](https://our.internmc.facebook.com/intern/diff/D50439211/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/111846 Approved by: https://github.com/aaronenyeshi ghstack dependencies: #111842, #111843
This commit is contained in:
committed by
PyTorch MergeBot
parent
1623cc5815
commit
ed15fa7cc2
@ -9,8 +9,6 @@
|
||||
|
||||
namespace torch {
|
||||
|
||||
extern TORCH_API const std::string kParamCommsCallName;
|
||||
|
||||
class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
||||
public:
|
||||
ParamCommsDebugInfo() = default;
|
||||
@ -99,7 +97,7 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
||||
outSplitSizes, \
|
||||
worldSize}; \
|
||||
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
|
||||
RECORD_FUNCTION(torch::kParamCommsCallName, paramInputs);
|
||||
RECORD_FUNCTION(at::kParamCommsCallName, paramInputs);
|
||||
|
||||
#define RECORD_PARAM_COMMS_DATA( \
|
||||
seq, \
|
||||
@ -135,7 +133,7 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
||||
worldSize}; \
|
||||
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
|
||||
RECORD_FUNCTION_WITH_INPUTS_OUTPUTS( \
|
||||
torch::kParamCommsCallName, \
|
||||
at::kParamCommsCallName, \
|
||||
paramInputs, \
|
||||
std::vector<c10::IValue>(1, c10::IValue(OutputTensors)));
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user