mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[distributed] Replace assert statements with AssertionError exceptions (#165216)"
This reverts commit 74db92b21868b7e9e77cc966e5d57a8246723cbd.
Reverted https://github.com/pytorch/pytorch/pull/165216 on behalf of https://github.com/clee2000 due to I think this broke distributed/test_pg_wrapper.py::ProcessGroupNCCLWrapperTest::test_debug_level_detail_no_gloo [GH job link](https://github.com/pytorch/pytorch/actions/runs/18492765290/job/52693842750) [HUD commit link](74db92b218
), note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/165216#issuecomment-3402838765))
This commit is contained in:
@ -1526,8 +1526,7 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) ->
|
||||
group = _get_default_group()
|
||||
if _rank_not_in_group(group):
|
||||
raise ValueError("Invalid process group specified")
|
||||
if not isinstance(group, ProcessGroup):
|
||||
raise AssertionError(f"Expected ProcessGroup, got {type(group)}")
|
||||
assert isinstance(group, ProcessGroup)
|
||||
devices = group._device_types
|
||||
backends = set()
|
||||
if torch.device("cpu") in devices and is_gloo_available():
|
||||
@ -1666,14 +1665,13 @@ def init_process_group(
|
||||
if "torch._dynamo" in sys.modules:
|
||||
torch._dynamo.trace_rules.clear_lru_cache()
|
||||
|
||||
if not ((store is None) or (init_method is None)):
|
||||
raise AssertionError("Cannot specify both init_method and store.")
|
||||
assert (store is None) or (init_method is None), (
|
||||
"Cannot specify both init_method and store."
|
||||
)
|
||||
|
||||
if store is not None:
|
||||
if not world_size > 0:
|
||||
raise AssertionError("world_size must be positive if using store")
|
||||
if not rank >= 0:
|
||||
raise AssertionError("rank must be non-negative if using store")
|
||||
assert world_size > 0, "world_size must be positive if using store"
|
||||
assert rank >= 0, "rank must be non-negative if using store"
|
||||
elif init_method is None:
|
||||
init_method = "env://"
|
||||
|
||||
@ -1947,8 +1945,7 @@ def _new_process_group_helper(
|
||||
backend_config = BackendConfig(backend)
|
||||
# Set the default backend when single backend is passed in.
|
||||
if "," not in str(backend) and ":" not in str(backend):
|
||||
if backend not in Backend.backend_type_map:
|
||||
raise AssertionError(f"Unknown backend type {backend}")
|
||||
assert backend in Backend.backend_type_map, f"Unknown backend type {backend}"
|
||||
if backend == Backend.UNDEFINED:
|
||||
# Currently when backend is UNDEFINED, only one backend will be initialized
|
||||
# we use nccl (if cuda is available) or gloo as default backend
|
||||
@ -2018,10 +2015,9 @@ def _new_process_group_helper(
|
||||
if not is_nccl_available():
|
||||
raise RuntimeError("Distributed package doesn't have NCCL built in")
|
||||
if backend_options is not None:
|
||||
if not isinstance(backend_options, ProcessGroupNCCL.Options):
|
||||
raise AssertionError(
|
||||
"Expected backend_options argument to be of type ProcessGroupNCCL.Options"
|
||||
)
|
||||
assert isinstance(backend_options, ProcessGroupNCCL.Options), (
|
||||
"Expected backend_options argument to be of type ProcessGroupNCCL.Options"
|
||||
)
|
||||
if backend_options._timeout != timeout:
|
||||
warnings.warn(
|
||||
"backend_options._timeout was specified, "
|
||||
@ -2071,8 +2067,9 @@ def _new_process_group_helper(
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.XCCL
|
||||
else:
|
||||
if backend_str.upper() not in Backend._plugins:
|
||||
raise AssertionError(f"Unknown c10d backend type {backend_str.upper()}")
|
||||
assert backend_str.upper() in Backend._plugins, (
|
||||
f"Unknown c10d backend type {backend_str.upper()}"
|
||||
)
|
||||
|
||||
backend_plugin = Backend._plugins[backend_str.upper()]
|
||||
creator_fn = backend_plugin.creator_fn
|
||||
@ -2097,16 +2094,10 @@ def _new_process_group_helper(
|
||||
|
||||
# Set sequence numbers for gloo and nccl backends.
|
||||
if backend_str == Backend.GLOO:
|
||||
if not isinstance(backend_class, ProcessGroupGloo):
|
||||
raise AssertionError(
|
||||
f"Expected ProcessGroupGloo, got {type(backend_class)}"
|
||||
)
|
||||
assert isinstance(backend_class, ProcessGroupGloo)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
elif backend_str == Backend.NCCL:
|
||||
if not isinstance(backend_class, ProcessGroupNCCL):
|
||||
raise AssertionError(
|
||||
f"Expected ProcessGroupNCCL, got {type(backend_class)}"
|
||||
)
|
||||
assert isinstance(backend_class, ProcessGroupNCCL)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
# If the type is a subclass of ProcessGroup then return this process group immediately
|
||||
@ -2153,10 +2144,8 @@ def _new_process_group_helper(
|
||||
pg._register_backend(torch.device(device), backend_type, backend_class)
|
||||
|
||||
# set group_name and group_dsec to backend
|
||||
if group_name is None:
|
||||
raise AssertionError("group_name must not be None")
|
||||
if group_desc is None:
|
||||
raise AssertionError("group_desc must not be None")
|
||||
assert group_name is not None
|
||||
assert group_desc is not None
|
||||
pg._set_group_name(group_name)
|
||||
pg._set_group_desc(group_desc)
|
||||
|
||||
@ -2202,8 +2191,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
|
||||
else:
|
||||
pg = group
|
||||
|
||||
if pg is None:
|
||||
raise AssertionError("Process group cannot be None")
|
||||
assert pg is not None
|
||||
if _world.pg_map.get(pg, None) is None:
|
||||
raise ValueError("Invalid process group specified")
|
||||
|
||||
@ -2293,8 +2281,7 @@ def _abort_process_group(group: Optional[ProcessGroup] = None):
|
||||
|
||||
pg = group or GroupMember.WORLD
|
||||
|
||||
if pg is None:
|
||||
raise AssertionError("Process group cannot be None")
|
||||
assert pg is not None
|
||||
if _world.pg_map.get(pg, None) is None:
|
||||
raise ValueError("Invalid process group specified or has been destroyed.")
|
||||
|
||||
@ -3351,8 +3338,7 @@ def gather_object(
|
||||
if my_group_rank != group_dst:
|
||||
return
|
||||
|
||||
if object_gather_list is None:
|
||||
raise AssertionError("Must provide object_gather_list on dst rank")
|
||||
assert object_gather_list is not None, "Must provide object_gather_list on dst rank"
|
||||
# pyrefly: ignore # unbound-name
|
||||
for i, tensor in enumerate(output_tensors):
|
||||
tensor = tensor.type(torch.uint8)
|
||||
@ -3608,8 +3594,9 @@ def recv_object_list(
|
||||
rank_objects = get_global_rank(group, group_src)
|
||||
else:
|
||||
rank_objects = recv(object_tensor, group=group, group_src=group_src)
|
||||
if rank_sizes != rank_objects:
|
||||
raise AssertionError("Mismatch in return ranks for object sizes and objects.")
|
||||
assert rank_sizes == rank_objects, (
|
||||
"Mismatch in return ranks for object sizes and objects."
|
||||
)
|
||||
# Deserialize objects using their stored sizes.
|
||||
offset = 0
|
||||
for i, obj_size in enumerate(object_sizes_tensor):
|
||||
@ -5016,8 +5003,7 @@ def _create_process_group_wrapper(
|
||||
world_size: int,
|
||||
timeout: timedelta = default_pg_timeout,
|
||||
):
|
||||
if not _GLOO_AVAILABLE:
|
||||
raise RuntimeError("ProcessGroupWrapper unsupported without GLOO backend.")
|
||||
assert _GLOO_AVAILABLE, "ProcessGroupWrapper unsupported without GLOO backend."
|
||||
|
||||
# (whc) this appears to be just for the gloo backend? if so, `default_pg_timeout` is appropriate...
|
||||
|
||||
@ -5219,10 +5205,9 @@ def split_group(
|
||||
split_pg.bound_device_id = device_id # type: ignore[union-attr]
|
||||
split_backend_class = split_pg._get_backend(torch.device("cuda"))
|
||||
split_backend_class._set_sequence_number_for_group()
|
||||
if split_pg.group_name != group_name:
|
||||
raise AssertionError(
|
||||
f"group name should be set to {group_name} but got {split_pg.group_name}"
|
||||
)
|
||||
assert split_pg.group_name == group_name, (
|
||||
f"group name should be set to {group_name} but got {split_pg.group_name}"
|
||||
)
|
||||
|
||||
# update global state
|
||||
_world.pg_map[split_pg] = (backend, split_pg.get_group_store())
|
||||
@ -5354,10 +5339,9 @@ def _new_group_with_tag(
|
||||
if device_id is None:
|
||||
device_id = default_pg.bound_device_id
|
||||
elif default_pg.bound_device_id is not None:
|
||||
if device_id != default_pg.bound_device_id:
|
||||
raise AssertionError(
|
||||
"Mismatched bound device between new pg and the default pg."
|
||||
)
|
||||
assert device_id == default_pg.bound_device_id, (
|
||||
"Mismatched bound device between new pg and the default pg."
|
||||
)
|
||||
default_backend, default_store = _world.pg_map[default_pg]
|
||||
global_rank = default_pg.rank()
|
||||
global_world_size = default_pg.size()
|
||||
@ -5671,25 +5655,22 @@ def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGro
|
||||
def _find_or_create_pg_by_ranks_and_tag(
|
||||
tag: str, ranks: list[int], stride: int
|
||||
) -> ProcessGroup:
|
||||
if len(ranks) % stride != 0:
|
||||
raise ValueError(
|
||||
f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
|
||||
)
|
||||
assert len(ranks) % stride == 0, (
|
||||
f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
|
||||
)
|
||||
|
||||
my_rank = get_rank()
|
||||
my_ranks = None
|
||||
|
||||
if stride == len(ranks):
|
||||
my_ranks = ranks.copy()
|
||||
if my_rank not in my_ranks:
|
||||
raise RuntimeError("rankset doesn't include the current node")
|
||||
assert my_rank in my_ranks, "rankset doesn't include the current node"
|
||||
else:
|
||||
for i in range(0, len(ranks), stride):
|
||||
rank_set = ranks[i : i + stride]
|
||||
if my_rank in rank_set:
|
||||
my_ranks = rank_set
|
||||
if my_ranks is None:
|
||||
raise RuntimeError("rankset doesn't include the current node")
|
||||
assert my_ranks is not None, "rankset doesn't include the current node"
|
||||
|
||||
my_ranks = sorted(my_ranks)
|
||||
|
||||
|
Reference in New Issue
Block a user