diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index b234c907a665..0d55845228da 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -1087,6 +1087,62 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase): dist.destroy_process_group() + @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit") + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + def test_comm_split_group_mixed_backend(self): + # Test `ncclCommSplit` for smaller subgroups of the world when + # we've passed a specific device_id to init_process_group. + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device(f"cuda:{self.rank}") + # pg = self._create_process_group_nccl(store, self.opts(), device_id=device) + # create nccl processgroup with opts + c10d.init_process_group( + "cpu:gloo,cuda:nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + pg_options=self.opts(), + device_id=device, + ) + pg = c10d.distributed_c10d._get_default_group() + backend = pg._get_backend(torch.device(device)) + + cuda_tensor = torch.full((1,), self.rank).cuda(device) + cpu_tensor = torch.full((1,), self.rank) + # Create subgroup between ranks 0, 1 + subg_ranks = [0, 1] + ng1 = c10d.split_group(pg, [subg_ranks]) + backend1 = ng1._get_backend(torch.device(device)) + + # check basic options are the same between parent and child + self.assertEqual(backend.options._timeout, backend1.options._timeout) + self.assertEqual( + backend.options.is_high_priority_stream, + backend1.options.is_high_priority_stream, + ) + self.assertEqual(ng1.group_desc, "default_pg:split:0") + + # comm split happens eagerly since device_id is passed to init_process_group. + self.assertEqual(backend.comm_split_count(), 1) + # dist.get_process_group_ranks returns the global ranks in the subgroup. + self.assertEqual( + dist.get_process_group_ranks(ng1), + subg_ranks if self.rank in subg_ranks else [], + ) + + # is part of ng1; otherwise, -1 + if dist.get_rank(ng1) >= 0: + dist.broadcast(cuda_tensor, dist.get_global_rank(ng1, 0), group=ng1) + self.assertEqual(cuda_tensor, torch.full((1,), 0)) + dist.broadcast(cpu_tensor, dist.get_global_rank(ng1, 0), group=ng1) + self.assertEqual(cpu_tensor, torch.full((1,), 0)) + + ng2 = c10d.split_group(pg, [subg_ranks]) + self.assertEqual(ng2.group_desc, "default_pg:split:1") + self.assertEqual(backend.comm_split_count(), 2) + + dist.destroy_process_group() + @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit") @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_non_blocking_init(self): diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 4fb2d566e9a7..5a06a386d5ca 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -1015,7 +1015,9 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { // Backend classes for this ProcessGroup std::unordered_set deviceTypes_; - std::unordered_map deviceTypeToBackendType_; + // This mapping is ordered, as splitGroup must call split on the underlying + // backends in a consistent order. + std::map deviceTypeToBackendType_; std::unordered_map> deviceTypeToBackend_; std::unordered_map> diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index fbd8a403b97d..74063ff579e8 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -551,6 +551,32 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo:: static std::atomic process_group_id = 0; +c10::intrusive_ptr ProcessGroupGloo::Options:: + create_default(std::chrono::milliseconds timeout) { + auto options = ::c10d::ProcessGroupGloo::Options::create(); + bool lazyInit = ::c10d::getDefaultGlooLazyInit(); + + // Use interfaces listed in "GLOO_SOCKET_IFNAME", if set. + auto ifnameEnv = c10::utils::get_env("GLOO_SOCKET_IFNAME"); + if (ifnameEnv && ifnameEnv->size() > 1) { + for (const auto& iface : ::c10d::split(',', ifnameEnv->c_str())) { + options->devices.push_back( + ::c10d::ProcessGroupGloo::createDeviceForInterface(iface, lazyInit)); + } + } else { + // If no hostname is specified, this function looks up + // the machine's hostname and returns a device instance + // associated with the address that the hostname resolves to. + options->devices.push_back( + ::c10d::ProcessGroupGloo::createDefaultDevice(lazyInit)); + } + + options->timeout = timeout; + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + options->threads = options->devices.size() * 2; + return options; +} + ProcessGroupGloo::ProcessGroupGloo( const c10::intrusive_ptr& store, int rank, @@ -710,7 +736,12 @@ c10::intrusive_ptr ProcessGroupGloo::split( } auto glooOpts = c10::dynamic_intrusive_pointer_cast(opts); - TORCH_CHECK(glooOpts != nullptr, "opts not a ProcessGroupGloo::Options."); + if (glooOpts == nullptr) { + TORCH_WARN_ONCE( + "Tried to pass options to ProcessGroupGloo::split that are not ProcessGroupGloo::Options." + "Falling back to default options."); + glooOpts = ProcessGroupGloo::Options::create_default(); + } // TODO: we need to get rid of globalRanksInGroup eventually. std::vector globalRanksInGroup; @@ -729,7 +760,12 @@ c10::intrusive_ptr ProcessGroupGloo::merge( const int& rank, const int& size) { auto glooOpts = c10::dynamic_intrusive_pointer_cast(opts); - TORCH_CHECK(glooOpts != nullptr, "opts not a ProcessGroupGloo::Options."); + if (glooOpts == nullptr) { + TORCH_WARN_ONCE( + "Tried to pass options to ProcessGroupGloo::merge that are not ProcessGroupGloo::Options." + "Falling back to default options."); + glooOpts = ProcessGroupGloo::Options::create_default(); + } auto pg = c10::make_intrusive( store->clone(), rank, size, glooOpts); return c10::static_intrusive_pointer_cast(pg); diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index 4297807f2e8b..b2cc6993528b 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -255,6 +255,9 @@ class TORCH_API ProcessGroupGloo : public Backend { return c10::make_intrusive(timeout); } + static c10::intrusive_ptr create_default( + std::chrono::milliseconds timeout = kBackendDefaultTimeout); + std::vector> devices; int threads; }; diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index a5270354cf61..7e79fef8392f 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -3106,8 +3106,6 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). .def_readwrite("group_name", &::c10d::Backend::Options::group_name); #ifdef USE_C10D_GLOO - static const std::string GLOO_SOCKET_IFNAME_ENV = "GLOO_SOCKET_IFNAME"; - auto processGroupGloo = intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupGloo>( module, "ProcessGroupGloo", backend); @@ -3184,31 +3182,11 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). // https://github.com/pybind/pybind11/issues/5473 py::gil_scoped_release nogil{}; - auto options = ::c10d::ProcessGroupGloo::Options::create(); - bool lazyInit = ::c10d::getDefaultGlooLazyInit(); - - // Use interfaces listed in "GLOO_SOCKET_IFNAME", if set. - auto ifnameEnv = - c10::utils::get_env(GLOO_SOCKET_IFNAME_ENV.c_str()); - if (ifnameEnv && ifnameEnv->size() > 1) { - for (const auto& iface : ::c10d::split(',', ifnameEnv->c_str())) { - options->devices.push_back( - ::c10d::ProcessGroupGloo::createDeviceForInterface( - iface, lazyInit)); - } - } else { - // If no hostname is specified, this function looks up - // the machine's hostname and returns a device instance - // associated with the address that the hostname resolves to. - options->devices.push_back( - ::c10d::ProcessGroupGloo::createDefaultDevice(lazyInit)); - } - - options->timeout = timeout; - // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) - options->threads = options->devices.size() * 2; return c10::make_intrusive<::c10d::ProcessGroupGloo>( - store, rank, size, options); + store, + rank, + size, + ::c10d::ProcessGroupGloo::Options::create_default(timeout)); }), py::arg("store"), py::arg("rank"), diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 2f60cbe13abc..498cc50eb9cf 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -5160,7 +5160,11 @@ def split_group( my_group = split_group break - group_name = _process_group_name(my_group, use_hashed_name=False) + # use_hashed_name is True to ensure that subgroups have unique names. + # This is needed as some backends (e.g. Gloo) use the group name as a + # PrefixStore prefix for initialization of splits. Thus, names have to be + # unique to avoid key collisions. + group_name = _process_group_name(my_group, use_hashed_name=True) split_pg = parent_pg.split_group( my_group, timeout=timeout,