#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace c10d { // // ProcessGroupTest implements dummy bindings for c10d. // class ProcessGroupTest : public ProcessGroup { public: class WorkTest : public Work { public: WorkTest() {} virtual ~WorkTest(); bool isCompleted() override; bool isSuccess() const override; bool wait(std::chrono::milliseconds timeout) override; protected: friend class ProcessGroupTest; }; explicit ProcessGroupTest(int rank = -1, int size = -1); virtual ~ProcessGroupTest(); c10::intrusive_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) override; c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; c10::intrusive_ptr _allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) override; c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) override; c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) override; c10::intrusive_ptr recvAnysource( std::vector& tensor, int tag) override; // Create a new ProcessGroupTest instance static c10::intrusive_ptr createProcessGroupTest( const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, const std::chrono::duration& timeout); static void ProcessGroupTestConstructor() __attribute__((constructor)) { py::object module = py::module::import("torch.distributed"); py::object register_backend = module.attr("Backend").attr("register_backend"); // The first parameter is the backend name used by user in invoking // torch.distributed.init_process_group(). // Note it could be different with module name. For example, the module // name is "torch_test" but the backend name is "test". // The second parameter is the instantiation function. register_backend("test", py::cpp_function(createProcessGroupTest)); } }; } // namespace c10d