Revert "shrink_group implementation to expose ncclCommShrink API (#164518)"

This reverts commit fa0db212e717b6cb225159cb32ea3d83baa52381.

Reverted https://github.com/pytorch/pytorch/pull/164518 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/164518#issuecomment-3419893217))
This commit is contained in:
PyTorch MergeBot
2025-10-19 19:20:44 +00:00
parent fa0db212e7
commit 633a3b7f67
11 changed files with 2 additions and 1503 deletions

View File

@ -79,23 +79,6 @@ class TORCH_API Backend : public torch::CustomClassHolder {
return false;
}
virtual bool supportsShrinking() const {
return false;
}
// Shrink the backend by excluding specified ranks. Backends that support
// communicator shrinking should override this and return a new backend
// instance representing the shrunken group. Backends may use opts_override
// to supply backend-specific options for the new group.
virtual c10::intrusive_ptr<Backend> shrink(
const std::vector<int64_t>& /*ranks_to_exclude*/,
int /*shrink_flags*/ = 0,
const c10::intrusive_ptr<Options>& /*opts_override*/ = nullptr) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), " does not support shrink"));
}
virtual void setTimeout(std::chrono::milliseconds timeout) {
TORCH_CHECK(
false,

View File

@ -259,65 +259,6 @@ std::shared_ptr<NCCLComm> NCCLComm::split(
}
#endif
#ifdef NCCL_HAS_COMM_SHRINK
std::shared_ptr<NCCLComm> NCCLComm::shrink(
NCCLComm* source,
std::vector<int>& ranks_to_exclude,
ncclConfig_t* config,
int shrinkFlags) {
// Preconditions are validated in ProcessGroupNCCL::shrink
LOG(INFO) << "Rank " << source->rank_ << ": shrinking comm " << source->repr()
<< " excluding " << ranks_to_exclude.size() << " ranks";
at::cuda::OptionalCUDAGuard gpuGuard(source->deviceIndex_);
auto comm = std::make_shared<NCCLComm>();
// This call will block until the source communicator is initialized
auto sourceComm = source->getNcclComm();
C10D_NCCL_CHECK_NONBLOCKING(
ncclCommShrink(
sourceComm,
ranks_to_exclude.data(),
ranks_to_exclude.size(),
reinterpret_cast<ncclComm_t*>(&(comm->ncclComm_)),
config,
shrinkFlags),
source->getNcclCommFailureReason());
// Wait for the child communicator to be ready
source->waitReady(true);
comm->initialized_ = true;
// NCCL automatically assigns rank during shrink - query it efficiently
int assigned_rank;
try {
C10D_NCCL_CHECK(
ncclCommUserRank(comm->ncclComm_, &assigned_rank), std::nullopt);
comm->rank_ = assigned_rank;
} catch (const std::exception& e) {
// Fallback: if ncclCommUserRank fails, we can't determine the rank
LOG(ERROR) << "Failed to query NCCL-assigned rank: " << e.what();
throw;
}
// Child comm should be on the same device as parent comm
comm->deviceIndex_ = source->deviceIndex_;
if (config != nullptr) {
comm->nonBlocking_ = config->blocking == 0;
} else {
// Inherit parent behavior if no config provided
comm->nonBlocking_ = source->nonBlocking_;
}
LOG(INFO) << "Rank " << source->rank_ << ": created shrunken comm "
<< comm->repr() << " with NCCL-assigned rank " << assigned_rank;
return comm;
}
#endif
void NCCLComm::finalize() {
LockType lock(mutex_);
if (aborted_) {

View File

@ -90,10 +90,6 @@ static_assert(
#define NCCL_HAS_NVLS_CTAS
#endif
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0)
#define NCCL_HAS_COMM_SHRINK
#endif
// Macro to throw on a non-successful NCCL return value.
#define C10D_NCCL_CHECK(cmd, failureReason) \
do { \
@ -298,14 +294,6 @@ class NCCLComm {
ncclConfig_t& config);
#endif // NCCL_HAS_COMM_SPLIT
#ifdef NCCL_HAS_COMM_SHRINK
static std::shared_ptr<NCCLComm> shrink(
NCCLComm* source,
std::vector<int>& ranks_to_exclude,
ncclConfig_t* config,
int shrinkFlags = 0);
#endif // NCCL_HAS_COMM_SHRINK
#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
std::unordered_map<std::string, std::string> ncclCommDump();
#endif

View File

@ -165,7 +165,7 @@ ncclRedOpRAII getNcclReduceOp(
}
// Get a key string from device
inline std::string getKeyFromDevice(const at::Device& device) {
inline std::string getKeyFromDevice(at::Device& device) {
return std::to_string(device.index());
}
@ -5838,139 +5838,6 @@ at::Tensor ProcessGroupNCCL::allocateTensor(
return tensor;
}
#ifdef NCCL_HAS_COMM_SHRINK
c10::intrusive_ptr<Backend> ProcessGroupNCCL::shrink(
const std::vector<int64_t>& ranks_to_exclude,
int shrink_flags,
const c10::intrusive_ptr<Backend::Options>& opts_override) {
// Runtime version check with better error message
auto runtime_version = torch::cuda::nccl::version();
TORCH_CHECK(
runtime_version >= NCCL_VERSION(2, 27, 0),
"ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later. "
"Found version: ",
runtime_version);
// Early validation with detailed error messages
TORCH_CHECK_VALUE(
!ranks_to_exclude.empty(), "ranks_to_exclude cannot be empty");
TORCH_CHECK_VALUE(
static_cast<int>(ranks_to_exclude.size()) < size_,
"Cannot exclude all ranks (",
ranks_to_exclude.size(),
" >= ",
size_,
")");
// Validate ranks and convert to int efficiently
std::vector<int> int_ranks_to_exclude;
int_ranks_to_exclude.reserve(ranks_to_exclude.size());
for (int64_t rank : ranks_to_exclude) {
TORCH_CHECK_VALUE(
rank >= 0 && rank < size_,
"Invalid rank ",
rank,
" for group size ",
size_);
int_ranks_to_exclude.push_back(static_cast<int>(rank));
}
// Get primary communicator with better error context
auto primary_device_index = guessDeviceId();
auto primary_device = at::Device(at::kCUDA, primary_device_index);
const auto primary_key = getKeyFromDevice(primary_device);
std::shared_ptr<NCCLComm> primary_comm = getNCCLComm(primary_key);
TORCH_CHECK(
primary_comm,
"Primary NCCL communicator for device ",
primary_device,
" (key: ",
primary_key,
") is not initialized");
// Cache device index before shrink operation
at::DeviceIndex parent_device_index = primary_comm->getDeviceIndex();
ncclConfig_t* config = nullptr;
// Default to inheriting from parent options
bool high_priority_stream = options_->is_high_priority_stream;
if (opts_override) {
auto nccl_opts =
c10::static_intrusive_pointer_cast<ProcessGroupNCCL::Options>(
opts_override);
config = &nccl_opts->config;
// If user provided override options, honor is_high_priority_stream as well
high_priority_stream = nccl_opts->is_high_priority_stream;
}
std::shared_ptr<NCCLComm> shrunk_comm = NCCLComm::shrink(
primary_comm.get(),
int_ranks_to_exclude,
(config != nullptr ? config : &options_->config),
shrink_flags);
// Calculate new size and get NCCL-assigned rank
int new_size = size_ - static_cast<int>(ranks_to_exclude.size());
int new_rank = shrunk_comm->rank_;
// Create new ProcessGroupNCCL with optimized options cloning
auto new_store = store_->clone();
auto new_opts = ProcessGroupNCCL::Options::create(high_priority_stream);
new_opts->timeout = options_->timeout;
if (config != nullptr) {
new_opts->config = *config;
} else {
new_opts->config = options_->config;
}
auto new_pg = c10::make_intrusive<ProcessGroupNCCL>(
new_store, new_rank, new_size, new_opts);
// Set up the new process group with optimized device setup
new_pg->initializeDeviceStateForComm(
at::Device(at::kCUDA, parent_device_index), shrunk_comm);
return c10::static_intrusive_pointer_cast<Backend>(new_pg);
}
#else // !NCCL_HAS_COMM_SHRINK
// Backend interface override: raise consistent error when shrink is
// unsupported.
c10::intrusive_ptr<Backend> ProcessGroupNCCL::shrink(
const std::vector<int64_t>& /*ranks_to_exclude*/,
int /*shrink_flags*/,
const c10::intrusive_ptr<Backend::Options>& /*opts_override*/) {
TORCH_CHECK(
false,
"ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later, "
"but PyTorch was built with an older version or without NCCL shrink support.");
}
#endif // NCCL_HAS_COMM_SHRINK
void ProcessGroupNCCL::initializeDeviceStateForComm(
const at::Device& device,
std::shared_ptr<NCCLComm> comm) {
const auto key = getKeyFromDevice(device);
std::unique_lock<std::mutex> lock(mutex_);
at::cuda::OptionalCUDAGuard gpuGuard(device);
bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false);
auto stream = at::cuda::getStreamFromPool(
options_->is_high_priority_stream || force_high);
devNCCLCommMap_[key] = comm;
ncclStreams_.emplace(key, stream);
ncclEvents_.emplace(key, at::cuda::CUDAEvent(cudaEventDisableTiming));
usedDeviceIdxs_.insert(device.index());
if (shouldAllCommunicatorsRegisterAllTensors()) {
std::lock_guard<std::mutex> map_lock(ncclCommMemPoolMapMutex);
ncclCommMemPoolMap.emplace(std::move(comm), MemPoolSet{});
}
}
} // namespace c10d
#endif // USE_C10D_NCCL

View File

@ -997,21 +997,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
ErrorType getError() override;
bool supportsShrinking() const override {
#ifdef NCCL_HAS_COMM_SHRINK
return true;
#else
return false;
#endif
}
// Backend-style shrink override that returns a Backend instance.
c10::intrusive_ptr<Backend> shrink(
const std::vector<int64_t>& ranks_to_exclude,
int shrink_flags = 0,
const c10::intrusive_ptr<Backend::Options>& opts_override =
nullptr) override;
std::shared_ptr<c10::Allocator> getMemAllocator() override;
// Allocate tensor from communication-optimized memory pool
@ -1080,12 +1065,6 @@ class TORCH_API ProcessGroupNCCL : public Backend {
int p2pRank = 0,
bool isSendRecvSelf = false);
// Initialize device-specific state (comm, stream, event, bookkeeping) for a
// given communicator on this process group instance.
void initializeDeviceStateForComm(
const at::Device& device,
std::shared_ptr<NCCLComm> comm);
// Wrapper method which can be overridden for tests.
virtual std::exception_ptr checkForNCCLErrors(
std::shared_ptr<NCCLComm>& ncclComm);

View File

@ -2730,23 +2730,12 @@ Arguments:
"supports_time_estimate",
&::c10d::Backend::supportsTimeEstimation,
"(test whether the backend supports collective time estimation)")
.def_property_readonly(
"supports_shrinking",
&::c10d::Backend::supportsShrinking,
"(test whether the backend supports communicator shrinking)")
.def(
"set_timeout",
&::c10d::Backend::setTimeout,
py::arg("timeout"),
py::call_guard<py::gil_scoped_release>(),
R"(Sets the default timeout for all future operations.)")
.def(
"shrink",
&::c10d::Backend::shrink,
py::arg("ranks_to_exclude"),
py::arg("shrink_flags") = 0,
py::arg("opts_override") = nullptr,
py::call_guard<py::gil_scoped_release>())
.def(
"broadcast",
&::c10d::Backend::broadcast,

View File

@ -130,7 +130,6 @@ __all__ = [
"reduce_scatter_tensor",
"get_node_local_rank",
"split_group",
"shrink_group",
]
_MPI_AVAILABLE = True
@ -5714,517 +5713,3 @@ def _get_process_group_name(pg: ProcessGroup) -> str:
def _get_process_group_store(pg: ProcessGroup) -> Store:
return _world.pg_map[pg][1]
# Shrink flags for process group backends
SHRINK_DEFAULT = 0x00
SHRINK_ABORT = 0x01
@_time_logger
def shrink_group(
ranks_to_exclude: list[int],
group: Optional[ProcessGroup] = None,
shrink_flags: int = SHRINK_DEFAULT,
pg_options: Optional[Any] = None,
) -> ProcessGroup:
"""
Shrinks a process group by excluding specified ranks.
Creates and returns a new, smaller process group comprising only the ranks
from the original group that were not in the ``ranks_to_exclude`` list.
Args:
ranks_to_exclude (List[int]): A list of ranks from the original
``group`` to exclude from the new group.
group (ProcessGroup, optional): The process group to shrink. If ``None``,
the default process group is used. Defaults to ``None``.
shrink_flags (int, optional): Flags to control the shrinking behavior.
Can be ``SHRINK_DEFAULT`` (default) or ``SHRINK_ABORT``.
``SHRINK_ABORT`` will attempt to terminate ongoing operations
in the parent communicator before shrinking.
Defaults to ``SHRINK_DEFAULT``.
pg_options (ProcessGroupOptions, optional): Backend-specific options to apply
to the shrunken process group. If provided, the backend will use
these options when creating the new group. If omitted, the new group
inherits defaults from the parent.
Returns:
ProcessGroup: a new group comprised of the remaining ranks. If the
default group was shrunk, the returned group becomes the new default group.
Raises:
TypeError: if the groups backend does not support shrinking.
ValueError: if ``ranks_to_exclude`` is invalid (empty, out of bounds,
duplicates, or excludes all ranks).
RuntimeError: if an excluded rank calls this function or the backend
fails the operation.
Notes:
- Only non-excluded ranks should call this function; excluded ranks
must not participate in the shrink operation.
- Shrinking the default group destroys all other process groups since
rank reassignment makes them inconsistent.
"""
# Step 1: Validate input parameters with comprehensive error checking
_validate_shrink_inputs(ranks_to_exclude, shrink_flags)
# Step 2: Get target group and essential properties
target_group_info = _prepare_shrink_target_group(group)
# Step 3: Validate backend requirements and availability
backend_impl = _validate_shrink_backend_requirements(target_group_info)
# Step 4: Validate ranks against group and check for duplicates
excluded_ranks_set = _validate_and_process_excluded_ranks(
ranks_to_exclude, target_group_info
)
# Step 5: Execute the actual shrink operation (backend-specific)
new_backend = backend_impl.shrink(
sorted(excluded_ranks_set),
shrink_flags,
pg_options if pg_options is not None else None,
)
# Step 6: Handle cleanup and creation of new process group
target_group_info["pg_options_override"] = pg_options
return _finalize_shrunk_group(target_group_info, excluded_ranks_set, new_backend)
def _validate_shrink_inputs(ranks_to_exclude: list[int], shrink_flags: int) -> None:
"""Validate input parameters for shrink_group."""
if not isinstance(ranks_to_exclude, list):
raise TypeError(
f"ranks_to_exclude must be a list, but got {type(ranks_to_exclude).__name__}. "
f"Example: [1, 3, 5] to exclude ranks 1, 3, and 5."
)
if not ranks_to_exclude:
raise ValueError(
"ranks_to_exclude cannot be empty. To shrink a group, you must specify at least "
"one rank to exclude. Example: [failed_rank_id]"
)
# Validate shrink_flags with clear explanation of valid values
valid_flags = [SHRINK_DEFAULT, SHRINK_ABORT]
if not isinstance(shrink_flags, int) or shrink_flags not in valid_flags:
raise ValueError(
f"Invalid shrink_flags value: {shrink_flags}. Must be one of: "
f"SHRINK_DEFAULT ({SHRINK_DEFAULT}) or SHRINK_ABORT ({SHRINK_ABORT}). "
f"Use SHRINK_ABORT to abort ongoing operations before shrinking."
)
def _prepare_shrink_target_group(group: Optional[ProcessGroup]) -> dict:
"""Prepare and validate the target group for shrinking."""
target_pg = group if group is not None else _get_default_group()
# Cache frequently accessed properties to avoid repeated calls
group_size = int(target_pg.size())
group_info = {
"process_group": target_pg,
"is_default_group": (target_pg == _get_default_group()),
"group_size": group_size,
"current_rank": target_pg.rank(),
"group_name": _get_process_group_name(target_pg),
}
# Validate that we have a valid process group
if group_size <= 1:
raise ValueError(
f"Cannot shrink a process group with size {group_size}. "
f"Group must have at least 2 ranks to support shrinking."
)
return group_info
def _validate_shrink_backend_requirements(group_info: dict) -> Any:
"""Return the backend implementation for the target group or raise if unsupported."""
target_pg = group_info["process_group"]
group_name = group_info["group_name"]
# Get the group's backend directly via ProcessGroup API. Prefer a bound device if present,
# otherwise try CUDA then fall back to CPU.
try:
preferred_device = getattr(target_pg, "bound_device_id", None)
if preferred_device is not None:
backend_impl = target_pg._get_backend(preferred_device)
else:
# Try CUDA first if available, else CPU
try:
backend_impl = target_pg._get_backend(torch.device("cuda"))
except Exception:
backend_impl = target_pg._get_backend(torch.device("cpu"))
except RuntimeError as e:
raise RuntimeError(
f"Cannot access device backend for process group '{group_name}'. "
f"Ensure the process group was initialized with a compatible device backend and devices are available."
) from e
try:
supports = bool(backend_impl.supports_shrinking)
except Exception:
supports = False
if not supports:
raise TypeError(
f"Process group backend for '{group_name}' does not support shrinking operations."
)
return backend_impl
def _validate_and_process_excluded_ranks(
ranks_to_exclude: list[int], group_info: dict
) -> set:
"""Validate excluded ranks and convert to set for efficient operations."""
group_size = group_info["group_size"]
current_rank = group_info["current_rank"]
# Use set for O(1) duplicate detection and membership testing
excluded_ranks_set = set()
# Validate each rank with detailed error messages
for i, rank in enumerate(ranks_to_exclude):
if not isinstance(rank, int):
raise TypeError(
f"All elements in ranks_to_exclude must be integers. "
f"Element at index {i} is {type(rank).__name__}: {rank}"
)
if not (0 <= rank < group_size):
raise ValueError(
f"Rank {rank} at index {i} is out of bounds for group size {group_size}. "
f"Valid ranks are in range [0, {group_size - 1}]."
)
if rank in excluded_ranks_set:
raise ValueError(
f"Duplicate rank {rank} found in ranks_to_exclude at index {i}. "
f"Each rank can only be excluded once."
)
excluded_ranks_set.add(rank)
# Ensure we don't exclude all ranks
if len(excluded_ranks_set) >= group_size:
raise ValueError(
f"Cannot exclude all {group_size} ranks from process group. "
f"At least one rank must remain. Excluding {len(excluded_ranks_set)} ranks."
)
# Critical check: current rank should not be in excluded list
if current_rank in excluded_ranks_set:
raise RuntimeError(
f"Current rank {current_rank} is in the exclusion list and should not call shrink_group(). "
f"Only non-excluded ranks should participate in the shrinking operation. "
f"Excluded ranks should terminate their processes instead."
)
return excluded_ranks_set
def _finalize_shrunk_group(
group_info: dict, excluded_ranks_set: set, new_backend
) -> ProcessGroup:
"""Clean up old group and create new shrunk process group."""
target_pg = group_info["process_group"]
is_default_group = group_info["is_default_group"]
# Handle default group dependencies - destroy other groups first
if is_default_group:
_destroy_all_other_groups(exclude_group=target_pg)
# Gather original group metadata before cleanup
original_group_metadata = _extract_group_metadata(target_pg)
# Calculate remaining ranks efficiently
original_ranks = get_process_group_ranks(target_pg)
remaining_ranks = [
rank for rank in original_ranks if rank not in excluded_ranks_set
]
# Clean up the original group
_cleanup_original_group(target_pg, is_default_group)
# Create and configure the new process group
new_pg = _create_shrunk_process_group(
new_backend, remaining_ranks, original_group_metadata, is_default_group
)
# Register the new group in global state
if is_default_group:
_update_default_pg(new_pg)
# Update global state with new group information
rank_mapping = {
global_rank: group_rank
for group_rank, global_rank in enumerate(remaining_ranks)
}
_update_process_group_global_state(
pg=new_pg,
backend_name=original_group_metadata["backend_name"],
store=original_group_metadata["store"],
group_name=original_group_metadata["new_group_name"],
backend_config=original_group_metadata["backend_config"],
rank_mapping=rank_mapping,
)
return new_pg
def _extract_group_metadata(target_pg: ProcessGroup) -> dict:
"""Extract metadata from the original group before cleanup."""
original_backend_name, original_store = _world.pg_map[target_pg]
original_backend_config = _world.pg_backend_config.get(target_pg, "")
original_group_name = _get_process_group_name(target_pg)
# Extract device binding information before cleanup to avoid accessing destroyed group
bound_device_id = None
if hasattr(target_pg, "bound_device_id"):
bound_device_id = target_pg.bound_device_id
# Generate new group name for the shrunk group; hash for uniqueness across backends
remaining_ranks = list(get_process_group_ranks(target_pg))
new_group_name = _process_group_name(remaining_ranks, use_hashed_name=True)
return {
"backend_name": original_backend_name,
"store": original_store,
"backend_config": original_backend_config,
"original_group_name": original_group_name,
"new_group_name": new_group_name,
"bound_device_id": bound_device_id, # Safe to access after cleanup
}
def _cleanup_original_group(target_pg: ProcessGroup, is_default_group: bool) -> None:
"""Clean up the original process group safely."""
try:
destroy_process_group(target_pg)
except Exception as e:
group_type = "default" if is_default_group else "non-default"
logger.warning("Failed to destroy %s group during shrinking: %s", group_type, e)
# Ensure global state cleanup even if destroy_process_group fails
_cleanup_process_group_global_state(target_pg)
def _create_shrunk_process_group(
new_backend, remaining_ranks: list[int], metadata: dict, is_default_group: bool
) -> ProcessGroup:
"""Create and configure the new shrunk process group."""
# Create new group properties
new_group_rank = new_backend.rank()
new_group_size = new_backend.size()
group_name = metadata["new_group_name"]
# Generate descriptive group description
if is_default_group:
group_desc = "default:shrunken"
else:
group_desc = f"{metadata['original_group_name']}:shrunk"
# Create process group with new communicator (clone the parent store like split does)
prefix_store = PrefixStore(f"{group_name}/", metadata["store"].clone())
new_pg = ProcessGroup(prefix_store, new_group_rank, new_group_size)
# Configure backend using the device type of the new backend's bound device if available,
# otherwise derive from the original group's bound device or fall back to CPU.
backend_device = metadata.get("bound_device_id")
if backend_device is None:
# Default to CPU if no bound device is present
backend_device = torch.device("cpu")
# Choose backend enum based on device type
if backend_device.type == "cuda":
backend_type = ProcessGroup.BackendType.NCCL
else:
backend_type = ProcessGroup.BackendType.GLOO
new_pg._register_backend(backend_device, backend_type, new_backend)
new_pg._set_default_backend(backend_type)
# Inherit device binding from original group if it was bound
bound_device_id = metadata.get("bound_device_id")
if bound_device_id is not None:
new_pg.bound_device_id = bound_device_id
# Set group metadata
new_pg._set_group_name(group_name)
new_pg._set_group_desc(group_desc)
# Persist backend configuration overrides (if provided via shrink_group)
backend_config_override = metadata.get("backend_config")
if backend_config_override is not None:
# Store for introspection/debugging and potential backend hooks
_world.pg_backend_config[new_pg] = backend_config_override
return new_pg
def _destroy_all_other_groups(exclude_group: Optional[ProcessGroup] = None) -> None:
"""
Destroy all process groups except the excluded group and clean up all global state.
This is necessary when shrinking the default group because global ranks
are reassigned by NCCL, making all existing process groups inconsistent.
Note: Uses abort for non-collective cleanup since excluded ranks may not
participate in collective operations. Backend cleanup is handled independently per group.
Args:
exclude_group (ProcessGroup, optional): Process group to exclude from destruction.
If None, destroys all process groups.
"""
# Get list of groups to destroy (avoid modifying dict while iterating)
groups_to_destroy = []
for pg in list(_world.pg_group_ranks.keys()):
if exclude_group is not None and pg == exclude_group:
continue
groups_to_destroy.append(pg)
# Warn user about automatic destruction
if groups_to_destroy:
group_names = [_get_process_group_name(pg) for pg in groups_to_destroy]
logger.warning(
"Shrinking default group will destroy %d other process groups: %s. "
"This is necessary because shrinking the default group reassigns global ranks, "
"making existing groups inconsistent.",
len(groups_to_destroy),
", ".join(group_names),
)
# Destroy each group and clean up global state
for pg in groups_to_destroy:
try:
# First call abort_process_group which handles the C++ cleanup non-collectively
_abort_process_group(pg)
except Exception as e:
# Log but don't fail - some groups might already be destroyed
logger.warning(
"Failed to abort process group %s: %s",
_get_process_group_name(pg),
e,
)
# Ensure all global state is cleaned up even if _abort_process_group fails
# or doesn't clean up everything
_cleanup_process_group_global_state(pg)
def _cleanup_process_group_global_state(pg: ProcessGroup) -> None:
"""
Clean up all global state associated with a process group.
This function ensures complete cleanup of process group state from all
global dictionaries and registries, even if destroy_process_group fails
or doesn't clean up everything. This is critical when destroying multiple
groups to prevent inconsistent state.
The cleanup removes the process group from:
- _world.pg_map (backend and store mapping)
- _world.pg_names (group name mapping)
- _world.pg_group_ranks (rank mappings)
- _world.pg_backend_config (backend configuration)
- _world.tags_to_pg and _world.pg_to_tag (tag mappings)
- _world.pg_coalesce_state (coalescing state)
- C++ internal registries via _unregister_process_group
Args:
pg (ProcessGroup): The process group to clean up.
"""
try:
# Clean up main process group mappings
_world.pg_map.pop(pg, None)
_world.pg_group_ranks.pop(pg, None)
_world.pg_backend_config.pop(pg, None)
# Clean up process group name mapping
group_name = _world.pg_names.pop(pg, None)
# Clean up tag mappings
pg_tag = _world.pg_to_tag.pop(pg, None)
if pg_tag is not None and pg_tag in _world.tags_to_pg:
try:
_world.tags_to_pg[pg_tag].remove(pg)
# Remove the tag entry if list is empty
if not _world.tags_to_pg[pg_tag]:
_world.tags_to_pg.pop(pg_tag, None)
except (ValueError, KeyError):
# Process group was already removed from the list
pass
# Clean up any registered process group names using C++ unregister function
if group_name is not None:
try:
_unregister_process_group(group_name)
except Exception:
# Process group name might not be registered or already unregistered
pass
# Clean up coalesce state if present
_world.pg_coalesce_state.pop(pg, None)
except Exception as e:
# Log cleanup failures but don't propagate - we want to continue with other cleanups
logger.warning("Failed to fully clean up global state for process group: %s", e)
def _update_process_group_global_state(
pg: ProcessGroup,
backend_name: str,
store: Store,
group_name: str,
backend_config: str,
rank_mapping: Optional[dict[int, int]] = None,
pg_tag: Optional[str] = None,
user_tag: Optional[str] = None,
) -> None:
"""
Update all global state dictionaries for a process group.
This helper function consolidates the common pattern of updating multiple
global state dictionaries when creating or modifying process groups.
Args:
pg (ProcessGroup): The process group to update state for.
backend_name (str): Backend name for pg_map.
store (Store): Store instance for pg_map.
group_name (str): Group name for pg_names and registration.
backend_config (str): Backend configuration string.
rank_mapping (Dict[int, int], optional): Global rank to group rank mapping.
If None, skips updating pg_group_ranks.
pg_tag (str, optional): Process group tag. If None, defaults to f"ptd:{group_name}".
user_tag (str, optional): User-provided tag for special tag handling.
If provided, creates "user:{user_tag}" tag and also adds to default "".
"""
# Update main process group mappings
_world.pg_map[pg] = (backend_name, store)
_world.pg_names[pg] = group_name
_world.pg_backend_config[pg] = backend_config
# Register the process group name
_register_process_group(group_name, pg)
# Update rank mapping if provided
if rank_mapping is not None:
_world.pg_group_ranks[pg] = rank_mapping
# Handle tag management
if pg_tag is None:
pg_tag = f"ptd:{group_name}"
if user_tag is not None:
# Special handling for user-provided tags
# Add to default "" tag first
_world.tags_to_pg.setdefault("", []).append(pg)
# Then create user-specific tag
user_pg_tag = f"user:{user_tag}"
_world.tags_to_pg.setdefault(user_pg_tag, []).append(pg)
_world.pg_to_tag[pg] = user_pg_tag
else:
# Standard process group tag
_world.tags_to_pg.setdefault(pg_tag, []).append(pg)
_world.pg_to_tag[pg] = pg_tag

View File

@ -238,47 +238,6 @@ def skip_if_lt_x_gpu(x):
return decorator
def requires_world_size(n: int):
"""
Decorator to request a specific world size for a test. The test harness can
read this attribute to set the number of ranks to spawn. If there are fewer
than `n` CUDA devices available, the test should be skipped by the harness.
Usage:
@require_world_size(3)
def test_something(self):
...
"""
def decorator(func):
func._required_world_size = n
available = torch.cuda.device_count()
return unittest.skipUnless(
available >= n, f"requires {n} GPUs, found {available}"
)(func)
return decorator
def get_required_world_size(obj: Any, default: int) -> int:
"""
Returns the requested world size for the currently running unittest method on `obj`
if annotated via `@require_world_size(n)`, else returns `default`.
"""
try:
# Try MultiProcessTestCase helper first, then unittest fallback
test_name = (
obj._current_test_name() # type: ignore[attr-defined]
if hasattr(obj, "_current_test_name") and callable(obj._current_test_name)
else obj._testMethodName
)
fn = getattr(obj, test_name)
value = fn._required_world_size
return int(value)
except Exception:
return default
# This decorator helps avoiding initializing cuda while testing other backends
def nccl_skip_if_lt_x_gpu(backend, x):
def decorator(func):
@ -408,13 +367,6 @@ def requires_nccl_version(version, msg):
)
def requires_nccl_shrink():
"""
Require NCCL shrink support (NCCL available and version >= 2.27).
"""
return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group")
def requires_nccl():
return skip_but_pass_in_sandcastle_if(
not c10d.is_nccl_available(),