#pragma once #include #include #include #include #include #include #include #include #include // ************************************************************************* // PROCESS GROUP collective communication API IS BEING CHANGED BETWEEN // versions 1.7 and 1.8. // PLEASE DO NOT ADD ANY DEPENDENCIES. // SEE RFC: https://github.com/pytorch/pytorch/issues/39662 // ************************************************************************* constexpr auto kProcessGroupDefaultTimeout = std::chrono::milliseconds(30 * 60 * 1000); namespace c10d { // We only call `register_work()` in two cases: // 1. If the work object is created from a functional collective call. // 2. If the work object is created from a non-functional collective call within // the `with allow_inflight_collective_as_graph_input_ctx()` context manager. C10_EXPORT void register_work( const at::Tensor& tensor, const c10::intrusive_ptr& work); C10_EXPORT at::Tensor wait_tensor(const at::Tensor& tensor); // We only call `unregister_work()` in one case: // 1. If the work object is created from a non-functional collective call within // the `with allow_inflight_collective_as_graph_input_ctx()` context manager. // // Q: What about the functional collective case? // A: The unregistration of work object for functional collective is done in // the required user-side explicit call to `wait_tensor()`. C10_EXPORT void unregister_work(const c10::intrusive_ptr& work); C10_EXPORT size_t get_work_registry_size(); C10_EXPORT void set_allow_inflight_collective_as_graph_input(bool value); C10_EXPORT bool allow_inflight_collective_as_graph_input(); // ProcessGroup is a base class that captures collective and point to // point communication in a fixed set of processes. // // The functions specified in the class below describe the API alone; // implementations are provided in subclasses. // // Every function that performs I/O is executed asynchronously by a // thread pool owned by the ProcessGroup (by default). They return an // object that can be used to wait for completion or error. // // The ProcessGroup can instantiate subgroups with fewer or an equal // number of members. Implementations must take care that multiple // process groups can be used in parallel and synchronize accordingly. // // The ProcessGroup assumes a fixed set of processes. If the set // changes, existing instances must be destructed and instantiation // and initialization must start from scratch. For members of the // process group to find each other (referred to as rendezvous from // hereon) // class TORCH_API ProcessGroup : public torch::CustomClassHolder { public: struct TORCH_API MergeOptions : torch::CustomClassHolder { explicit MergeOptions( const std::chrono::milliseconds timeout = kProcessGroupDefaultTimeout, const std::optional group_name = std::nullopt, const std::optional group_desc = std::nullopt) : timeout(timeout), group_name(group_name), group_desc(group_desc) {} ~MergeOptions() override = default; MergeOptions(const MergeOptions&) = delete; MergeOptions& operator=(const MergeOptions&) = delete; std::chrono::milliseconds timeout; std::optional group_name; std::optional group_desc; }; enum BackendType : uint8_t { UNDEFINED = 0, GLOO = 1, NCCL = 2, UCC = 3, MPI = 4, XCCL = 5, CUSTOM = 6, }; static std::string backendTypeToString(const BackendType& type) { switch (type) { case BackendType::GLOO: return "gloo"; case BackendType::NCCL: return "nccl"; case BackendType::XCCL: return "xccl"; case BackendType::UCC: return "ucc"; case BackendType::MPI: return "mpi"; case BackendType::UNDEFINED: return "undefined"; case BackendType::CUSTOM: return "custom"; default: TORCH_CHECK(false, "THis should never happen!"); } } static BackendType strToBackendType(const std::string& backend) { if (backend == "undefined") { return BackendType::UNDEFINED; } else if (backend == "gloo") { return BackendType::GLOO; } else if (backend == "nccl") { return BackendType::NCCL; } else if (backend == "xccl") { return BackendType::XCCL; } else if (backend == "ucc") { return BackendType::UCC; } else if (backend == "mpi") { return BackendType::MPI; } else { return BackendType::CUSTOM; } } // Not used, set for backwards compatibility and only used for TypeDef in // Ops.cpp explicit ProcessGroup(int rank, int size); explicit ProcessGroup( c10::intrusive_ptr<::c10d::Store> store, int rank, int size); ~ProcessGroup() override; virtual int getRank() const { return rank_; } virtual int getSize() const { return size_; } // Returns an unique opaque ID of this process group object. int64_t getID() const { return reinterpret_cast(this); } // Returns an unique opaque ID of a backend for the specific backend type // that can correlate with this process group's collectives. int64_t getBackendID(BackendType backend_type) const { return reinterpret_cast(getBackend(backend_type).get()); } virtual const std::string getBackendName() const { return backendTypeToString(backendType_); } BackendType getBackendType() const { return backendType_; } inline bool backendSupportsSequenceNumbers(BackendType backendType) { if (backendType == BackendType::GLOO || backendType == BackendType::NCCL || backendType == BackendType::XCCL || backendType == BackendType::UCC) return true; return false; } virtual void setTimeout(std::chrono::milliseconds timeout) { for (auto& backend : backendTypeToBackend_) { backend.second->setTimeout(timeout); } } int64_t incrementSplitCount() { return splitCounter_++; } virtual void startCoalescing(c10::DeviceType deviceType) { // only nccl has implemented startCoalescing so only execute for nccl // backends auto backend = getBackend(deviceType); backend->startCoalescing(); } virtual c10::intrusive_ptr endCoalescing(c10::DeviceType deviceType) { // only nccl has implemented endCoalescing so only execute for nccl // backends auto backend = getBackend(deviceType); auto work = backend->endCoalescing(); return work; } virtual c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::broadcast_", "") .typed< std::tuple, c10::intrusive_ptr>( at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t, int64_t, bool, int64_t)>(); // It's awakward to unbox the opts here and box them again in the custom C++ // op. But it's also complicated to make opts as a CustomClassHolder. Leave // it as it is now. auto work = std::get<1>(op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.rootRank, opts.rootTensor, opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor : tensors) { c10d::register_work(tensor, work); } } return work; } virtual c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::allreduce_", "") .typed< std::tuple, c10::intrusive_ptr>( at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, const std::optional& sparse_indices, bool, int64_t)>(); auto work = std::get<1>(op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive(opts.reduceOp), opts.sparseIndices, opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor : tensors) { c10d::register_work(tensor, work); } } return work; } virtual c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::allreduce_coalesced_", "") .typed( at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, bool, int64_t)>(); auto work = op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive(opts.reduceOp), opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor : tensors) { c10d::register_work(tensor, work); } } return work; } virtual c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::reduce_", "") .typed( at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, int64_t, int64_t, bool, int64_t)>(); auto work = op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive(opts.reduceOp), opts.rootRank, opts.rootTensor, opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor : tensors) { c10d::register_work(tensor, work); } } return work; } virtual c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::allgather_", "") .typed>, c10::intrusive_ptr>( const std::vector>&, at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, bool, int64_t)>(); auto work = std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor_list : outputTensors) { for (const auto& tensor : tensor_list) { c10d::register_work(tensor, work); } } } return work; } // Gathers a single tensor inputBuffer into a single buffer outputBuffer that // is interpreted as a contiguous collection of size inputBuffer * WORLD_SIZE. // For implementers of ProcessGroup API and advanced users only. // Note: this function will be deprecated in near future. virtual c10::intrusive_ptr _allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::_allgather_base_", "") .typed>( at::Tensor&, at::Tensor&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, bool, int64_t)>(); auto work = std::get<1>(op.call( outputBuffer, inputBuffer, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { c10d::register_work(outputBuffer, work); } return work; } // This function is deprecated and will be moved out of ProcessGroup to comms: // * do not add dependencies on this function, // * do not implement it in your ProcessGroup, implement _allgather_base // instead. virtual c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::allgather_coalesced_", "") .typed( const std::vector>&, const at::TensorList&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, bool)>(); auto work = op.call( outputTensorLists, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.asyncOp); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor_list : outputTensorLists) { for (const auto& tensor : tensor_list) { c10d::register_work(tensor, work); } } } return work; } // This function is a coalesced version of `allgather_into_tensor` (currently // still named as `_allgather_base`). Each tensor in the vector corresponds to // an input/output of one `allgather_into_tensor` operation. virtual c10::intrusive_ptr allgather_into_tensor_coalesced( std::vector& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::allgather_into_tensor_coalesced_", "") .typed( const at::TensorList, const at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, bool)>(); auto work = op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.asyncOp); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor : outputTensors) { c10d::register_work(tensor, work); } } return work; } virtual c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::gather_", "") .typed( const std::vector>&, const at::TensorList&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t, bool, int64_t)>(); auto work = op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.rootRank, opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor_list : outputTensors) { for (const auto& tensor : tensor_list) { c10d::register_work(tensor, work); } } } return work; } virtual c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::scatter_", "") .typed< std::tuple, c10::intrusive_ptr>( const at::TensorList&, const std::vector>&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t, bool, int64_t)>(); auto work = std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.rootRank, opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor : outputTensors) { c10d::register_work(tensor, work); } } return work; } virtual c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::reduce_scatter_", "") .typed< std::tuple, c10::intrusive_ptr>( const at::TensorList&, const std::vector>&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, bool, int64_t)>(); auto work = std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor : outputTensors) { c10d::register_work(tensor, work); } } return work; } virtual c10::intrusive_ptr _reduce_scatter_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const ReduceScatterOptions& opts = ReduceScatterOptions()) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::_reduce_scatter_base_", "") .typed>( at::Tensor&, at::Tensor&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, bool, int64_t)>(); auto work = std::get<1>(op.call( outputBuffer, inputBuffer, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { c10d::register_work(outputBuffer, work); } return work; } // This function is a coalesced version of `reduce_scatter_tensor` (currently // still named as `_reduce_scatter_base`). Each tensor in the vector // corresponds to an input/output of one `reduce_scatter_tensor` operation. virtual c10::intrusive_ptr reduce_scatter_tensor_coalesced( std::vector& outputTensors, std::vector& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::reduce_scatter_tensor_coalesced_", "") .typed( const at::TensorList, const at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const c10::intrusive_ptr<::c10d::ReduceOp>&, bool, int64_t)>(); auto work = op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp), opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor : outputTensors) { c10d::register_work(tensor, work); } } return work; } virtual c10::intrusive_ptr alltoall_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::alltoall_base_", "") .typed( at::Tensor&, at::Tensor&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, std::vector, std::vector, bool, int64_t)>(); auto work = op.call( outputBuffer, inputBuffer, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), outputSplitSizes, inputSplitSizes, opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { c10d::register_work(outputBuffer, work); } return work; } virtual c10::intrusive_ptr alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::alltoall_", "") .typed< std::tuple, c10::intrusive_ptr>( const at::TensorList&, const at::TensorList&, const c10::intrusive_ptr<::c10d::ProcessGroup>&, bool, int64_t)>(); auto work = std::get<1>(op.call( outputTensors, inputTensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.asyncOp, opts.timeout.count())); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor : outputTensors) { c10d::register_work(tensor, work); } } return work; } virtual void monitoredBarrier( const BarrierOptions& opts, bool wait_all_ranks = false) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::monitored_barrier_", "") .typed&, const std::vector&, int64_t, bool)>(); // Default to using cpu implementation, monitored barrier is only for GLOO at::Tensor tensor = at::empty({0}, at::TensorOptions().device(at::kCPU)); op.call( tensor, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.device_ids, opts.timeout.count(), wait_all_ranks); } // Agrees on an initial sequence number for the whole group by having rank 0 // create it and broadcast it to other ranks using the store. Only implemented // for GLOO and NCCL backends currently. virtual void setSequenceNumberForGroup() { auto backendType = getBackendType(); // TODO: HACK for backend name to get sequence number for that backend. if (backendSupportsSequenceNumbers(backendType)) { getDefaultBackend()->setSequenceNumberForGroup(); } else { TORCH_CHECK( false, c10::str( "ProcessGroup ", getBackendName(), " does not yet support sequence numbers.")); } } // Retrieves the current sequence number for the whole group, which should be // in sync. If the returned number is not consistent across the group, it // may indicate that there is some sort of collective desynchronization. virtual uint64_t getSequenceNumberForGroup() { auto backendType = getBackendType(); // TODO: HACK for backend name to get sequence number for that backend. if (backendSupportsSequenceNumbers(backendType)) { return getDefaultBackend()->getSequenceNumberForGroup(); } else { TORCH_CHECK( false, c10::str( "ProcessGroup ", getBackendName(), " does not yet support sequence numbers.")); } } virtual c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::send", "") .typed( at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t, int64_t)>(); auto work = op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), dstRank, tag); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor : tensors) { c10d::register_work(tensor, work); } } return work; } virtual c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::recv_", "") .typed( at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t, int64_t)>(); auto work = op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), srcRank, tag); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor : tensors) { c10d::register_work(tensor, work); } } return work; } virtual c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) { static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::recv_any_source_", "") .typed( at::TensorList, const c10::intrusive_ptr<::c10d::ProcessGroup>&, int64_t)>(); auto work = op.call( tensors, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), tag); if (c10d::allow_inflight_collective_as_graph_input()) { for (const auto& tensor : tensors) { c10d::register_work(tensor, work); } } return work; } virtual c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) { static at::Tensor tensor; // TODO: if nccl was specified then use it auto device = opts.device; if (device.has_value()) { // set device tensor from argument tensor = at::empty( {1}, at::TensorOptions().device(device.value()).dtype(at::kByte)); } else if (backendType_ == c10d::ProcessGroup::BackendType::NCCL) { // set cuda tensor tensor = at::empty( {1}, at::TensorOptions().device(at::DeviceType::CUDA).dtype(at::kByte)); } else if (backendType_ == c10d::ProcessGroup::BackendType::XCCL) { // set xpu tensor for override cpu dispatch tensor = at::empty( {1}, at::TensorOptions().device(at::DeviceType::XPU).dtype(at::kByte)); } else { // Default to using cpu implementation tensor = at::empty( {1}, at::TensorOptions().device(at::DeviceType::CPU).dtype(at::kByte)); } static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("c10d::barrier", "") .typed( at::Tensor, const c10::intrusive_ptr<::c10d::ProcessGroup>&, const std::vector&, bool, int64_t)>(); auto work = op.call( tensor, c10::intrusive_ptr::unsafe_reclaim_from_nonowning(this), opts.device_ids, opts.asyncOp, opts.timeout.count()); if (c10d::allow_inflight_collective_as_graph_input()) { c10d::register_work(tensor, work); } return work; } bool hasBackends() { return !deviceTypeToBackendType_.empty(); } void setBackend( c10::DeviceType deviceType, BackendType backendType, const std::optional>& backend) { // TODO: should we add these entries after the backend setting succeeds? deviceTypeToBackendType_[deviceType] = backendType; deviceTypes_.insert(deviceType); // if the backendType is already set then reuse it for this device if (backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end()) { auto existingBackend = backendTypeToBackend_.at(backendType); deviceTypeToBackend_[deviceType] = existingBackend; TORCH_CHECK( existingBackend->getBoundDeviceId() == (*backend)->getBoundDeviceId()); } else { // check if backend has value if (backend.has_value()) { deviceTypeToBackend_[deviceType] = backend.value(); backendTypeToBackend_[backendType] = backend.value(); (*backend)->setBoundDeviceId(bound_device_id_); } } } c10::intrusive_ptr getDefaultBackend() const { auto backend_iter = backendTypeToBackend_.find(backendType_); TORCH_CHECK( backend_iter != backendTypeToBackend_.end(), "Could not find the default backend type ", uint16_t(backendType_), " for Process Group with name ", getBackendName(), "."); return backend_iter->second; } void setDefaultBackend(const BackendType& backendType) { backendType_ = backendType; } void setDefaultBackend(const std::string& backend) { backendType_ = strToBackendType(backend); } c10::intrusive_ptr getBackend(c10::DeviceType deviceType); c10::intrusive_ptr getBackend(BackendType backendType) const { TORCH_CHECK( backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end(), "Could not find backend type ", uint16_t(backendType), " for Process Group with name ", backendTypeToString(backendType), "."); return backendTypeToBackend_.at(backendType); } // Return device types supported by this ProcessGroup. // Note: the return type is `Device` rather than `DeviceType` for the purpose // of easy comparison at Python level. The `Device` will have default index // (-1). std::vector getDeviceTypes() const { std::vector devices; devices.reserve(deviceTypes_.size()); for (auto& dt : deviceTypes_) { devices.emplace_back(dt); } return devices; } void registerOnCompletionHook( std::function)>&& hook) { getDefaultBackend()->registerOnCompletionHook(std::move(hook)); } void waitForPendingWorks() { getDefaultBackend()->waitForPendingWorks(); } virtual void shutdown() { for (auto& backend : backendTypeToBackend_) { backend.second->shutdown(); } } virtual void abort() { for (auto& backend : backendTypeToBackend_) { backend.second->abort(); } } bool hasHooks() const { auto backend_iter = backendTypeToBackend_.find(backendType_); if (backend_iter == backendTypeToBackend_.end()) { TORCH_WARN( "No backend of type ", uint16_t(backendType_), " found for Process Group with name ", getBackendName(), ". Assuming no hooks are registered."); return false; } return backend_iter->second->hasHooks(); } virtual const std::string& getGroupName() const; virtual void setGroupName(const std::string& name); virtual const std::string& getGroupDesc() const; virtual void setGroupDesc(const std::string& name); void enableCollectivesTiming(); void release_resources() override; // ProcessGroups optionally can be "bound" to a specific device. // Currently this is only for nccl and allows for some opt-in // optimizations such as automatic use of ncclCommSplit. The device // is specified in `init_process_group` and eventually makes it // here and then down into the actual backend instances. std::optional getBoundDeviceId() const { return bound_device_id_; } c10::intrusive_ptr getStore() const { return store_; } void setBoundDeviceId(std::optional device) { if (device) { TORCH_CHECK(device->has_index(), "setBoundDeviceId must have an index"); } bound_device_id_ = device; } // This creates a new subgroup using the specified ranks. // The current rank must be included in the list of new_ranks. virtual c10::intrusive_ptr splitGroup( const std::vector& ranks, const std::optional& timeout, const std::optional>& opts, const std::optional& name, const std::optional& groupDesc); // This creates a new subgroup using the specified ranks. // The current rank must be included in the list of new_ranks. virtual c10::intrusive_ptr mergeRemoteGroup( const c10::intrusive_ptr& store, const MergeOptions& opts, const int& size); protected: // Implementations of this interface need to call this to setup // appropriate logging etc. void init(); c10::intrusive_ptr store_; // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const int rank_; // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const int size_; // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) BackendType backendType_; std::string pg_desc_; int64_t splitCounter_; // Debug level setting. It is parsed once when ProcessGroup is constructed and // remains the same across use of this process group. DebugLevel dist_debug_level_{DebugLevel::Off}; // Backend classes for this ProcessGroup std::unordered_set deviceTypes_; // This mapping is ordered, as splitGroup must call split on the underlying // backends in a consistent order. std::map deviceTypeToBackendType_; std::unordered_map> deviceTypeToBackend_; std::unordered_map> backendTypeToBackend_; std::optional bound_device_id_; }; // Thread local functions for managing the currently active process group. TORCH_API c10::intrusive_ptr& currentProcessGroup(); TORCH_API void setProcessGroup(c10::intrusive_ptr processGroup); } // namespace c10d