mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[Kineto][NCCL][3/n] Get the NCCL communication info from PARAM_COMMS_INFO (#111846)
This diff enables the functionality to get the NCCL communication metadata from `c10::DebugInfoKind::PARAM_COMMS_INFO` available in `ThreadLocalDebugInfo`. To make the overhead lighweight and avoid comparing the function name on each op, we add the method `bool isNcclMeta()`, which decided during initialization. Differential Revision: [D50439211](https://our.internmc.facebook.com/intern/diff/D50439211/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/111846 Approved by: https://github.com/aaronenyeshi ghstack dependencies: #111842, #111843
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							1623cc5815
						
					
				
				
					commit
					ed15fa7cc2
				
			| @ -10,6 +10,8 @@ | ||||
|  | ||||
| namespace at { | ||||
|  | ||||
| extern const std::string kParamCommsCallName = "record_param_comms"; | ||||
|  | ||||
| namespace { | ||||
|  | ||||
| // Used to generate unique callback handles | ||||
| @ -712,6 +714,7 @@ uint64_t RecordFunction::currentThreadId() { | ||||
| void RecordFunction::before(const char* name, int64_t sequence_nr) { | ||||
|   fn_ = name; | ||||
|   sequence_nr_ = sequence_nr; | ||||
|   is_nccl_meta_ = (std::strcmp(name, kParamCommsCallName.c_str()) == 0); | ||||
|  | ||||
| #ifndef NDEBUG | ||||
|     inputs_valid_ = true; | ||||
| @ -721,6 +724,7 @@ void RecordFunction::before(const char* name, int64_t sequence_nr) { | ||||
| } | ||||
|  | ||||
| void RecordFunction::before(std::string name, int64_t sequence_nr) { | ||||
|   is_nccl_meta_ = (name == kParamCommsCallName); | ||||
|   fn_ = std::move(name); | ||||
|   sequence_nr_ = sequence_nr; | ||||
|  | ||||
| @ -736,6 +740,7 @@ void RecordFunction::before( | ||||
|     int64_t sequence_nr) { | ||||
|   sequence_nr_ = sequence_nr; | ||||
|   fn_ = schema; | ||||
|   is_nccl_meta_ = (schema.get().name() == kParamCommsCallName); | ||||
|  | ||||
| #ifndef NDEBUG | ||||
|     inputs_valid_ = true; | ||||
|  | ||||
| @ -18,6 +18,9 @@ class TORCH_API OperatorHandle; | ||||
|  | ||||
| namespace at { | ||||
|  | ||||
| // Function name to record NCCL metadata | ||||
| extern TORCH_API const std::string kParamCommsCallName; | ||||
|  | ||||
| // Kind of record function scope; | ||||
| enum class C10_API_ENUM RecordScope : uint8_t { | ||||
|   // c10/ATen ops, autograd nodes | ||||
| @ -392,9 +395,15 @@ struct TORCH_API RecordFunction { | ||||
|   // profiling. | ||||
|   void _setAsync(); | ||||
|  | ||||
|   // Returns whether this RecordFunction corresponds to an async event orn ot. | ||||
|   // Returns whether this RecordFunction corresponds to an async event or not. | ||||
|   bool isAsync() const; | ||||
|  | ||||
|   // Returns whether this RecordFunction corresponds to NCCL metadata collection | ||||
|   // or not. | ||||
|   bool isNcclMeta() const { | ||||
|     return is_nccl_meta_; | ||||
|   } | ||||
|  | ||||
|   // Internal-only, used to denote out variant used for Static Runtime execution | ||||
|   void _setStaticRuntimeOutVariant(); | ||||
|   bool isStaticRuntimeOutVariant() const; | ||||
| @ -483,6 +492,9 @@ struct TORCH_API RecordFunction { | ||||
|   // Whether this RecordFunction is used for an out variant run with | ||||
|   // Static Runtime | ||||
|   bool is_static_runtime_out_variant_{false}; | ||||
|  | ||||
|   // Whether this RecordFunction is used for NCCL metadata collection | ||||
|   bool is_nccl_meta_{false}; | ||||
| }; | ||||
|  | ||||
| TORCH_API StepCallbacks getStepCallbacks(RecordScope scope); | ||||
|  | ||||
| @ -244,6 +244,11 @@ struct AddGenericMetadata : public MetadataBase { | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     // Add extra metadata if any | ||||
|     for (const auto& [key, val] : op_event.extra_meta_) { | ||||
|       addMetadata(key, val); | ||||
|     } | ||||
|  | ||||
|     if (config_ && !config_->experimental_config.performance_events.empty()) { | ||||
|       auto& event_names = config_->experimental_config.performance_events; | ||||
|       for (const auto i : c10::irange(op_event.perf_event_counters_->size())) { | ||||
| @ -873,6 +878,7 @@ TYPED_ATTR(TorchOp, fwdThreadId, e.sequence_number_ >= 0 ? e.forward_tid_ : 0) | ||||
| TYPED_ATTR(TorchOp, scope, static_cast<uint8_t>(e.scope_)) | ||||
| TYPED_ATTR(TorchOp, hasModuleHierarchy, !e.jit_modules_.empty()) | ||||
| TYPED_ATTR(TorchOp, isAsync, e.is_async_) | ||||
| TYPED_ATTR(TorchOp, extraMeta, e.extra_meta_) | ||||
| TYPED_ATTR(TorchOp, fallbackStart, e.device_fallback_.device_event_start_) | ||||
| TYPED_ATTR(TorchOp, fallbackEnd, e.device_fallback_.device_event_end_) | ||||
| TYPED_ATTR( | ||||
|  | ||||
| @ -20,6 +20,7 @@ struct ActivityTraceWrapper; | ||||
| namespace autograd { | ||||
| namespace profiler { | ||||
| using experimental_event_t = std::shared_ptr<torch::profiler::impl::Result>; | ||||
| using extra_meta_t = std::unordered_map<std::string, std::string>; | ||||
|  | ||||
| struct TORCH_API KinetoEvent { | ||||
|   KinetoEvent( | ||||
| @ -59,6 +60,7 @@ struct TORCH_API KinetoEvent { | ||||
|   int64_t cudaElapsedUs() const; | ||||
|   int64_t privateuse1ElapsedUs() const; | ||||
|   void getPerfEventCounters(torch::profiler::perf_counters_t&) const; | ||||
|   extra_meta_t extraMeta() const; | ||||
|  | ||||
|  private: | ||||
|   torch::profiler::impl::ProfilerVoidEventStub fallbackStart() const; | ||||
|  | ||||
| @ -7,8 +7,6 @@ | ||||
|  | ||||
| namespace torch { | ||||
|  | ||||
| extern const std::string kParamCommsCallName = "record_param_comms"; | ||||
|  | ||||
| ParamCommsDebugInfo::ParamCommsDebugInfo( | ||||
|     int rank, | ||||
|     std::string&& colName, | ||||
|  | ||||
| @ -9,8 +9,6 @@ | ||||
|  | ||||
| namespace torch { | ||||
|  | ||||
| extern TORCH_API const std::string kParamCommsCallName; | ||||
|  | ||||
| class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase { | ||||
|  public: | ||||
|   ParamCommsDebugInfo() = default; | ||||
| @ -99,7 +97,7 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase { | ||||
|       outSplitSizes,                                                           \ | ||||
|       worldSize};                                                              \ | ||||
|   c10::ArrayRef<const c10::IValue> paramInputs(paramList);                     \ | ||||
|   RECORD_FUNCTION(torch::kParamCommsCallName, paramInputs); | ||||
|   RECORD_FUNCTION(at::kParamCommsCallName, paramInputs); | ||||
|  | ||||
| #define RECORD_PARAM_COMMS_DATA(                                               \ | ||||
|     seq,                                                                       \ | ||||
| @ -135,7 +133,7 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase { | ||||
|       worldSize};                                                              \ | ||||
|   c10::ArrayRef<const c10::IValue> paramInputs(paramList);                     \ | ||||
|   RECORD_FUNCTION_WITH_INPUTS_OUTPUTS(                                         \ | ||||
|       torch::kParamCommsCallName,                                              \ | ||||
|       at::kParamCommsCallName,                                                 \ | ||||
|       paramInputs,                                                             \ | ||||
|       std::vector<c10::IValue>(1, c10::IValue(OutputTensors))); | ||||
| } // namespace torch | ||||
|  | ||||
| @ -341,6 +341,11 @@ std::unique_ptr<KinetoObserverContext> ThreadLocalSubqueue::begin_op( | ||||
|         torch::profiler::impl::saveExtraArgs(fn)); | ||||
|   } | ||||
|  | ||||
|   // Record NCCL metadata for specific CPU ops | ||||
|   fn.isNcclMeta() ? torch_ops_.extra_meta_.emplace_back( | ||||
|                         torch::profiler::impl::saveNcclMeta(fn)) | ||||
|                   : torch_ops_.extra_meta_.emplace_back(); | ||||
|  | ||||
|   auto out = std::make_unique<KinetoObserverContext>(event); | ||||
|  | ||||
|   if (config_.state == ProfilerState::KINETO_GPU_FALLBACK) { | ||||
| @ -434,6 +439,7 @@ void ThreadLocalSubqueue::TorchOpStorage::materialize( | ||||
|   auto jit_stack = StealOrDefault<decltype(jit_stack_)>(jit_stack_); | ||||
|   auto jit_module = StealOrDefault<decltype(jit_modules_)>(jit_modules_); | ||||
|   auto extra_args = StealOrDefault<decltype(extra_args_)>(extra_args_); | ||||
|   auto extra_meta = StealOrDefault<decltype(extra_meta_)>(extra_meta_); | ||||
|   auto gpu_fallback = | ||||
|       StealOrDefault<decltype(device_fallback_)>(device_fallback_); | ||||
|  | ||||
| @ -447,6 +453,7 @@ void ThreadLocalSubqueue::TorchOpStorage::materialize( | ||||
|         jit_stack(), | ||||
|         jit_module(), | ||||
|         extra_args(), | ||||
|         extra_meta(), | ||||
|         gpu_fallback(), | ||||
|         event->allow_tf32_cublas_, | ||||
|         std::move(event->counters_)}; | ||||
|  | ||||
| @ -114,6 +114,7 @@ struct TorchOpBasicFields { | ||||
| using jit_stack_t = std::vector<std::string>; | ||||
| using jit_modules_t = std::vector<std::string>; | ||||
| using extra_args_t = std::unordered_map<std::string, c10::IValue>; | ||||
| using extra_meta_t = std::unordered_map<std::string, std::string>; | ||||
|  | ||||
| struct FallbackPair { | ||||
|   ProfilerVoidEventStub device_event_start_ = nullptr; | ||||
| @ -131,6 +132,7 @@ struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields { | ||||
|       jit_stack_t&& jit_stack, | ||||
|       jit_modules_t&& jit_modules, | ||||
|       extra_args_t&& extra_args, | ||||
|       extra_meta_t&& extra_meta, | ||||
|       FallbackPair&& device_fallback, | ||||
|       bool allow_tf32_cublas, | ||||
|       std::unique_ptr<perf_counters_t>&& perf_event_counters) | ||||
| @ -142,6 +144,7 @@ struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields { | ||||
|         jit_stack_{std::move(jit_stack)}, | ||||
|         jit_modules_{std::move(jit_modules)}, | ||||
|         extra_args_{std::move(extra_args)}, | ||||
|         extra_meta_{std::move(extra_meta)}, | ||||
|         device_fallback_{std::move(device_fallback)}, | ||||
|         allow_tf32_cublas_{allow_tf32_cublas}, | ||||
|         perf_event_counters_{std::move(perf_event_counters)} {} | ||||
| @ -152,6 +155,7 @@ struct ExtraFields<EventType::TorchOp> : TorchOpBasicFields { | ||||
|   jit_stack_t jit_stack_; | ||||
|   jit_modules_t jit_modules_; | ||||
|   extra_args_t extra_args_; | ||||
|   extra_meta_t extra_meta_; | ||||
|   FallbackPair device_fallback_; | ||||
|   bool allow_tf32_cublas_; | ||||
|   std::unique_ptr<perf_counters_t> perf_event_counters_; | ||||
| @ -579,6 +583,9 @@ class TORCH_API ThreadLocalSubqueue { | ||||
|     // with_flops | ||||
|     AppendOnlyList<extra_args_t, BlockSize> extra_args_; | ||||
|  | ||||
|     // report extra metadata, i.e. collective communication meta | ||||
|     AppendOnlyList<extra_meta_t, BlockSize> extra_meta_; | ||||
|  | ||||
|     // ProfilerState::KINETO_GPU_FALLBACK or | ||||
|     // ProfilerState::KINETO_PRIVATEUSE1_FALLBACK | ||||
|     AppendOnlyList<FallbackPair, BlockSize> device_fallback_; | ||||
|  | ||||
| @ -9,6 +9,11 @@ | ||||
| #ifdef USE_KINETO | ||||
| #include <libkineto.h> | ||||
| #endif | ||||
| #ifdef USE_DISTRIBUTED | ||||
| #ifdef USE_C10D | ||||
| #include <torch/csrc/distributed/c10d/ParamCommsUtils.hpp> | ||||
| #endif // USE_C10D | ||||
| #endif // USE_DISTRIBUTED | ||||
|  | ||||
| namespace torch { | ||||
| namespace profiler { | ||||
| @ -398,6 +403,51 @@ std::vector<std::string> inputTypes(const at::RecordFunction& fn) { | ||||
|   return types; | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
| // -- NCCL Metadata ----------------------------------------------------------- | ||||
| // ---------------------------------------------------------------------------- | ||||
| #ifdef USE_DISTRIBUTED | ||||
| #ifdef USE_C10D | ||||
| static constexpr auto kCommuName = "Collective name"; | ||||
| static constexpr auto kDtype = "dtype"; | ||||
| static constexpr auto kInMsgSize = "In msg size"; | ||||
| static constexpr auto kOutMsgSize = "Out msg size"; | ||||
| static constexpr auto kInSplit = "In split size"; | ||||
| static constexpr auto kOutSplit = "Out split size"; | ||||
| static constexpr auto kGroupSize = "Group size"; | ||||
| #endif // USE_C10D | ||||
| #endif // USE_DISTRIBUTED | ||||
|  | ||||
| std::unordered_map<std::string, std::string> saveNcclMeta( | ||||
|     const at::RecordFunction& fn) { | ||||
|   std::unordered_map<std::string, std::string> map; | ||||
| #ifdef USE_DISTRIBUTED | ||||
| #ifdef USE_C10D | ||||
|   auto debugInfo = dynamic_cast<ParamCommsDebugInfo*>( | ||||
|       c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::PARAM_COMMS_INFO)); | ||||
|   if (debugInfo == nullptr) { | ||||
|     LOG(WARNING) << "ParamCommsDebugInfo not available for function: " | ||||
|                  << fn.name(); | ||||
|     return map; | ||||
|   } | ||||
|  | ||||
|   map.emplace(kCommuName, fmt::format("\"{}\"", debugInfo->getColumnName())); | ||||
|   map.emplace( | ||||
|       kDtype, fmt::format("\"{}\"", c10::toString(debugInfo->getDType()))); | ||||
|   map.emplace(kInMsgSize, std::to_string(debugInfo->getInMessageSize())); | ||||
|   map.emplace(kOutMsgSize, std::to_string(debugInfo->getOutMessageSize())); | ||||
|   map.emplace( | ||||
|       kInSplit, | ||||
|       fmt::format("[{}]", fmt::join(debugInfo->getInputSplitSizes(), ", "))); | ||||
|   map.emplace( | ||||
|       kOutSplit, | ||||
|       fmt::format("[{}]", fmt::join(debugInfo->getOutputSplitSizes(), ", "))); | ||||
|   map.emplace(kGroupSize, std::to_string(debugInfo->getWorldSize())); | ||||
| #endif // USE_C10D | ||||
| #endif // USE_DISTRIBUTED | ||||
|   return map; | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
| // -- FLOPS ------------------------------------------------------------------- | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| @ -201,6 +201,8 @@ TORCH_API std::vector<std::string> inputTypes(const at::RecordFunction& fn); | ||||
|  | ||||
| std::unordered_map<std::string, c10::IValue> TORCH_API | ||||
| saveExtraArgs(const at::RecordFunction& fn); | ||||
| std::unordered_map<std::string, std::string> TORCH_API | ||||
| saveNcclMeta(const at::RecordFunction& fn); | ||||
|  | ||||
| uint64_t TORCH_API computeFlops( | ||||
|     const std::string& op_name, | ||||
|  | ||||
		Reference in New Issue
	
	Block a user