mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[nccl-pg] Store PG global rank information in tracing logs (#115730)
Storing the list of global ranks associated with each PG allows us to correlate traces across different ranks. Test Plan: OSS CI Pull Request resolved: https://github.com/pytorch/pytorch/pull/115730 Approved by: https://github.com/fduwjj
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							b38e14c12a
						
					
				
				
					commit
					ffc826bf10
				
			| @ -15,6 +15,8 @@ ParamCommsDebugInfo::ParamCommsDebugInfo( | ||||
|     at::ScalarType dType, | ||||
|     std::vector<int64_t> inSplitSizes, | ||||
|     std::vector<int64_t> outSplitSizes, | ||||
|     int globalRankStart, | ||||
|     int globalRankStride, | ||||
|     int worldSize) | ||||
|     : rank_(rank), | ||||
|       worldSize_(worldSize), | ||||
| @ -23,6 +25,8 @@ ParamCommsDebugInfo::ParamCommsDebugInfo( | ||||
|       outMessageNelems_(outNelems), | ||||
|       dType_(dType), | ||||
|       inputSplitSizes_(std::move(inSplitSizes)), | ||||
|       outputSplitSizes_(std::move(outSplitSizes)) {} | ||||
|       outputSplitSizes_(std::move(outSplitSizes)), | ||||
|       globalRankStart_(globalRankStart), | ||||
|       globalRankStride_(globalRankStride) {} | ||||
|  | ||||
| } // namespace torch | ||||
|  | ||||
| @ -20,6 +20,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase { | ||||
|       at::ScalarType dType, | ||||
|       std::vector<int64_t> inSplitSizes, | ||||
|       std::vector<int64_t> outSplitSizes, | ||||
|       int globalRankStart, | ||||
|       int globalRankStride, | ||||
|       int worldSize); | ||||
|  | ||||
|   ~ParamCommsDebugInfo() override = default; | ||||
| @ -32,6 +34,14 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase { | ||||
|     return worldSize_; | ||||
|   } | ||||
|  | ||||
|   int getGlobalRankStart() const { | ||||
|     return globalRankStart_; | ||||
|   } | ||||
|  | ||||
|   int getGlobalRankStride() const { | ||||
|     return globalRankStride_; | ||||
|   } | ||||
|  | ||||
|   const std::string getColumnName() const { | ||||
|     return columnName_; | ||||
|   } | ||||
| @ -65,6 +75,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase { | ||||
|   at::ScalarType dType_ = at::kByte; | ||||
|   std::vector<int64_t> inputSplitSizes_; | ||||
|   std::vector<int64_t> outputSplitSizes_; | ||||
|   int globalRankStart_; | ||||
|   int globalRankStride_; | ||||
| }; | ||||
|  | ||||
| #define RECORD_PARAM_COMMS(                                                    \ | ||||
| @ -77,6 +89,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase { | ||||
|     dType,                                                                     \ | ||||
|     inSplitSizes,                                                              \ | ||||
|     outSplitSizes,                                                             \ | ||||
|     globalRankStart,                                                           \ | ||||
|     globalRankStride,                                                          \ | ||||
|     worldSize)                                                                 \ | ||||
|   auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>(          \ | ||||
|       rank,                                                                    \ | ||||
| @ -86,6 +100,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase { | ||||
|       dType,                                                                   \ | ||||
|       inSplitSizes,                                                            \ | ||||
|       outSplitSizes,                                                           \ | ||||
|       globalRankStart,                                                         \ | ||||
|       globalRankStride,                                                        \ | ||||
|       worldSize);                                                              \ | ||||
|   c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \ | ||||
|   std::initializer_list<const c10::IValue> paramList = {                       \ | ||||
| @ -95,6 +111,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase { | ||||
|       colName,                                                                 \ | ||||
|       inSplitSizes,                                                            \ | ||||
|       outSplitSizes,                                                           \ | ||||
|       globalRankStart,                                                         \ | ||||
|       globalRankStride,                                                        \ | ||||
|       worldSize};                                                              \ | ||||
|   c10::ArrayRef<const c10::IValue> paramInputs(paramList);                     \ | ||||
|   RECORD_FUNCTION(at::kParamCommsCallName, paramInputs); | ||||
| @ -111,6 +129,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase { | ||||
|     dType,                                                                     \ | ||||
|     inSplitSizes,                                                              \ | ||||
|     outSplitSizes,                                                             \ | ||||
|     globalRankStart,                                                           \ | ||||
|     globalRankStride,                                                          \ | ||||
|     worldSize)                                                                 \ | ||||
|   auto paramCommsInfo = std::make_shared<torch::ParamCommsDebugInfo>(          \ | ||||
|       rank,                                                                    \ | ||||
| @ -120,6 +140,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase { | ||||
|       dType,                                                                   \ | ||||
|       inSplitSizes,                                                            \ | ||||
|       outSplitSizes,                                                           \ | ||||
|       globalRankStart,                                                         \ | ||||
|       globalRankStride,                                                        \ | ||||
|       worldSize);                                                              \ | ||||
|   c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \ | ||||
|   std::initializer_list<const c10::IValue> paramList = {                       \ | ||||
| @ -130,6 +152,8 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase { | ||||
|       colName,                                                                 \ | ||||
|       inSplitSizes,                                                            \ | ||||
|       outSplitSizes,                                                           \ | ||||
|       globalRankStart,                                                         \ | ||||
|       globalRankStride,                                                        \ | ||||
|       worldSize};                                                              \ | ||||
|   c10::ArrayRef<const c10::IValue> paramInputs(paramList);                     \ | ||||
|   RECORD_FUNCTION_WITH_INPUTS_OUTPUTS(                                         \ | ||||
|  | ||||
| @ -641,6 +641,8 @@ bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) { | ||||
|       at::kByte, // dType | ||||
|       std::vector<int64_t>(), // inSplitSizes | ||||
|       std::vector<int64_t>(), // outSplitSizes | ||||
|       -1, | ||||
|       -1, | ||||
|       static_cast<int>(devices_.size())); // worldSize | ||||
|   synchronizeInternal(timeout); | ||||
|   // Always return true, because abort API is not implemented. | ||||
| @ -806,6 +808,38 @@ ProcessGroupNCCL::ProcessGroupNCCL( | ||||
|             << ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_ | ||||
|             << ", NCCL_DEBUG: " << nccl_debug << ", ID=" << this->getID(); | ||||
|  | ||||
|   if (options_->global_ranks_in_group.empty()) { | ||||
|     this->globalRankStart = 0; | ||||
|   } else { | ||||
|     this->globalRankStart = options_->global_ranks_in_group[0]; | ||||
|   } | ||||
|  | ||||
|   if (options_->global_ranks_in_group.empty()) { | ||||
|     this->globalRankStride = 1; | ||||
|   } else if (options_->global_ranks_in_group.size() == 1) { | ||||
|     this->globalRankStride = 0; | ||||
|   } else { | ||||
|     bool ranksAreStrided = true; | ||||
|     int startRank = options_->global_ranks_in_group[0]; | ||||
|     int stride = | ||||
|         options_->global_ranks_in_group[1] - options_->global_ranks_in_group[0]; | ||||
|     for (std::vector<uint64_t>::size_type i = 0; | ||||
|          i < options_->global_ranks_in_group.size(); | ||||
|          i++) { | ||||
|       if (options_->global_ranks_in_group[i] != startRank + i * stride) { | ||||
|         ranksAreStrided = false; | ||||
|         break; | ||||
|       } | ||||
|     } | ||||
|  | ||||
|     if (ranksAreStrided) { | ||||
|       this->globalRankStride = options_->global_ranks_in_group[1] - | ||||
|           options_->global_ranks_in_group[0]; | ||||
|     } else { | ||||
|       this->globalRankStride = -1; | ||||
|     } | ||||
|   } | ||||
|  | ||||
|   RECORD_PARAM_COMMS( | ||||
|       0, // seq | ||||
|       this->getID(), | ||||
| @ -816,6 +850,8 @@ ProcessGroupNCCL::ProcessGroupNCCL( | ||||
|       at::kByte, // dType | ||||
|       std::vector<int64_t>(), // inSplitSizes | ||||
|       std::vector<int64_t>(), // outSplitSizes | ||||
|       globalRankStart, // globalRankStart | ||||
|       globalRankStride, // globalRankStride | ||||
|       size_); // worldSize | ||||
|  | ||||
|   // Attach hooks to cache allocator to trigger the hooks whenever a traced | ||||
| @ -2774,6 +2810,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce( | ||||
|       tensor.scalar_type(), // dType | ||||
|       std::vector<int64_t>(), // inSplitSizes | ||||
|       std::vector<int64_t>(), // outSplitSizes | ||||
|       globalRankStart, // globalRankStart | ||||
|       globalRankStride, // globalRankStride | ||||
|       this->getSize()); // worldSize | ||||
|  | ||||
|   // avoidRecordStreams_ note: collective() will stash tensors. | ||||
| @ -2800,6 +2838,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_coalesced( | ||||
|       // I'm not sure what in,outSplitSizes mean here. | ||||
|       std::vector<int64_t>(), // inSplitSizes | ||||
|       std::vector<int64_t>(), // outSplitSizes | ||||
|       globalRankStart, // globalRankStart | ||||
|       globalRankStride, // globalRankStride | ||||
|       this->getSize()); // worldSize | ||||
|  | ||||
|   // avoidRecordStreams_ note: collective() will stash tensors. | ||||
| @ -2826,6 +2866,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::broadcast( | ||||
|       tensor.scalar_type(), // dType | ||||
|       std::vector<int64_t>(), // inSplitSizes | ||||
|       std::vector<int64_t>(), // outSplitSizes | ||||
|       globalRankStart, // globalRankStart | ||||
|       globalRankStride, // globalRankStride | ||||
|       this->getSize()); // worldSize | ||||
|  | ||||
|   // avoidRecordStreams_ note: collective() will stash tensors. | ||||
| @ -2889,6 +2931,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_broadcast_oop( | ||||
|       tensor.scalar_type(), // dType | ||||
|       std::vector<int64_t>(), // inSplitSizes | ||||
|       std::vector<int64_t>(), // outSplitSizes | ||||
|       globalRankStart, // globalRankStart | ||||
|       globalRankStride, // globalRankStride | ||||
|       this->getSize()); // worldSize | ||||
|  | ||||
|   return collective( | ||||
| @ -2931,6 +2975,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce( | ||||
|       tensor.scalar_type(), // dType | ||||
|       std::vector<int64_t>(), // inSplitSizes | ||||
|       std::vector<int64_t>(), // outSplitSizes | ||||
|       globalRankStart, // globalRankStart | ||||
|       globalRankStride, // globalRankStride | ||||
|       this->getSize()); // worldSize | ||||
|  | ||||
|   int dev_in_group = 0; | ||||
| @ -2995,6 +3041,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_oop( | ||||
|       tensor.scalar_type(), // dType | ||||
|       std::vector<int64_t>(), // inSplitSizes | ||||
|       std::vector<int64_t>(), // outSplitSizes | ||||
|       globalRankStart, // globalRankStart | ||||
|       globalRankStride, // globalRankStride | ||||
|       this->getSize()); // worldSize | ||||
|  | ||||
|   int dev_in_group{0}; | ||||
| @ -3053,6 +3101,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allgather( | ||||
|         tensor.scalar_type(), // dType | ||||
|         std::vector<int64_t>(), // inSplitSizes | ||||
|         std::vector<int64_t>(), // outSplitSize | ||||
|         globalRankStart, // globalRankStart | ||||
|         globalRankStride, // globalRankStride | ||||
|         this->getSize()); // worldSize | ||||
|  | ||||
|     return collective( | ||||
| @ -3196,6 +3246,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter( | ||||
|         tensor.scalar_type(), // dType | ||||
|         std::vector<int64_t>(), // inSplitSizes | ||||
|         std::vector<int64_t>(), // outSplitSizes | ||||
|         globalRankStart, // globalRankStart | ||||
|         globalRankStride, // globalRankStride | ||||
|         this->getSize()); // worldSize | ||||
|  | ||||
|     return collective( | ||||
| @ -3315,6 +3367,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base( | ||||
|       tensor.scalar_type(), // dtype | ||||
|       std::vector<int64_t>(), // inSplitSizes | ||||
|       std::vector<int64_t>(), // outSplitSizes | ||||
|       globalRankStart, // globalRankStart | ||||
|       globalRankStride, // globalRankStride | ||||
|       this->getSize()); // worldSize | ||||
|  | ||||
|   auto inputs = std::vector<at::Tensor>{inputTensor}; | ||||
| @ -3403,6 +3457,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::barrier(const BarrierOptions& opts) { | ||||
|       at::kByte, // dType | ||||
|       std::vector<int64_t>(), // inSplitSizes | ||||
|       std::vector<int64_t>(), // outSplitSizes | ||||
|       globalRankStart, // globalRankStart | ||||
|       globalRankStride, // globalRankStride | ||||
|       this->getSize()); // worldSize | ||||
|  | ||||
|   std::vector<at::Device> devices; | ||||
| @ -3484,6 +3540,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base( | ||||
|         inputTensor.scalar_type(), // dType | ||||
|         std::vector<int64_t>(), // inSplitSizes | ||||
|         std::vector<int64_t>(), // outSplitSizes | ||||
|         globalRankStart, // globalRankStart | ||||
|         globalRankStride, // globalRankStride | ||||
|         this->getSize()); // worldSize | ||||
|  | ||||
|     // avoidRecordStreams_ note: collective() will stash inputTensors and | ||||
| @ -3526,6 +3584,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall_base( | ||||
|         inputTensor.scalar_type(), // dType | ||||
|         inputSplitSizes, // inSplitSizes | ||||
|         outputSplitSizes, // outSplitSizes | ||||
|         globalRankStart, // globalRankStart | ||||
|         globalRankStride, // globalRankStride | ||||
|         this->getSize()); // worldSize | ||||
|  | ||||
|     // avoidRecordStreams_ note: collective() will stash inputTensors and | ||||
| @ -3602,6 +3662,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::alltoall( | ||||
|       inputTensors.front().scalar_type(), // dType | ||||
|       inSplitSizes, // inSplitSizes | ||||
|       outSplitSizes, // outSplitSizes | ||||
|       globalRankStart, // globalRankStart | ||||
|       globalRankStride, // globalRankStride | ||||
|       this->getSize()); // worldSize | ||||
|  | ||||
|   std::vector<at::Tensor> inputTensor0 = {inputTensors[0]}; | ||||
| @ -3653,6 +3715,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::send( | ||||
|       tensor.scalar_type(), // dType | ||||
|       std::vector<int64_t>(), // inSplitSizes | ||||
|       std::vector<int64_t>(), // outSplitSizes | ||||
|       globalRankStart, // globalRankStart | ||||
|       globalRankStride, // globalRankStride | ||||
|       this->getSize()); // worldSize | ||||
|  | ||||
|   auto ret = pointToPoint( | ||||
| @ -3691,6 +3755,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::recv( | ||||
|       tensor.scalar_type(), // dType | ||||
|       std::vector<int64_t>(), // inSplitSizes | ||||
|       std::vector<int64_t>(), // outSplitSizes | ||||
|       globalRankStart, // globalRankStart | ||||
|       globalRankStride, // globalRankStride | ||||
|       this->getSize()); // worldSize | ||||
|  | ||||
|   auto ret = pointToPoint( | ||||
| @ -3851,6 +3917,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::gather( | ||||
|       tensor.scalar_type(), // dType | ||||
|       std::vector<int64_t>(), // inSplitSizes | ||||
|       std::vector<int64_t>(), // outSplitSize | ||||
|       globalRankStart, // globalRankStart | ||||
|       globalRankStride, // globalRankStride | ||||
|       this->getSize()); // worldSize | ||||
|  | ||||
|   // avoidRecordStreams_ note: collective() will stash inputTensors and | ||||
| @ -3938,6 +4006,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::scatter( | ||||
|       tensor.scalar_type(), // dType | ||||
|       std::vector<int64_t>(), // inSplitSizes | ||||
|       std::vector<int64_t>(), // outSplitSize | ||||
|       globalRankStart, // globalRankStart | ||||
|       globalRankStride, // globalRankStride | ||||
|       this->getSize()); // worldSize | ||||
|  | ||||
|   // avoidRecordStreams_ note: collective() will stash outputTensors and | ||||
| @ -4009,6 +4079,8 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_allgather_base( | ||||
|       tensor.scalar_type(), // dType | ||||
|       std::vector<int64_t>(), // inSplitSizes | ||||
|       std::vector<int64_t>(), // outSplitSize | ||||
|       globalRankStart, // globalRankStart | ||||
|       globalRankStride, // globalRankStride | ||||
|       this->getSize()); // worldSize | ||||
|  | ||||
|   // just a wrapper to fit the collective interface | ||||
|  | ||||
| @ -590,6 +590,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { | ||||
|       OpType opType); | ||||
|  | ||||
|  private: | ||||
|   int globalRankStart; | ||||
|   int globalRankStride; | ||||
|  | ||||
|   // Helper that encapsulates work shared across all collective communication | ||||
|   // primitives.  The callbacks have the following signatures: | ||||
|   // | ||||
|  | ||||
| @ -158,6 +158,8 @@ void CommTraceLogger::recordComms( | ||||
|       dtype, | ||||
|       curInSplitSizes_, | ||||
|       curOutSplitSizes_, | ||||
|       -1, | ||||
|       -1, | ||||
|       world_size); | ||||
|  | ||||
|   ++seqnum; | ||||
|  | ||||
| @ -344,6 +344,8 @@ static constexpr auto kInMsgNelems = "In msg nelems"; | ||||
| static constexpr auto kOutMsgNelems = "Out msg nelems"; | ||||
| static constexpr auto kInSplit = "In split size"; | ||||
| static constexpr auto kOutSplit = "Out split size"; | ||||
| static constexpr auto kGlobalRankStart = "Global rank start"; | ||||
| static constexpr auto kGlobalRankStride = "Global rank stride"; | ||||
| static constexpr auto kGroupSize = "Group size"; | ||||
| static constexpr int32_t kTruncatLength = 30; | ||||
| #endif // USE_C10D | ||||
| @ -395,6 +397,10 @@ std::unordered_map<std::string, std::string> saveNcclMeta( | ||||
|                 outSplitSizes.begin() + kTruncatLength, | ||||
|                 ", "))); | ||||
|   } | ||||
|   map.emplace( | ||||
|       kGlobalRankStart, std::to_string(debugInfo->getGlobalRankStart())); | ||||
|   map.emplace( | ||||
|       kGlobalRankStride, std::to_string(debugInfo->getGlobalRankStride())); | ||||
|   map.emplace(kGroupSize, std::to_string(debugInfo->getWorldSize())); | ||||
| #endif // USE_C10D | ||||
| #endif // USE_DISTRIBUTED | ||||
|  | ||||
		Reference in New Issue
	
	Block a user