mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Kineto][NCCL][3/n] Get the NCCL communication info from PARAM_COMMS_INFO (#111846)
This diff enables the functionality to get the NCCL communication metadata from `c10::DebugInfoKind::PARAM_COMMS_INFO` available in `ThreadLocalDebugInfo`. To make the overhead lighweight and avoid comparing the function name on each op, we add the method `bool isNcclMeta()`, which decided during initialization. Differential Revision: [D50439211](https://our.internmc.facebook.com/intern/diff/D50439211/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/111846 Approved by: https://github.com/aaronenyeshi ghstack dependencies: #111842, #111843
This commit is contained in:
committed by
PyTorch MergeBot
parent
1623cc5815
commit
ed15fa7cc2
@ -10,6 +10,8 @@
|
||||
|
||||
namespace at {
|
||||
|
||||
extern const std::string kParamCommsCallName = "record_param_comms";
|
||||
|
||||
namespace {
|
||||
|
||||
// Used to generate unique callback handles
|
||||
@ -712,6 +714,7 @@ uint64_t RecordFunction::currentThreadId() {
|
||||
void RecordFunction::before(const char* name, int64_t sequence_nr) {
|
||||
fn_ = name;
|
||||
sequence_nr_ = sequence_nr;
|
||||
is_nccl_meta_ = (std::strcmp(name, kParamCommsCallName.c_str()) == 0);
|
||||
|
||||
#ifndef NDEBUG
|
||||
inputs_valid_ = true;
|
||||
@ -721,6 +724,7 @@ void RecordFunction::before(const char* name, int64_t sequence_nr) {
|
||||
}
|
||||
|
||||
void RecordFunction::before(std::string name, int64_t sequence_nr) {
|
||||
is_nccl_meta_ = (name == kParamCommsCallName);
|
||||
fn_ = std::move(name);
|
||||
sequence_nr_ = sequence_nr;
|
||||
|
||||
@ -736,6 +740,7 @@ void RecordFunction::before(
|
||||
int64_t sequence_nr) {
|
||||
sequence_nr_ = sequence_nr;
|
||||
fn_ = schema;
|
||||
is_nccl_meta_ = (schema.get().name() == kParamCommsCallName);
|
||||
|
||||
#ifndef NDEBUG
|
||||
inputs_valid_ = true;
|
||||
|
||||
@ -18,6 +18,9 @@ class TORCH_API OperatorHandle;
|
||||
|
||||
namespace at {
|
||||
|
||||
// Function name to record NCCL metadata
|
||||
extern TORCH_API const std::string kParamCommsCallName;
|
||||
|
||||
// Kind of record function scope;
|
||||
enum class C10_API_ENUM RecordScope : uint8_t {
|
||||
// c10/ATen ops, autograd nodes
|
||||
@ -392,9 +395,15 @@ struct TORCH_API RecordFunction {
|
||||
// profiling.
|
||||
void _setAsync();
|
||||
|
||||
// Returns whether this RecordFunction corresponds to an async event orn ot.
|
||||
// Returns whether this RecordFunction corresponds to an async event or not.
|
||||
bool isAsync() const;
|
||||
|
||||
// Returns whether this RecordFunction corresponds to NCCL metadata collection
|
||||
// or not.
|
||||
bool isNcclMeta() const {
|
||||
return is_nccl_meta_;
|
||||
}
|
||||
|
||||
// Internal-only, used to denote out variant used for Static Runtime execution
|
||||
void _setStaticRuntimeOutVariant();
|
||||
bool isStaticRuntimeOutVariant() const;
|
||||
@ -483,6 +492,9 @@ struct TORCH_API RecordFunction {
|
||||
// Whether this RecordFunction is used for an out variant run with
|
||||
// Static Runtime
|
||||
bool is_static_runtime_out_variant_{false};
|
||||
|
||||
// Whether this RecordFunction is used for NCCL metadata collection
|
||||
bool is_nccl_meta_{false};
|
||||
};
|
||||
|
||||
TORCH_API StepCallbacks getStepCallbacks(RecordScope scope);
|
||||
|
||||
@ -244,6 +244,11 @@ struct AddGenericMetadata : public MetadataBase {
|
||||
}
|
||||
}
|
||||
|
||||
// Add extra metadata if any
|
||||
for (const auto& [key, val] : op_event.extra_meta_) {
|
||||
addMetadata(key, val);
|
||||
}
|
||||
|
||||
if (config_ && !config_->experimental_config.performance_events.empty()) {
|
||||
auto& event_names = config_->experimental_config.performance_events;
|
||||
for (const auto i : c10::irange(op_event.perf_event_counters_->size())) {
|
||||
@ -873,6 +878,7 @@ TYPED_ATTR(TorchOp, fwdThreadId, e.sequence_number_ >= 0 ? e.forward_tid_ : 0)
|
||||
TYPED_ATTR(TorchOp, scope, static_cast<uint8_t>(e.scope_))
|
||||
TYPED_ATTR(TorchOp, hasModuleHierarchy, !e.jit_modules_.empty())
|
||||
TYPED_ATTR(TorchOp, isAsync, e.is_async_)
|
||||
TYPED_ATTR(TorchOp, extraMeta, e.extra_meta_)
|
||||
TYPED_ATTR(TorchOp, fallbackStart, e.device_fallback_.device_event_start_)
|
||||
TYPED_ATTR(TorchOp, fallbackEnd, e.device_fallback_.device_event_end_)
|
||||
TYPED_ATTR(
|
||||
|
||||
@ -20,6 +20,7 @@ struct ActivityTraceWrapper;
|
||||
namespace autograd {
|
||||
namespace profiler {
|
||||
using experimental_event_t = std::shared_ptr<torch::profiler::impl::Result>;
|
||||
using extra_meta_t = std::unordered_map<std::string, std::string>;
|
||||
|
||||
struct TORCH_API KinetoEvent {
|
||||
KinetoEvent(
|
||||
@ -59,6 +60,7 @@ struct TORCH_API KinetoEvent {
|
||||
int64_t cudaElapsedUs() const;
|
||||
int64_t privateuse1ElapsedUs() const;
|
||||
void getPerfEventCounters(torch::profiler::perf_counters_t&) const;
|
||||
extra_meta_t extraMeta() const;
|
||||
|
||||
private:
|
||||
torch::profiler::impl::ProfilerVoidEventStub fallbackStart() const;
|
||||
|
||||
@ -7,8 +7,6 @@
|
||||
|
||||
namespace torch {
|
||||
|
||||
extern const std::string kParamCommsCallName = "record_param_comms";
|
||||
|
||||
ParamCommsDebugInfo::ParamCommsDebugInfo(
|
||||
int rank,
|
||||
std::string&& colName,
|
||||
|
||||
@ -9,8 +9,6 @@
|
||||
|
||||
namespace torch {
|
||||
|
||||
extern TORCH_API const std::string kParamCommsCallName;
|
||||
|
||||
class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
||||
public:
|
||||
ParamCommsDebugInfo() = default;
|
||||
@ -99,7 +97,7 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
||||
outSplitSizes, \
|
||||
worldSize}; \
|
||||
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
|
||||
RECORD_FUNCTION(torch::kParamCommsCallName, paramInputs);
|
||||
RECORD_FUNCTION(at::kParamCommsCallName, paramInputs);
|
||||
|
||||
#define RECORD_PARAM_COMMS_DATA( \
|
||||
seq, \
|
||||
@ -135,7 +133,7 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase {
|
||||
worldSize}; \
|
||||
c10::ArrayRef<const c10::IValue> paramInputs(paramList); \
|
||||
RECORD_FUNCTION_WITH_INPUTS_OUTPUTS( \
|
||||
torch::kParamCommsCallName, \
|
||||
at::kParamCommsCallName, \
|
||||
paramInputs, \
|
||||
std::vector<c10::IValue>(1, c10::IValue(OutputTensors)));
|
||||
} // namespace torch
|
||||
|
||||
@ -341,6 +341,11 @@ std::unique_ptr<KinetoObserverContext> ThreadLocalSubqueue::begin_op(
|
||||
torch::profiler::impl::saveExtraArgs(fn));
|
||||
}
|
||||
|
||||
// Record NCCL metadata for specific CPU ops
|
||||
fn.isNcclMeta() ? torch_ops_.extra_meta_.emplace_back(
|
||||
torch::profiler::impl::saveNcclMeta(fn))
|
||||
: torch_ops_.extra_meta_.emplace_back();
|
||||
|
||||
auto out = std::make_unique<KinetoObserverContext>(event);
|
||||
|
||||
if (config_.state == ProfilerState::KINETO_GPU_FALLBACK) {
|
||||
@ -434,6 +439,7 @@ void ThreadLocalSubqueue::TorchOpStorage::materialize(
|
||||
auto jit_stack = StealOrDefault<decltype(jit_stack_)>(jit_stack_);
|
||||
auto jit_module = StealOrDefault<decltype(jit_modules_)>(jit_modules_);
|
||||
auto extra_args = StealOrDefault<decltype(extra_args_)>(extra_args_);
|
||||
auto extra_meta = StealOrDefault<decltype(extra_meta_)>(extra_meta_);
|
||||
auto gpu_fallback =
|
||||
StealOrDefault<decltype(device_fallback_)>(device_fallback_);
|
||||
|
||||
@ -447,6 +453,7 @@ void ThreadLocalSubqueue::TorchOpStorage::materialize(
|
||||
jit_stack(),
|
||||
jit_module(),
|
||||
extra_args(),
|
||||
extra_meta(),
|
||||
gpu_fallback(),
|
||||
event->allow_tf32_cublas_,
|
||||
std::move(event->counters_)};
|
||||
|
||||
@ -114,6 +114,7 @@ struct TorchOpBasicFields {
|
||||
using jit_stack_t = std::vector<std::string>;
|
||||
using jit_modules_t = std::vector<std::string>;
|
||||
using extra_args_t = std::unordered_map<std::string, c10::IValue>;
|
||||
using extra_meta_t = std::unordered_map<std::string, std::string>;
|
||||
|
||||
struct FallbackPair {
|
||||
ProfilerVoidEventStub device_event_start_ = nullptr;
|
||||
@ -131,6 +132,7 @@ struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields {
|
||||
jit_stack_t&& jit_stack,
|
||||
jit_modules_t&& jit_modules,
|
||||
extra_args_t&& extra_args,
|
||||
extra_meta_t&& extra_meta,
|
||||
FallbackPair&& device_fallback,
|
||||
bool allow_tf32_cublas,
|
||||
std::unique_ptr<perf_counters_t>&& perf_event_counters)
|
||||
@ -142,6 +144,7 @@ struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields {
|
||||
jit_stack_{std::move(jit_stack)},
|
||||
jit_modules_{std::move(jit_modules)},
|
||||
extra_args_{std::move(extra_args)},
|
||||
extra_meta_{std::move(extra_meta)},
|
||||
device_fallback_{std::move(device_fallback)},
|
||||
allow_tf32_cublas_{allow_tf32_cublas},
|
||||
perf_event_counters_{std::move(perf_event_counters)} {}
|
||||
@ -152,6 +155,7 @@ struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields {
|
||||
jit_stack_t jit_stack_;
|
||||
jit_modules_t jit_modules_;
|
||||
extra_args_t extra_args_;
|
||||
extra_meta_t extra_meta_;
|
||||
FallbackPair device_fallback_;
|
||||
bool allow_tf32_cublas_;
|
||||
std::unique_ptr<perf_counters_t> perf_event_counters_;
|
||||
@ -579,6 +583,9 @@ class TORCH_API ThreadLocalSubqueue {
|
||||
// with_flops
|
||||
AppendOnlyList<extra_args_t, BlockSize> extra_args_;
|
||||
|
||||
// report extra metadata, i.e. collective communication meta
|
||||
AppendOnlyList<extra_meta_t, BlockSize> extra_meta_;
|
||||
|
||||
// ProfilerState::KINETO_GPU_FALLBACK or
|
||||
// ProfilerState::KINETO_PRIVATEUSE1_FALLBACK
|
||||
AppendOnlyList<FallbackPair, BlockSize> device_fallback_;
|
||||
|
||||
@ -9,6 +9,11 @@
|
||||
#ifdef USE_KINETO
|
||||
#include <libkineto.h>
|
||||
#endif
|
||||
#ifdef USE_DISTRIBUTED
|
||||
#ifdef USE_C10D
|
||||
#include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp>
|
||||
#endif // USE_C10D
|
||||
#endif // USE_DISTRIBUTED
|
||||
|
||||
namespace torch {
|
||||
namespace profiler {
|
||||
@ -398,6 +403,51 @@ std::vector<std::string> inputTypes(const at::RecordFunction& fn) {
|
||||
return types;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// -- NCCL Metadata -----------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------
|
||||
#ifdef USE_DISTRIBUTED
|
||||
#ifdef USE_C10D
|
||||
static constexpr auto kCommuName = "Collective name";
|
||||
static constexpr auto kDtype = "dtype";
|
||||
static constexpr auto kInMsgSize = "In msg size";
|
||||
static constexpr auto kOutMsgSize = "Out msg size";
|
||||
static constexpr auto kInSplit = "In split size";
|
||||
static constexpr auto kOutSplit = "Out split size";
|
||||
static constexpr auto kGroupSize = "Group size";
|
||||
#endif // USE_C10D
|
||||
#endif // USE_DISTRIBUTED
|
||||
|
||||
std::unordered_map<std::string, std::string> saveNcclMeta(
|
||||
const at::RecordFunction& fn) {
|
||||
std::unordered_map<std::string, std::string> map;
|
||||
#ifdef USE_DISTRIBUTED
|
||||
#ifdef USE_C10D
|
||||
auto debugInfo = dynamic_cast<ParamCommsDebugInfo*>(
|
||||
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PARAM_COMMS_INFO));
|
||||
if (debugInfo == nullptr) {
|
||||
LOG(WARNING) << "ParamCommsDebugInfo not available for function: "
|
||||
<< fn.name();
|
||||
return map;
|
||||
}
|
||||
|
||||
map.emplace(kCommuName, fmt::format("\"{}\"", debugInfo->getColumnName()));
|
||||
map.emplace(
|
||||
kDtype, fmt::format("\"{}\"", c10::toString(debugInfo->getDType())));
|
||||
map.emplace(kInMsgSize, std::to_string(debugInfo->getInMessageSize()));
|
||||
map.emplace(kOutMsgSize, std::to_string(debugInfo->getOutMessageSize()));
|
||||
map.emplace(
|
||||
kInSplit,
|
||||
fmt::format("[{}]", fmt::join(debugInfo->getInputSplitSizes(), ", ")));
|
||||
map.emplace(
|
||||
kOutSplit,
|
||||
fmt::format("[{}]", fmt::join(debugInfo->getOutputSplitSizes(), ", ")));
|
||||
map.emplace(kGroupSize, std::to_string(debugInfo->getWorldSize()));
|
||||
#endif // USE_C10D
|
||||
#endif // USE_DISTRIBUTED
|
||||
return map;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// -- FLOPS -------------------------------------------------------------------
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
@ -201,6 +201,8 @@ TORCH_API std::vector<std::string> inputTypes(const at::RecordFunction& fn);
|
||||
|
||||
std::unordered_map<std::string, c10::IValue> TORCH_API
|
||||
saveExtraArgs(const at::RecordFunction& fn);
|
||||
std::unordered_map<std::string, std::string> TORCH_API
|
||||
saveNcclMeta(const at::RecordFunction& fn);
|
||||
|
||||
uint64_t TORCH_API computeFlops(
|
||||
const std::string& op_name,
|
||||
|
||||
Reference in New Issue
Block a user