[Kineto][NCCL][3/n] Get the NCCL communication info from PARAM_COMMS_INFO (#111846)

This diff enables the functionality to get the NCCL communication metadata from `c10::DebugInfoKind::PARAM_COMMS_INFO` available in `ThreadLocalDebugInfo`.

To make the overhead lighweight and avoid comparing the function name on each op, we add the method `bool isNcclMeta()`, which decided during initialization.

Differential Revision: [D50439211](https://our.internmc.facebook.com/intern/diff/D50439211/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111846
Approved by: https://github.com/aaronenyeshi
ghstack dependencies: #111842, #111843
This commit is contained in:
Yue Dong
2023-10-24 19:29:18 -07:00
committed by PyTorch MergeBot
parent 1623cc5815
commit ed15fa7cc2
10 changed files with 94 additions and 7 deletions

View File

@ -10,6 +10,8 @@
namespace at {
extern const std::string kParamCommsCallName = "record_param_comms";
namespace {
// Used to generate unique callback handles
@ -712,6 +714,7 @@ uint64_t RecordFunction::currentThreadId() {
void RecordFunction::before(const char* name, int64_t sequence_nr) {
fn_ = name;
sequence_nr_ = sequence_nr;
is_nccl_meta_ = (std::strcmp(name, kParamCommsCallName.c_str()) == 0);
#ifndef NDEBUG
inputs_valid_ = true;
@ -721,6 +724,7 @@ void RecordFunction::before(const char* name, int64_t sequence_nr) {
}
void RecordFunction::before(std::string name, int64_t sequence_nr) {
is_nccl_meta_ = (name == kParamCommsCallName);
fn_ = std::move(name);
sequence_nr_ = sequence_nr;
@ -736,6 +740,7 @@ void RecordFunction::before(
int64_t sequence_nr) {
sequence_nr_ = sequence_nr;
fn_ = schema;
is_nccl_meta_ = (schema.get().name() == kParamCommsCallName);
#ifndef NDEBUG
inputs_valid_ = true;

View File

@ -18,6 +18,9 @@ class TORCH_API OperatorHandle;
namespace at {
// Function name to record NCCL metadata
extern TORCH_API const std::string kParamCommsCallName;
// Kind of record function scope;
enum class C10_API_ENUM RecordScope : uint8_t {
// c10/ATen ops, autograd nodes
@ -392,9 +395,15 @@ struct TORCH_API RecordFunction {
// profiling.
void _setAsync();
// Returns whether this RecordFunction corresponds to an async event orn ot.
// Returns whether this RecordFunction corresponds to an async event or not.
bool isAsync() const;
// Returns whether this RecordFunction corresponds to NCCL metadata collection
// or not.
bool isNcclMeta() const {
return is_nccl_meta_;
}
// Internal-only, used to denote out variant used for Static Runtime execution
void _setStaticRuntimeOutVariant();
bool isStaticRuntimeOutVariant() const;
@ -483,6 +492,9 @@ struct TORCH_API RecordFunction {
// Whether this RecordFunction is used for an out variant run with
// Static Runtime
bool is_static_runtime_out_variant_{false};
// Whether this RecordFunction is used for NCCL metadata collection
bool is_nccl_meta_{false};
};
TORCH_API StepCallbacks getStepCallbacks(RecordScope scope);

View File

@ -244,6 +244,11 @@ struct AddGenericMetadata : public MetadataBase {
}
}
// Add extra metadata if any
for (const auto& [key, val] : op_event.extra_meta_) {
addMetadata(key, val);
}
if (config_ && !config_->experimental_config.performance_events.empty()) {
auto& event_names = config_->experimental_config.performance_events;
for (const auto i : c10::irange(op_event.perf_event_counters_->size())) {
@ -873,6 +878,7 @@ TYPED_ATTR(TorchOp, fwdThreadId, e.sequence_number_ >= 0 ? e.forward_tid_ : 0)
TYPED_ATTR(TorchOp, scope, static_cast<uint8_t>(e.scope_))
TYPED_ATTR(TorchOp, hasModuleHierarchy, !e.jit_modules_.empty())
TYPED_ATTR(TorchOp, isAsync, e.is_async_)
TYPED_ATTR(TorchOp, extraMeta, e.extra_meta_)
TYPED_ATTR(TorchOp, fallbackStart, e.device_fallback_.device_event_start_)
TYPED_ATTR(TorchOp, fallbackEnd, e.device_fallback_.device_event_end_)
TYPED_ATTR(

View File

@ -20,6 +20,7 @@ struct ActivityTraceWrapper;
namespace autograd {
namespace profiler {
using experimental_event_t = std::shared_ptr<torch::profiler::impl::Result>;
using extra_meta_t = std::unordered_map<std::string, std::string>;
struct TORCH_API KinetoEvent {
KinetoEvent(
@ -59,6 +60,7 @@ struct TORCH_API KinetoEvent {
int64_t cudaElapsedUs() const;
int64_t privateuse1ElapsedUs() const;
void getPerfEventCounters(torch::profiler::perf_counters_t&) const;
extra_meta_t extraMeta() const;
private:
torch::profiler::impl::ProfilerVoidEventStub fallbackStart() const;

View File

@ -7,8 +7,6 @@
namespace torch {
extern const std::string kParamCommsCallName = "record_param_comms";
ParamCommsDebugInfo::ParamCommsDebugInfo(
int rank,
std::string&& colName,

View File

@ -9,8 +9,6 @@
namespace torch {
extern TORCH_API const std::string kParamCommsCallName;
class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
public:
ParamCommsDebugInfo() = default;
@ -99,7 +97,7 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
outSplitSizes, \
worldSize}; \
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
RECORD_FUNCTION(torch::kParamCommsCallName, paramInputs);
RECORD_FUNCTION(at::kParamCommsCallName, paramInputs);
#define RECORD_PARAM_COMMS_DATA( \
seq, \
@ -135,7 +133,7 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
worldSize}; \
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
RECORD_FUNCTION_WITH_INPUTS_OUTPUTS( \
torch::kParamCommsCallName, \
at::kParamCommsCallName, \
paramInputs, \
std::vector<c10::IValue>(1, c10::IValue(OutputTensors)));
} // namespace torch

View File

@ -341,6 +341,11 @@ std::unique_ptr<KinetoObserverContext> ThreadLocalSubqueue::begin_op(
torch::profiler::impl::saveExtraArgs(fn));
}
// Record NCCL metadata for specific CPU ops
fn.isNcclMeta() ? torch_ops_.extra_meta_.emplace_back(
torch::profiler::impl::saveNcclMeta(fn))
: torch_ops_.extra_meta_.emplace_back();
auto out = std::make_unique<KinetoObserverContext>(event);
if (config_.state == ProfilerState::KINETO_GPU_FALLBACK) {
@ -434,6 +439,7 @@ void ThreadLocalSubqueue::TorchOpStorage::materialize(
auto jit_stack = StealOrDefault<decltype(jit_stack_)>(jit_stack_);
auto jit_module = StealOrDefault<decltype(jit_modules_)>(jit_modules_);
auto extra_args = StealOrDefault<decltype(extra_args_)>(extra_args_);
auto extra_meta = StealOrDefault<decltype(extra_meta_)>(extra_meta_);
auto gpu_fallback =
StealOrDefault<decltype(device_fallback_)>(device_fallback_);
@ -447,6 +453,7 @@ void ThreadLocalSubqueue::TorchOpStorage::materialize(
jit_stack(),
jit_module(),
extra_args(),
extra_meta(),
gpu_fallback(),
event->allow_tf32_cublas_,
std::move(event->counters_)};

View File

@ -114,6 +114,7 @@ struct TorchOpBasicFields {
using jit_stack_t = std::vector<std::string>;
using jit_modules_t = std::vector<std::string>;
using extra_args_t = std::unordered_map<std::string, c10::IValue>;
using extra_meta_t = std::unordered_map<std::string, std::string>;
struct FallbackPair {
ProfilerVoidEventStub device_event_start_ = nullptr;
@ -131,6 +132,7 @@ struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields {
jit_stack_t&& jit_stack,
jit_modules_t&& jit_modules,
extra_args_t&& extra_args,
extra_meta_t&& extra_meta,
FallbackPair&& device_fallback,
bool allow_tf32_cublas,
std::unique_ptr<perf_counters_t>&& perf_event_counters)
@ -142,6 +144,7 @@ struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields {
jit_stack_{std::move(jit_stack)},
jit_modules_{std::move(jit_modules)},
extra_args_{std::move(extra_args)},
extra_meta_{std::move(extra_meta)},
device_fallback_{std::move(device_fallback)},
allow_tf32_cublas_{allow_tf32_cublas},
perf_event_counters_{std::move(perf_event_counters)} {}
@ -152,6 +155,7 @@ struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields {
jit_stack_t jit_stack_;
jit_modules_t jit_modules_;
extra_args_t extra_args_;
extra_meta_t extra_meta_;
FallbackPair device_fallback_;
bool allow_tf32_cublas_;
std::unique_ptr<perf_counters_t> perf_event_counters_;
@ -579,6 +583,9 @@ class TORCH_API ThreadLocalSubqueue {
// with_flops
AppendOnlyList<extra_args_t, BlockSize> extra_args_;
// report extra metadata, i.e. collective communication meta
AppendOnlyList<extra_meta_t, BlockSize> extra_meta_;
// ProfilerState::KINETO_GPU_FALLBACK or
// ProfilerState::KINETO_PRIVATEUSE1_FALLBACK
AppendOnlyList<FallbackPair, BlockSize> device_fallback_;

View File

@ -9,6 +9,11 @@
#ifdef USE_KINETO
#include <libkineto.h>
#endif
#ifdef USE_DISTRIBUTED
#ifdef USE_C10D
#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
#endif // USE_C10D
#endif // USE_DISTRIBUTED
namespace torch {
namespace profiler {
@ -398,6 +403,51 @@ std::vector<std::string> inputTypes(const at::RecordFunction& fn) {
return types;
}
// ----------------------------------------------------------------------------
// -- NCCL Metadata -----------------------------------------------------------
// ----------------------------------------------------------------------------
#ifdef USE_DISTRIBUTED
#ifdef USE_C10D
static constexpr auto kCommuName = "Collective name";
static constexpr auto kDtype = "dtype";
static constexpr auto kInMsgSize = "In msg size";
static constexpr auto kOutMsgSize = "Out msg size";
static constexpr auto kInSplit = "In split size";
static constexpr auto kOutSplit = "Out split size";
static constexpr auto kGroupSize = "Group size";
#endif // USE_C10D
#endif // USE_DISTRIBUTED
std::unordered_map<std::string, std::string> saveNcclMeta(
const at::RecordFunction& fn) {
std::unordered_map<std::string, std::string> map;
#ifdef USE_DISTRIBUTED
#ifdef USE_C10D
auto debugInfo = dynamic_cast<ParamCommsDebugInfo*>(
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PARAM_COMMS_INFO));
if (debugInfo == nullptr) {
LOG(WARNING) << "ParamCommsDebugInfo not available for function: "
<< fn.name();
return map;
}
map.emplace(kCommuName, fmt::format("\"{}\"", debugInfo->getColumnName()));
map.emplace(
kDtype, fmt::format("\"{}\"", c10::toString(debugInfo->getDType())));
map.emplace(kInMsgSize, std::to_string(debugInfo->getInMessageSize()));
map.emplace(kOutMsgSize, std::to_string(debugInfo->getOutMessageSize()));
map.emplace(
kInSplit,
fmt::format("[{}]", fmt::join(debugInfo->getInputSplitSizes(), ", ")));
map.emplace(
kOutSplit,
fmt::format("[{}]", fmt::join(debugInfo->getOutputSplitSizes(), ", ")));
map.emplace(kGroupSize, std::to_string(debugInfo->getWorldSize()));
#endif // USE_C10D
#endif // USE_DISTRIBUTED
return map;
}
// ----------------------------------------------------------------------------
// -- FLOPS -------------------------------------------------------------------
// ----------------------------------------------------------------------------

View File

@ -201,6 +201,8 @@ TORCH_API std::vector<std::string> inputTypes(const at::RecordFunction& fn);
std::unordered_map<std::string, c10::IValue> TORCH_API
saveExtraArgs(const at::RecordFunction& fn);
std::unordered_map<std::string, std::string> TORCH_API
saveNcclMeta(const at::RecordFunction& fn);
uint64_t TORCH_API computeFlops(
const std::string& op_name,