mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
shrink_group implementation to expose ncclCommShrink API (#164518)
Closes #164529 To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch. This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization. For more info: [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518 Approved by: https://github.com/kwen2501
This commit is contained in:
committed by
PyTorch MergeBot
parent
15ff1cd28b
commit
fa0db212e7
@ -79,6 +79,23 @@ class TORCH_API Backend : public torch::CustomClassHolder {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool supportsShrinking() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Shrink the backend by excluding specified ranks. Backends that support
|
||||
// communicator shrinking should override this and return a new backend
|
||||
// instance representing the shrunken group. Backends may use opts_override
|
||||
// to supply backend-specific options for the new group.
|
||||
virtual c10::intrusive_ptr<Backend> shrink(
|
||||
const std::vector<int64_t>& /*ranks_to_exclude*/,
|
||||
int /*shrink_flags*/ = 0,
|
||||
const c10::intrusive_ptr<Options>& /*opts_override*/ = nullptr) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
c10::str("Backend ", getBackendName(), " does not support shrink"));
|
||||
}
|
||||
|
||||
virtual void setTimeout(std::chrono::milliseconds timeout) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
@ -259,6 +259,65 @@ std::shared_ptr<NCCLComm> NCCLComm::split(
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef NCCL_HAS_COMM_SHRINK
|
||||
std::shared_ptr<NCCLComm> NCCLComm::shrink(
|
||||
NCCLComm* source,
|
||||
std::vector<int>& ranks_to_exclude,
|
||||
ncclConfig_t* config,
|
||||
int shrinkFlags) {
|
||||
// Preconditions are validated in ProcessGroupNCCL::shrink
|
||||
|
||||
LOG(INFO) << "Rank " << source->rank_ << ": shrinking comm " << source->repr()
|
||||
<< " excluding " << ranks_to_exclude.size() << " ranks";
|
||||
|
||||
at::cuda::OptionalCUDAGuard gpuGuard(source->deviceIndex_);
|
||||
auto comm = std::make_shared<NCCLComm>();
|
||||
|
||||
// This call will block until the source communicator is initialized
|
||||
auto sourceComm = source->getNcclComm();
|
||||
|
||||
C10D_NCCL_CHECK_NONBLOCKING(
|
||||
ncclCommShrink(
|
||||
sourceComm,
|
||||
ranks_to_exclude.data(),
|
||||
ranks_to_exclude.size(),
|
||||
reinterpret_cast<ncclComm_t*>(&(comm->ncclComm_)),
|
||||
config,
|
||||
shrinkFlags),
|
||||
source->getNcclCommFailureReason());
|
||||
|
||||
// Wait for the child communicator to be ready
|
||||
source->waitReady(true);
|
||||
comm->initialized_ = true;
|
||||
|
||||
// NCCL automatically assigns rank during shrink - query it efficiently
|
||||
int assigned_rank;
|
||||
try {
|
||||
C10D_NCCL_CHECK(
|
||||
ncclCommUserRank(comm->ncclComm_, &assigned_rank), std::nullopt);
|
||||
comm->rank_ = assigned_rank;
|
||||
} catch (const std::exception& e) {
|
||||
// Fallback: if ncclCommUserRank fails, we can't determine the rank
|
||||
LOG(ERROR) << "Failed to query NCCL-assigned rank: " << e.what();
|
||||
throw;
|
||||
}
|
||||
|
||||
// Child comm should be on the same device as parent comm
|
||||
comm->deviceIndex_ = source->deviceIndex_;
|
||||
if (config != nullptr) {
|
||||
comm->nonBlocking_ = config->blocking == 0;
|
||||
} else {
|
||||
// Inherit parent behavior if no config provided
|
||||
comm->nonBlocking_ = source->nonBlocking_;
|
||||
}
|
||||
|
||||
LOG(INFO) << "Rank " << source->rank_ << ": created shrunken comm "
|
||||
<< comm->repr() << " with NCCL-assigned rank " << assigned_rank;
|
||||
|
||||
return comm;
|
||||
}
|
||||
#endif
|
||||
|
||||
void NCCLComm::finalize() {
|
||||
LockType lock(mutex_);
|
||||
if (aborted_) {
|
||||
|
@ -90,6 +90,10 @@ static_assert(
|
||||
#define NCCL_HAS_NVLS_CTAS
|
||||
#endif
|
||||
|
||||
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0)
|
||||
#define NCCL_HAS_COMM_SHRINK
|
||||
#endif
|
||||
|
||||
// Macro to throw on a non-successful NCCL return value.
|
||||
#define C10D_NCCL_CHECK(cmd, failureReason) \
|
||||
do { \
|
||||
@ -294,6 +298,14 @@ class NCCLComm {
|
||||
ncclConfig_t& config);
|
||||
#endif // NCCL_HAS_COMM_SPLIT
|
||||
|
||||
#ifdef NCCL_HAS_COMM_SHRINK
|
||||
static std::shared_ptr<NCCLComm> shrink(
|
||||
NCCLComm* source,
|
||||
std::vector<int>& ranks_to_exclude,
|
||||
ncclConfig_t* config,
|
||||
int shrinkFlags = 0);
|
||||
#endif // NCCL_HAS_COMM_SHRINK
|
||||
|
||||
#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
|
||||
std::unordered_map<std::string, std::string> ncclCommDump();
|
||||
#endif
|
||||
|
@ -165,7 +165,7 @@ ncclRedOpRAII getNcclReduceOp(
|
||||
}
|
||||
|
||||
// Get a key string from device
|
||||
inline std::string getKeyFromDevice(at::Device& device) {
|
||||
inline std::string getKeyFromDevice(const at::Device& device) {
|
||||
return std::to_string(device.index());
|
||||
}
|
||||
|
||||
@ -5838,6 +5838,139 @@ at::Tensor ProcessGroupNCCL::allocateTensor(
|
||||
return tensor;
|
||||
}
|
||||
|
||||
#ifdef NCCL_HAS_COMM_SHRINK
|
||||
c10::intrusive_ptr<Backend> ProcessGroupNCCL::shrink(
|
||||
const std::vector<int64_t>& ranks_to_exclude,
|
||||
int shrink_flags,
|
||||
const c10::intrusive_ptr<Backend::Options>& opts_override) {
|
||||
// Runtime version check with better error message
|
||||
auto runtime_version = torch::cuda::nccl::version();
|
||||
TORCH_CHECK(
|
||||
runtime_version >= NCCL_VERSION(2, 27, 0),
|
||||
"ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later. "
|
||||
"Found version: ",
|
||||
runtime_version);
|
||||
|
||||
// Early validation with detailed error messages
|
||||
TORCH_CHECK_VALUE(
|
||||
!ranks_to_exclude.empty(), "ranks_to_exclude cannot be empty");
|
||||
TORCH_CHECK_VALUE(
|
||||
static_cast<int>(ranks_to_exclude.size()) < size_,
|
||||
"Cannot exclude all ranks (",
|
||||
ranks_to_exclude.size(),
|
||||
" >= ",
|
||||
size_,
|
||||
")");
|
||||
|
||||
// Validate ranks and convert to int efficiently
|
||||
std::vector<int> int_ranks_to_exclude;
|
||||
int_ranks_to_exclude.reserve(ranks_to_exclude.size());
|
||||
for (int64_t rank : ranks_to_exclude) {
|
||||
TORCH_CHECK_VALUE(
|
||||
rank >= 0 && rank < size_,
|
||||
"Invalid rank ",
|
||||
rank,
|
||||
" for group size ",
|
||||
size_);
|
||||
int_ranks_to_exclude.push_back(static_cast<int>(rank));
|
||||
}
|
||||
|
||||
// Get primary communicator with better error context
|
||||
auto primary_device_index = guessDeviceId();
|
||||
auto primary_device = at::Device(at::kCUDA, primary_device_index);
|
||||
const auto primary_key = getKeyFromDevice(primary_device);
|
||||
|
||||
std::shared_ptr<NCCLComm> primary_comm = getNCCLComm(primary_key);
|
||||
TORCH_CHECK(
|
||||
primary_comm,
|
||||
"Primary NCCL communicator for device ",
|
||||
primary_device,
|
||||
" (key: ",
|
||||
primary_key,
|
||||
") is not initialized");
|
||||
|
||||
// Cache device index before shrink operation
|
||||
at::DeviceIndex parent_device_index = primary_comm->getDeviceIndex();
|
||||
|
||||
ncclConfig_t* config = nullptr;
|
||||
// Default to inheriting from parent options
|
||||
bool high_priority_stream = options_->is_high_priority_stream;
|
||||
if (opts_override) {
|
||||
auto nccl_opts =
|
||||
c10::static_intrusive_pointer_cast<ProcessGroupNCCL::Options>(
|
||||
opts_override);
|
||||
config = &nccl_opts->config;
|
||||
// If user provided override options, honor is_high_priority_stream as well
|
||||
high_priority_stream = nccl_opts->is_high_priority_stream;
|
||||
}
|
||||
|
||||
std::shared_ptr<NCCLComm> shrunk_comm = NCCLComm::shrink(
|
||||
primary_comm.get(),
|
||||
int_ranks_to_exclude,
|
||||
(config != nullptr ? config : &options_->config),
|
||||
shrink_flags);
|
||||
|
||||
// Calculate new size and get NCCL-assigned rank
|
||||
int new_size = size_ - static_cast<int>(ranks_to_exclude.size());
|
||||
int new_rank = shrunk_comm->rank_;
|
||||
|
||||
// Create new ProcessGroupNCCL with optimized options cloning
|
||||
auto new_store = store_->clone();
|
||||
auto new_opts = ProcessGroupNCCL::Options::create(high_priority_stream);
|
||||
new_opts->timeout = options_->timeout;
|
||||
if (config != nullptr) {
|
||||
new_opts->config = *config;
|
||||
} else {
|
||||
new_opts->config = options_->config;
|
||||
}
|
||||
|
||||
auto new_pg = c10::make_intrusive<ProcessGroupNCCL>(
|
||||
new_store, new_rank, new_size, new_opts);
|
||||
|
||||
// Set up the new process group with optimized device setup
|
||||
new_pg->initializeDeviceStateForComm(
|
||||
at::Device(at::kCUDA, parent_device_index), shrunk_comm);
|
||||
|
||||
return c10::static_intrusive_pointer_cast<Backend>(new_pg);
|
||||
}
|
||||
|
||||
#else // !NCCL_HAS_COMM_SHRINK
|
||||
// Backend interface override: raise consistent error when shrink is
|
||||
// unsupported.
|
||||
c10::intrusive_ptr<Backend> ProcessGroupNCCL::shrink(
|
||||
const std::vector<int64_t>& /*ranks_to_exclude*/,
|
||||
int /*shrink_flags*/,
|
||||
const c10::intrusive_ptr<Backend::Options>& /*opts_override*/) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later, "
|
||||
"but PyTorch was built with an older version or without NCCL shrink support.");
|
||||
}
|
||||
|
||||
#endif // NCCL_HAS_COMM_SHRINK
|
||||
|
||||
void ProcessGroupNCCL::initializeDeviceStateForComm(
|
||||
const at::Device& device,
|
||||
std::shared_ptr<NCCLComm> comm) {
|
||||
const auto key = getKeyFromDevice(device);
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
at::cuda::OptionalCUDAGuard gpuGuard(device);
|
||||
|
||||
bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false);
|
||||
auto stream = at::cuda::getStreamFromPool(
|
||||
options_->is_high_priority_stream || force_high);
|
||||
|
||||
devNCCLCommMap_[key] = comm;
|
||||
ncclStreams_.emplace(key, stream);
|
||||
ncclEvents_.emplace(key, at::cuda::CUDAEvent(cudaEventDisableTiming));
|
||||
usedDeviceIdxs_.insert(device.index());
|
||||
|
||||
if (shouldAllCommunicatorsRegisterAllTensors()) {
|
||||
std::lock_guard<std::mutex> map_lock(ncclCommMemPoolMapMutex);
|
||||
ncclCommMemPoolMap.emplace(std::move(comm), MemPoolSet{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace c10d
|
||||
|
||||
#endif // USE_C10D_NCCL
|
||||
|
@ -997,6 +997,21 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
|
||||
ErrorType getError() override;
|
||||
|
||||
bool supportsShrinking() const override {
|
||||
#ifdef NCCL_HAS_COMM_SHRINK
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Backend-style shrink override that returns a Backend instance.
|
||||
c10::intrusive_ptr<Backend> shrink(
|
||||
const std::vector<int64_t>& ranks_to_exclude,
|
||||
int shrink_flags = 0,
|
||||
const c10::intrusive_ptr<Backend::Options>& opts_override =
|
||||
nullptr) override;
|
||||
|
||||
std::shared_ptr<c10::Allocator> getMemAllocator() override;
|
||||
|
||||
// Allocate tensor from communication-optimized memory pool
|
||||
@ -1065,6 +1080,12 @@ class TORCH_API ProcessGroupNCCL : public Backend {
|
||||
int p2pRank = 0,
|
||||
bool isSendRecvSelf = false);
|
||||
|
||||
// Initialize device-specific state (comm, stream, event, bookkeeping) for a
|
||||
// given communicator on this process group instance.
|
||||
void initializeDeviceStateForComm(
|
||||
const at::Device& device,
|
||||
std::shared_ptr<NCCLComm> comm);
|
||||
|
||||
// Wrapper method which can be overridden for tests.
|
||||
virtual std::exception_ptr checkForNCCLErrors(
|
||||
std::shared_ptr<NCCLComm>& ncclComm);
|
||||
|
@ -2730,12 +2730,23 @@ Arguments:
|
||||
"supports_time_estimate",
|
||||
&::c10d::Backend::supportsTimeEstimation,
|
||||
"(test whether the backend supports collective time estimation)")
|
||||
.def_property_readonly(
|
||||
"supports_shrinking",
|
||||
&::c10d::Backend::supportsShrinking,
|
||||
"(test whether the backend supports communicator shrinking)")
|
||||
.def(
|
||||
"set_timeout",
|
||||
&::c10d::Backend::setTimeout,
|
||||
py::arg("timeout"),
|
||||
py::call_guard<py::gil_scoped_release>(),
|
||||
R"(Sets the default timeout for all future operations.)")
|
||||
.def(
|
||||
"shrink",
|
||||
&::c10d::Backend::shrink,
|
||||
py::arg("ranks_to_exclude"),
|
||||
py::arg("shrink_flags") = 0,
|
||||
py::arg("opts_override") = nullptr,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"broadcast",
|
||||
&::c10d::Backend::broadcast,
|
||||
|
@ -130,6 +130,7 @@ __all__ = [
|
||||
"reduce_scatter_tensor",
|
||||
"get_node_local_rank",
|
||||
"split_group",
|
||||
"shrink_group",
|
||||
]
|
||||
|
||||
_MPI_AVAILABLE = True
|
||||
@ -5713,3 +5714,517 @@ def _get_process_group_name(pg: ProcessGroup) -> str:
|
||||
|
||||
def _get_process_group_store(pg: ProcessGroup) -> Store:
|
||||
return _world.pg_map[pg][1]
|
||||
|
||||
|
||||
# Shrink flags for process group backends
|
||||
SHRINK_DEFAULT = 0x00
|
||||
SHRINK_ABORT = 0x01
|
||||
|
||||
|
||||
@_time_logger
|
||||
def shrink_group(
|
||||
ranks_to_exclude: list[int],
|
||||
group: Optional[ProcessGroup] = None,
|
||||
shrink_flags: int = SHRINK_DEFAULT,
|
||||
pg_options: Optional[Any] = None,
|
||||
) -> ProcessGroup:
|
||||
"""
|
||||
Shrinks a process group by excluding specified ranks.
|
||||
|
||||
Creates and returns a new, smaller process group comprising only the ranks
|
||||
from the original group that were not in the ``ranks_to_exclude`` list.
|
||||
|
||||
Args:
|
||||
ranks_to_exclude (List[int]): A list of ranks from the original
|
||||
``group`` to exclude from the new group.
|
||||
group (ProcessGroup, optional): The process group to shrink. If ``None``,
|
||||
the default process group is used. Defaults to ``None``.
|
||||
shrink_flags (int, optional): Flags to control the shrinking behavior.
|
||||
Can be ``SHRINK_DEFAULT`` (default) or ``SHRINK_ABORT``.
|
||||
``SHRINK_ABORT`` will attempt to terminate ongoing operations
|
||||
in the parent communicator before shrinking.
|
||||
Defaults to ``SHRINK_DEFAULT``.
|
||||
pg_options (ProcessGroupOptions, optional): Backend-specific options to apply
|
||||
to the shrunken process group. If provided, the backend will use
|
||||
these options when creating the new group. If omitted, the new group
|
||||
inherits defaults from the parent.
|
||||
|
||||
Returns:
|
||||
ProcessGroup: a new group comprised of the remaining ranks. If the
|
||||
default group was shrunk, the returned group becomes the new default group.
|
||||
|
||||
Raises:
|
||||
TypeError: if the group’s backend does not support shrinking.
|
||||
ValueError: if ``ranks_to_exclude`` is invalid (empty, out of bounds,
|
||||
duplicates, or excludes all ranks).
|
||||
RuntimeError: if an excluded rank calls this function or the backend
|
||||
fails the operation.
|
||||
|
||||
Notes:
|
||||
- Only non-excluded ranks should call this function; excluded ranks
|
||||
must not participate in the shrink operation.
|
||||
- Shrinking the default group destroys all other process groups since
|
||||
rank reassignment makes them inconsistent.
|
||||
"""
|
||||
# Step 1: Validate input parameters with comprehensive error checking
|
||||
_validate_shrink_inputs(ranks_to_exclude, shrink_flags)
|
||||
|
||||
# Step 2: Get target group and essential properties
|
||||
target_group_info = _prepare_shrink_target_group(group)
|
||||
|
||||
# Step 3: Validate backend requirements and availability
|
||||
backend_impl = _validate_shrink_backend_requirements(target_group_info)
|
||||
|
||||
# Step 4: Validate ranks against group and check for duplicates
|
||||
excluded_ranks_set = _validate_and_process_excluded_ranks(
|
||||
ranks_to_exclude, target_group_info
|
||||
)
|
||||
|
||||
# Step 5: Execute the actual shrink operation (backend-specific)
|
||||
new_backend = backend_impl.shrink(
|
||||
sorted(excluded_ranks_set),
|
||||
shrink_flags,
|
||||
pg_options if pg_options is not None else None,
|
||||
)
|
||||
|
||||
# Step 6: Handle cleanup and creation of new process group
|
||||
target_group_info["pg_options_override"] = pg_options
|
||||
return _finalize_shrunk_group(target_group_info, excluded_ranks_set, new_backend)
|
||||
|
||||
|
||||
def _validate_shrink_inputs(ranks_to_exclude: list[int], shrink_flags: int) -> None:
|
||||
"""Validate input parameters for shrink_group."""
|
||||
if not isinstance(ranks_to_exclude, list):
|
||||
raise TypeError(
|
||||
f"ranks_to_exclude must be a list, but got {type(ranks_to_exclude).__name__}. "
|
||||
f"Example: [1, 3, 5] to exclude ranks 1, 3, and 5."
|
||||
)
|
||||
|
||||
if not ranks_to_exclude:
|
||||
raise ValueError(
|
||||
"ranks_to_exclude cannot be empty. To shrink a group, you must specify at least "
|
||||
"one rank to exclude. Example: [failed_rank_id]"
|
||||
)
|
||||
|
||||
# Validate shrink_flags with clear explanation of valid values
|
||||
valid_flags = [SHRINK_DEFAULT, SHRINK_ABORT]
|
||||
if not isinstance(shrink_flags, int) or shrink_flags not in valid_flags:
|
||||
raise ValueError(
|
||||
f"Invalid shrink_flags value: {shrink_flags}. Must be one of: "
|
||||
f"SHRINK_DEFAULT ({SHRINK_DEFAULT}) or SHRINK_ABORT ({SHRINK_ABORT}). "
|
||||
f"Use SHRINK_ABORT to abort ongoing operations before shrinking."
|
||||
)
|
||||
|
||||
|
||||
def _prepare_shrink_target_group(group: Optional[ProcessGroup]) -> dict:
|
||||
"""Prepare and validate the target group for shrinking."""
|
||||
target_pg = group if group is not None else _get_default_group()
|
||||
|
||||
# Cache frequently accessed properties to avoid repeated calls
|
||||
group_size = int(target_pg.size())
|
||||
group_info = {
|
||||
"process_group": target_pg,
|
||||
"is_default_group": (target_pg == _get_default_group()),
|
||||
"group_size": group_size,
|
||||
"current_rank": target_pg.rank(),
|
||||
"group_name": _get_process_group_name(target_pg),
|
||||
}
|
||||
|
||||
# Validate that we have a valid process group
|
||||
if group_size <= 1:
|
||||
raise ValueError(
|
||||
f"Cannot shrink a process group with size {group_size}. "
|
||||
f"Group must have at least 2 ranks to support shrinking."
|
||||
)
|
||||
|
||||
return group_info
|
||||
|
||||
|
||||
def _validate_shrink_backend_requirements(group_info: dict) -> Any:
|
||||
"""Return the backend implementation for the target group or raise if unsupported."""
|
||||
target_pg = group_info["process_group"]
|
||||
group_name = group_info["group_name"]
|
||||
|
||||
# Get the group's backend directly via ProcessGroup API. Prefer a bound device if present,
|
||||
# otherwise try CUDA then fall back to CPU.
|
||||
try:
|
||||
preferred_device = getattr(target_pg, "bound_device_id", None)
|
||||
if preferred_device is not None:
|
||||
backend_impl = target_pg._get_backend(preferred_device)
|
||||
else:
|
||||
# Try CUDA first if available, else CPU
|
||||
try:
|
||||
backend_impl = target_pg._get_backend(torch.device("cuda"))
|
||||
except Exception:
|
||||
backend_impl = target_pg._get_backend(torch.device("cpu"))
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(
|
||||
f"Cannot access device backend for process group '{group_name}'. "
|
||||
f"Ensure the process group was initialized with a compatible device backend and devices are available."
|
||||
) from e
|
||||
|
||||
try:
|
||||
supports = bool(backend_impl.supports_shrinking)
|
||||
except Exception:
|
||||
supports = False
|
||||
if not supports:
|
||||
raise TypeError(
|
||||
f"Process group backend for '{group_name}' does not support shrinking operations."
|
||||
)
|
||||
|
||||
return backend_impl
|
||||
|
||||
|
||||
def _validate_and_process_excluded_ranks(
|
||||
ranks_to_exclude: list[int], group_info: dict
|
||||
) -> set:
|
||||
"""Validate excluded ranks and convert to set for efficient operations."""
|
||||
group_size = group_info["group_size"]
|
||||
current_rank = group_info["current_rank"]
|
||||
|
||||
# Use set for O(1) duplicate detection and membership testing
|
||||
excluded_ranks_set = set()
|
||||
|
||||
# Validate each rank with detailed error messages
|
||||
for i, rank in enumerate(ranks_to_exclude):
|
||||
if not isinstance(rank, int):
|
||||
raise TypeError(
|
||||
f"All elements in ranks_to_exclude must be integers. "
|
||||
f"Element at index {i} is {type(rank).__name__}: {rank}"
|
||||
)
|
||||
|
||||
if not (0 <= rank < group_size):
|
||||
raise ValueError(
|
||||
f"Rank {rank} at index {i} is out of bounds for group size {group_size}. "
|
||||
f"Valid ranks are in range [0, {group_size - 1}]."
|
||||
)
|
||||
|
||||
if rank in excluded_ranks_set:
|
||||
raise ValueError(
|
||||
f"Duplicate rank {rank} found in ranks_to_exclude at index {i}. "
|
||||
f"Each rank can only be excluded once."
|
||||
)
|
||||
|
||||
excluded_ranks_set.add(rank)
|
||||
|
||||
# Ensure we don't exclude all ranks
|
||||
if len(excluded_ranks_set) >= group_size:
|
||||
raise ValueError(
|
||||
f"Cannot exclude all {group_size} ranks from process group. "
|
||||
f"At least one rank must remain. Excluding {len(excluded_ranks_set)} ranks."
|
||||
)
|
||||
|
||||
# Critical check: current rank should not be in excluded list
|
||||
if current_rank in excluded_ranks_set:
|
||||
raise RuntimeError(
|
||||
f"Current rank {current_rank} is in the exclusion list and should not call shrink_group(). "
|
||||
f"Only non-excluded ranks should participate in the shrinking operation. "
|
||||
f"Excluded ranks should terminate their processes instead."
|
||||
)
|
||||
|
||||
return excluded_ranks_set
|
||||
|
||||
|
||||
def _finalize_shrunk_group(
|
||||
group_info: dict, excluded_ranks_set: set, new_backend
|
||||
) -> ProcessGroup:
|
||||
"""Clean up old group and create new shrunk process group."""
|
||||
target_pg = group_info["process_group"]
|
||||
is_default_group = group_info["is_default_group"]
|
||||
|
||||
# Handle default group dependencies - destroy other groups first
|
||||
if is_default_group:
|
||||
_destroy_all_other_groups(exclude_group=target_pg)
|
||||
|
||||
# Gather original group metadata before cleanup
|
||||
original_group_metadata = _extract_group_metadata(target_pg)
|
||||
|
||||
# Calculate remaining ranks efficiently
|
||||
original_ranks = get_process_group_ranks(target_pg)
|
||||
remaining_ranks = [
|
||||
rank for rank in original_ranks if rank not in excluded_ranks_set
|
||||
]
|
||||
|
||||
# Clean up the original group
|
||||
_cleanup_original_group(target_pg, is_default_group)
|
||||
|
||||
# Create and configure the new process group
|
||||
new_pg = _create_shrunk_process_group(
|
||||
new_backend, remaining_ranks, original_group_metadata, is_default_group
|
||||
)
|
||||
|
||||
# Register the new group in global state
|
||||
if is_default_group:
|
||||
_update_default_pg(new_pg)
|
||||
|
||||
# Update global state with new group information
|
||||
rank_mapping = {
|
||||
global_rank: group_rank
|
||||
for group_rank, global_rank in enumerate(remaining_ranks)
|
||||
}
|
||||
_update_process_group_global_state(
|
||||
pg=new_pg,
|
||||
backend_name=original_group_metadata["backend_name"],
|
||||
store=original_group_metadata["store"],
|
||||
group_name=original_group_metadata["new_group_name"],
|
||||
backend_config=original_group_metadata["backend_config"],
|
||||
rank_mapping=rank_mapping,
|
||||
)
|
||||
|
||||
return new_pg
|
||||
|
||||
|
||||
def _extract_group_metadata(target_pg: ProcessGroup) -> dict:
|
||||
"""Extract metadata from the original group before cleanup."""
|
||||
original_backend_name, original_store = _world.pg_map[target_pg]
|
||||
original_backend_config = _world.pg_backend_config.get(target_pg, "")
|
||||
original_group_name = _get_process_group_name(target_pg)
|
||||
|
||||
# Extract device binding information before cleanup to avoid accessing destroyed group
|
||||
bound_device_id = None
|
||||
if hasattr(target_pg, "bound_device_id"):
|
||||
bound_device_id = target_pg.bound_device_id
|
||||
|
||||
# Generate new group name for the shrunk group; hash for uniqueness across backends
|
||||
remaining_ranks = list(get_process_group_ranks(target_pg))
|
||||
new_group_name = _process_group_name(remaining_ranks, use_hashed_name=True)
|
||||
|
||||
return {
|
||||
"backend_name": original_backend_name,
|
||||
"store": original_store,
|
||||
"backend_config": original_backend_config,
|
||||
"original_group_name": original_group_name,
|
||||
"new_group_name": new_group_name,
|
||||
"bound_device_id": bound_device_id, # Safe to access after cleanup
|
||||
}
|
||||
|
||||
|
||||
def _cleanup_original_group(target_pg: ProcessGroup, is_default_group: bool) -> None:
|
||||
"""Clean up the original process group safely."""
|
||||
try:
|
||||
destroy_process_group(target_pg)
|
||||
except Exception as e:
|
||||
group_type = "default" if is_default_group else "non-default"
|
||||
logger.warning("Failed to destroy %s group during shrinking: %s", group_type, e)
|
||||
|
||||
# Ensure global state cleanup even if destroy_process_group fails
|
||||
_cleanup_process_group_global_state(target_pg)
|
||||
|
||||
|
||||
def _create_shrunk_process_group(
|
||||
new_backend, remaining_ranks: list[int], metadata: dict, is_default_group: bool
|
||||
) -> ProcessGroup:
|
||||
"""Create and configure the new shrunk process group."""
|
||||
# Create new group properties
|
||||
new_group_rank = new_backend.rank()
|
||||
new_group_size = new_backend.size()
|
||||
group_name = metadata["new_group_name"]
|
||||
|
||||
# Generate descriptive group description
|
||||
if is_default_group:
|
||||
group_desc = "default:shrunken"
|
||||
else:
|
||||
group_desc = f"{metadata['original_group_name']}:shrunk"
|
||||
|
||||
# Create process group with new communicator (clone the parent store like split does)
|
||||
prefix_store = PrefixStore(f"{group_name}/", metadata["store"].clone())
|
||||
new_pg = ProcessGroup(prefix_store, new_group_rank, new_group_size)
|
||||
|
||||
# Configure backend using the device type of the new backend's bound device if available,
|
||||
# otherwise derive from the original group's bound device or fall back to CPU.
|
||||
backend_device = metadata.get("bound_device_id")
|
||||
if backend_device is None:
|
||||
# Default to CPU if no bound device is present
|
||||
backend_device = torch.device("cpu")
|
||||
|
||||
# Choose backend enum based on device type
|
||||
if backend_device.type == "cuda":
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
else:
|
||||
backend_type = ProcessGroup.BackendType.GLOO
|
||||
|
||||
new_pg._register_backend(backend_device, backend_type, new_backend)
|
||||
new_pg._set_default_backend(backend_type)
|
||||
|
||||
# Inherit device binding from original group if it was bound
|
||||
bound_device_id = metadata.get("bound_device_id")
|
||||
if bound_device_id is not None:
|
||||
new_pg.bound_device_id = bound_device_id
|
||||
|
||||
# Set group metadata
|
||||
new_pg._set_group_name(group_name)
|
||||
new_pg._set_group_desc(group_desc)
|
||||
|
||||
# Persist backend configuration overrides (if provided via shrink_group)
|
||||
backend_config_override = metadata.get("backend_config")
|
||||
if backend_config_override is not None:
|
||||
# Store for introspection/debugging and potential backend hooks
|
||||
_world.pg_backend_config[new_pg] = backend_config_override
|
||||
|
||||
return new_pg
|
||||
|
||||
|
||||
def _destroy_all_other_groups(exclude_group: Optional[ProcessGroup] = None) -> None:
|
||||
"""
|
||||
Destroy all process groups except the excluded group and clean up all global state.
|
||||
|
||||
This is necessary when shrinking the default group because global ranks
|
||||
are reassigned by NCCL, making all existing process groups inconsistent.
|
||||
|
||||
Note: Uses abort for non-collective cleanup since excluded ranks may not
|
||||
participate in collective operations. Backend cleanup is handled independently per group.
|
||||
|
||||
Args:
|
||||
exclude_group (ProcessGroup, optional): Process group to exclude from destruction.
|
||||
If None, destroys all process groups.
|
||||
"""
|
||||
# Get list of groups to destroy (avoid modifying dict while iterating)
|
||||
groups_to_destroy = []
|
||||
for pg in list(_world.pg_group_ranks.keys()):
|
||||
if exclude_group is not None and pg == exclude_group:
|
||||
continue
|
||||
groups_to_destroy.append(pg)
|
||||
|
||||
# Warn user about automatic destruction
|
||||
if groups_to_destroy:
|
||||
group_names = [_get_process_group_name(pg) for pg in groups_to_destroy]
|
||||
logger.warning(
|
||||
"Shrinking default group will destroy %d other process groups: %s. "
|
||||
"This is necessary because shrinking the default group reassigns global ranks, "
|
||||
"making existing groups inconsistent.",
|
||||
len(groups_to_destroy),
|
||||
", ".join(group_names),
|
||||
)
|
||||
|
||||
# Destroy each group and clean up global state
|
||||
for pg in groups_to_destroy:
|
||||
try:
|
||||
# First call abort_process_group which handles the C++ cleanup non-collectively
|
||||
_abort_process_group(pg)
|
||||
except Exception as e:
|
||||
# Log but don't fail - some groups might already be destroyed
|
||||
logger.warning(
|
||||
"Failed to abort process group %s: %s",
|
||||
_get_process_group_name(pg),
|
||||
e,
|
||||
)
|
||||
|
||||
# Ensure all global state is cleaned up even if _abort_process_group fails
|
||||
# or doesn't clean up everything
|
||||
_cleanup_process_group_global_state(pg)
|
||||
|
||||
|
||||
def _cleanup_process_group_global_state(pg: ProcessGroup) -> None:
|
||||
"""
|
||||
Clean up all global state associated with a process group.
|
||||
|
||||
This function ensures complete cleanup of process group state from all
|
||||
global dictionaries and registries, even if destroy_process_group fails
|
||||
or doesn't clean up everything. This is critical when destroying multiple
|
||||
groups to prevent inconsistent state.
|
||||
|
||||
The cleanup removes the process group from:
|
||||
- _world.pg_map (backend and store mapping)
|
||||
- _world.pg_names (group name mapping)
|
||||
- _world.pg_group_ranks (rank mappings)
|
||||
- _world.pg_backend_config (backend configuration)
|
||||
- _world.tags_to_pg and _world.pg_to_tag (tag mappings)
|
||||
- _world.pg_coalesce_state (coalescing state)
|
||||
- C++ internal registries via _unregister_process_group
|
||||
|
||||
Args:
|
||||
pg (ProcessGroup): The process group to clean up.
|
||||
"""
|
||||
try:
|
||||
# Clean up main process group mappings
|
||||
_world.pg_map.pop(pg, None)
|
||||
_world.pg_group_ranks.pop(pg, None)
|
||||
_world.pg_backend_config.pop(pg, None)
|
||||
|
||||
# Clean up process group name mapping
|
||||
group_name = _world.pg_names.pop(pg, None)
|
||||
|
||||
# Clean up tag mappings
|
||||
pg_tag = _world.pg_to_tag.pop(pg, None)
|
||||
if pg_tag is not None and pg_tag in _world.tags_to_pg:
|
||||
try:
|
||||
_world.tags_to_pg[pg_tag].remove(pg)
|
||||
# Remove the tag entry if list is empty
|
||||
if not _world.tags_to_pg[pg_tag]:
|
||||
_world.tags_to_pg.pop(pg_tag, None)
|
||||
except (ValueError, KeyError):
|
||||
# Process group was already removed from the list
|
||||
pass
|
||||
|
||||
# Clean up any registered process group names using C++ unregister function
|
||||
if group_name is not None:
|
||||
try:
|
||||
_unregister_process_group(group_name)
|
||||
except Exception:
|
||||
# Process group name might not be registered or already unregistered
|
||||
pass
|
||||
|
||||
# Clean up coalesce state if present
|
||||
_world.pg_coalesce_state.pop(pg, None)
|
||||
|
||||
except Exception as e:
|
||||
# Log cleanup failures but don't propagate - we want to continue with other cleanups
|
||||
logger.warning("Failed to fully clean up global state for process group: %s", e)
|
||||
|
||||
|
||||
def _update_process_group_global_state(
|
||||
pg: ProcessGroup,
|
||||
backend_name: str,
|
||||
store: Store,
|
||||
group_name: str,
|
||||
backend_config: str,
|
||||
rank_mapping: Optional[dict[int, int]] = None,
|
||||
pg_tag: Optional[str] = None,
|
||||
user_tag: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update all global state dictionaries for a process group.
|
||||
|
||||
This helper function consolidates the common pattern of updating multiple
|
||||
global state dictionaries when creating or modifying process groups.
|
||||
|
||||
Args:
|
||||
pg (ProcessGroup): The process group to update state for.
|
||||
backend_name (str): Backend name for pg_map.
|
||||
store (Store): Store instance for pg_map.
|
||||
group_name (str): Group name for pg_names and registration.
|
||||
backend_config (str): Backend configuration string.
|
||||
rank_mapping (Dict[int, int], optional): Global rank to group rank mapping.
|
||||
If None, skips updating pg_group_ranks.
|
||||
pg_tag (str, optional): Process group tag. If None, defaults to f"ptd:{group_name}".
|
||||
user_tag (str, optional): User-provided tag for special tag handling.
|
||||
If provided, creates "user:{user_tag}" tag and also adds to default "".
|
||||
"""
|
||||
# Update main process group mappings
|
||||
_world.pg_map[pg] = (backend_name, store)
|
||||
_world.pg_names[pg] = group_name
|
||||
_world.pg_backend_config[pg] = backend_config
|
||||
|
||||
# Register the process group name
|
||||
_register_process_group(group_name, pg)
|
||||
|
||||
# Update rank mapping if provided
|
||||
if rank_mapping is not None:
|
||||
_world.pg_group_ranks[pg] = rank_mapping
|
||||
|
||||
# Handle tag management
|
||||
if pg_tag is None:
|
||||
pg_tag = f"ptd:{group_name}"
|
||||
|
||||
if user_tag is not None:
|
||||
# Special handling for user-provided tags
|
||||
# Add to default "" tag first
|
||||
_world.tags_to_pg.setdefault("", []).append(pg)
|
||||
# Then create user-specific tag
|
||||
user_pg_tag = f"user:{user_tag}"
|
||||
_world.tags_to_pg.setdefault(user_pg_tag, []).append(pg)
|
||||
_world.pg_to_tag[pg] = user_pg_tag
|
||||
else:
|
||||
# Standard process group tag
|
||||
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
|
||||
_world.pg_to_tag[pg] = pg_tag
|
||||
|
@ -238,6 +238,47 @@ def skip_if_lt_x_gpu(x):
|
||||
return decorator
|
||||
|
||||
|
||||
def requires_world_size(n: int):
|
||||
"""
|
||||
Decorator to request a specific world size for a test. The test harness can
|
||||
read this attribute to set the number of ranks to spawn. If there are fewer
|
||||
than `n` CUDA devices available, the test should be skipped by the harness.
|
||||
|
||||
Usage:
|
||||
@require_world_size(3)
|
||||
def test_something(self):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
func._required_world_size = n
|
||||
available = torch.cuda.device_count()
|
||||
return unittest.skipUnless(
|
||||
available >= n, f"requires {n} GPUs, found {available}"
|
||||
)(func)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_required_world_size(obj: Any, default: int) -> int:
|
||||
"""
|
||||
Returns the requested world size for the currently running unittest method on `obj`
|
||||
if annotated via `@require_world_size(n)`, else returns `default`.
|
||||
"""
|
||||
try:
|
||||
# Try MultiProcessTestCase helper first, then unittest fallback
|
||||
test_name = (
|
||||
obj._current_test_name() # type: ignore[attr-defined]
|
||||
if hasattr(obj, "_current_test_name") and callable(obj._current_test_name)
|
||||
else obj._testMethodName
|
||||
)
|
||||
fn = getattr(obj, test_name)
|
||||
value = fn._required_world_size
|
||||
return int(value)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
# This decorator helps avoiding initializing cuda while testing other backends
|
||||
def nccl_skip_if_lt_x_gpu(backend, x):
|
||||
def decorator(func):
|
||||
@ -367,6 +408,13 @@ def requires_nccl_version(version, msg):
|
||||
)
|
||||
|
||||
|
||||
def requires_nccl_shrink():
|
||||
"""
|
||||
Require NCCL shrink support (NCCL available and version >= 2.27).
|
||||
"""
|
||||
return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group")
|
||||
|
||||
|
||||
def requires_nccl():
|
||||
return skip_but_pass_in_sandcastle_if(
|
||||
not c10d.is_nccl_available(),
|
||||
|
Reference in New Issue
Block a user