Revert "[distributed] Replace assert statements with AssertionError exceptions (#165216)"

This reverts commit 74db92b21868b7e9e77cc966e5d57a8246723cbd.

Reverted https://github.com/pytorch/pytorch/pull/165216 on behalf of https://github.com/clee2000 due to I think this broke distributed/test_pg_wrapper.py::ProcessGroupNCCLWrapperTest::test_debug_level_detail_no_gloo [GH job link](https://github.com/pytorch/pytorch/actions/runs/18492765290/job/52693842750) [HUD commit link](74db92b218), note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/165216#issuecomment-3402838765))
This commit is contained in:
PyTorch MergeBot
2025-10-14 17:05:16 +00:00
parent 5eddbb5e47
commit d2494cbb2b
11 changed files with 136 additions and 222 deletions

View File

@ -1526,8 +1526,7 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) ->
group = _get_default_group()
if _rank_not_in_group(group):
raise ValueError("Invalid process group specified")
if not isinstance(group, ProcessGroup):
raise AssertionError(f"Expected ProcessGroup, got {type(group)}")
assert isinstance(group, ProcessGroup)
devices = group._device_types
backends = set()
if torch.device("cpu") in devices and is_gloo_available():
@ -1666,14 +1665,13 @@ def init_process_group(
if "torch._dynamo" in sys.modules:
torch._dynamo.trace_rules.clear_lru_cache()
if not ((store is None) or (init_method is None)):
raise AssertionError("Cannot specify both init_method and store.")
assert (store is None) or (init_method is None), (
"Cannot specify both init_method and store."
)
if store is not None:
if not world_size > 0:
raise AssertionError("world_size must be positive if using store")
if not rank >= 0:
raise AssertionError("rank must be non-negative if using store")
assert world_size > 0, "world_size must be positive if using store"
assert rank >= 0, "rank must be non-negative if using store"
elif init_method is None:
init_method = "env://"
@ -1947,8 +1945,7 @@ def _new_process_group_helper(
backend_config = BackendConfig(backend)
# Set the default backend when single backend is passed in.
if "," not in str(backend) and ":" not in str(backend):
if backend not in Backend.backend_type_map:
raise AssertionError(f"Unknown backend type {backend}")
assert backend in Backend.backend_type_map, f"Unknown backend type {backend}"
if backend == Backend.UNDEFINED:
# Currently when backend is UNDEFINED, only one backend will be initialized
# we use nccl (if cuda is available) or gloo as default backend
@ -2018,10 +2015,9 @@ def _new_process_group_helper(
if not is_nccl_available():
raise RuntimeError("Distributed package doesn't have NCCL built in")
if backend_options is not None:
if not isinstance(backend_options, ProcessGroupNCCL.Options):
raise AssertionError(
"Expected backend_options argument to be of type ProcessGroupNCCL.Options"
)
assert isinstance(backend_options, ProcessGroupNCCL.Options), (
"Expected backend_options argument to be of type ProcessGroupNCCL.Options"
)
if backend_options._timeout != timeout:
warnings.warn(
"backend_options._timeout was specified, "
@ -2071,8 +2067,9 @@ def _new_process_group_helper(
)
backend_type = ProcessGroup.BackendType.XCCL
else:
if backend_str.upper() not in Backend._plugins:
raise AssertionError(f"Unknown c10d backend type {backend_str.upper()}")
assert backend_str.upper() in Backend._plugins, (
f"Unknown c10d backend type {backend_str.upper()}"
)
backend_plugin = Backend._plugins[backend_str.upper()]
creator_fn = backend_plugin.creator_fn
@ -2097,16 +2094,10 @@ def _new_process_group_helper(
# Set sequence numbers for gloo and nccl backends.
if backend_str == Backend.GLOO:
if not isinstance(backend_class, ProcessGroupGloo):
raise AssertionError(
f"Expected ProcessGroupGloo, got {type(backend_class)}"
)
assert isinstance(backend_class, ProcessGroupGloo)
backend_class._set_sequence_number_for_group()
elif backend_str == Backend.NCCL:
if not isinstance(backend_class, ProcessGroupNCCL):
raise AssertionError(
f"Expected ProcessGroupNCCL, got {type(backend_class)}"
)
assert isinstance(backend_class, ProcessGroupNCCL)
backend_class._set_sequence_number_for_group()
# If the type is a subclass of ProcessGroup then return this process group immediately
@ -2153,10 +2144,8 @@ def _new_process_group_helper(
pg._register_backend(torch.device(device), backend_type, backend_class)
# set group_name and group_dsec to backend
if group_name is None:
raise AssertionError("group_name must not be None")
if group_desc is None:
raise AssertionError("group_desc must not be None")
assert group_name is not None
assert group_desc is not None
pg._set_group_name(group_name)
pg._set_group_desc(group_desc)
@ -2202,8 +2191,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
else:
pg = group
if pg is None:
raise AssertionError("Process group cannot be None")
assert pg is not None
if _world.pg_map.get(pg, None) is None:
raise ValueError("Invalid process group specified")
@ -2293,8 +2281,7 @@ def _abort_process_group(group: Optional[ProcessGroup] = None):
pg = group or GroupMember.WORLD
if pg is None:
raise AssertionError("Process group cannot be None")
assert pg is not None
if _world.pg_map.get(pg, None) is None:
raise ValueError("Invalid process group specified or has been destroyed.")
@ -3351,8 +3338,7 @@ def gather_object(
if my_group_rank != group_dst:
return
if object_gather_list is None:
raise AssertionError("Must provide object_gather_list on dst rank")
assert object_gather_list is not None, "Must provide object_gather_list on dst rank"
# pyrefly: ignore # unbound-name
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8)
@ -3608,8 +3594,9 @@ def recv_object_list(
rank_objects = get_global_rank(group, group_src)
else:
rank_objects = recv(object_tensor, group=group, group_src=group_src)
if rank_sizes != rank_objects:
raise AssertionError("Mismatch in return ranks for object sizes and objects.")
assert rank_sizes == rank_objects, (
"Mismatch in return ranks for object sizes and objects."
)
# Deserialize objects using their stored sizes.
offset = 0
for i, obj_size in enumerate(object_sizes_tensor):
@ -5016,8 +5003,7 @@ def _create_process_group_wrapper(
world_size: int,
timeout: timedelta = default_pg_timeout,
):
if not _GLOO_AVAILABLE:
raise RuntimeError("ProcessGroupWrapper unsupported without GLOO backend.")
assert _GLOO_AVAILABLE, "ProcessGroupWrapper unsupported without GLOO backend."
# (whc) this appears to be just for the gloo backend? if so, `default_pg_timeout` is appropriate...
@ -5219,10 +5205,9 @@ def split_group(
split_pg.bound_device_id = device_id # type: ignore[union-attr]
split_backend_class = split_pg._get_backend(torch.device("cuda"))
split_backend_class._set_sequence_number_for_group()
if split_pg.group_name != group_name:
raise AssertionError(
f"group name should be set to {group_name} but got {split_pg.group_name}"
)
assert split_pg.group_name == group_name, (
f"group name should be set to {group_name} but got {split_pg.group_name}"
)
# update global state
_world.pg_map[split_pg] = (backend, split_pg.get_group_store())
@ -5354,10 +5339,9 @@ def _new_group_with_tag(
if device_id is None:
device_id = default_pg.bound_device_id
elif default_pg.bound_device_id is not None:
if device_id != default_pg.bound_device_id:
raise AssertionError(
"Mismatched bound device between new pg and the default pg."
)
assert device_id == default_pg.bound_device_id, (
"Mismatched bound device between new pg and the default pg."
)
default_backend, default_store = _world.pg_map[default_pg]
global_rank = default_pg.rank()
global_world_size = default_pg.size()
@ -5671,25 +5655,22 @@ def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGro
def _find_or_create_pg_by_ranks_and_tag(
tag: str, ranks: list[int], stride: int
) -> ProcessGroup:
if len(ranks) % stride != 0:
raise ValueError(
f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
)
assert len(ranks) % stride == 0, (
f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
)
my_rank = get_rank()
my_ranks = None
if stride == len(ranks):
my_ranks = ranks.copy()
if my_rank not in my_ranks:
raise RuntimeError("rankset doesn't include the current node")
assert my_rank in my_ranks, "rankset doesn't include the current node"
else:
for i in range(0, len(ranks), stride):
rank_set = ranks[i : i + stride]
if my_rank in rank_set:
my_ranks = rank_set
if my_ranks is None:
raise RuntimeError("rankset doesn't include the current node")
assert my_ranks is not None, "rankset doesn't include the current node"
my_ranks = sorted(my_ranks)