Revert "[distributed] Replace assert statements with AssertionError exceptions (#165216)"

This reverts commit 74db92b21868b7e9e77cc966e5d57a8246723cbd.

Reverted https://github.com/pytorch/pytorch/pull/165216 on behalf of https://github.com/clee2000 due to I think this broke distributed/test_pg_wrapper.py::ProcessGroupNCCLWrapperTest::test_debug_level_detail_no_gloo [GH job link](https://github.com/pytorch/pytorch/actions/runs/18492765290/job/52693842750) [HUD commit link](74db92b218), note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/165216#issuecomment-3402838765))
This commit is contained in:
PyTorch MergeBot
2025-10-14 17:05:16 +00:00
parent 5eddbb5e47
commit d2494cbb2b
11 changed files with 136 additions and 222 deletions

View File

@ -15,8 +15,7 @@ _module_state_mapping: weakref.WeakKeyDictionary[
def _insert_module_state(module: nn.Module, state: _State) -> None:
global _module_state_mapping
if module in _module_state_mapping:
raise AssertionError(f"Inserting {module} more than once.")
assert module not in _module_state_mapping, f"Inserting {module} more than once."
_module_state_mapping[module] = weakref.ref(state)

View File

@ -71,8 +71,7 @@ def _gloo_factory(
) -> ProcessGroup:
from torch.distributed import ProcessGroupGloo
if len(kwargs) != 0:
raise AssertionError("Gloo backend received unexpected kwargs")
assert len(kwargs) == 0, "Gloo backend received unexpected kwargs"
backend_class = ProcessGroupGloo(store, rank, world_size, timeout)
backend_class._set_sequence_number_for_group()

View File

@ -193,8 +193,7 @@ def all_gather_tensor(
:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
that information and perform collective algebraic optimization. Use other forms of input for that.
"""
if not self.is_contiguous():
raise AssertionError("Tensor must be contiguous for all_gather_tensor")
assert self.is_contiguous()
group_name = _resolve_group_name(group, tag)
group_size = c10d._get_group_size_by_name(group_name)
tensor = torch.ops._c10d_functional.all_gather_into_tensor(
@ -269,10 +268,9 @@ def reduce_scatter_tensor(
group_name = _resolve_group_name(group, tag)
group_size = c10d._get_group_size_by_name(group_name)
if self.size(scatter_dim) % group_size != 0:
raise AssertionError(
f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size})"
)
assert self.size(scatter_dim) % group_size == 0, (
f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size})"
)
if scatter_dim != 0:
tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
self = torch.cat(tensor_list)
@ -309,10 +307,9 @@ def reduce_scatter_tensor_autograd(
group_name = _resolve_group_name(group, tag)
group_size = c10d._get_group_size_by_name(group_name)
if self.size(scatter_dim) % group_size != 0:
raise AssertionError(
f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
)
assert self.size(scatter_dim) % group_size == 0, (
f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
)
if scatter_dim != 0:
tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
self = torch.cat(tensor_list)
@ -409,15 +406,11 @@ def reduce_scatter_tensor_coalesced(
group_name = _resolve_group_name(group, tag)
group_size = c10d._get_group_size_by_name(group_name)
if len(scatter_dim) != len(inputs):
raise AssertionError(
f"Length of scatter_dim ({len(scatter_dim)}) must equal length of inputs ({len(inputs)})"
)
assert len(scatter_dim) == len(inputs)
for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)):
if tensor.size(dim) % group_size != 0:
raise AssertionError(
f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
)
assert tensor.size(dim) % group_size == 0, (
f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
)
if dim != 0:
tensor_list = torch.chunk(tensor, group_size, dim=dim)
inputs[idx] = torch.cat(tensor_list)
@ -435,8 +428,7 @@ def reduce_scatter_tensor_coalesced(
# This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias.
# Today, this maps 1:1 with "aten ops that are views".
def _is_view_op(tgt):
if not isinstance(tgt, torch._ops.OpOverload):
raise AssertionError(f"Expected torch._ops.OpOverload, got {type(tgt)}")
assert isinstance(tgt, torch._ops.OpOverload)
# Don't apply the view optimization to any `CompositeImplicitAutograd` ops.
# See issue: https://github.com/pytorch/pytorch/issues/133421
if torch._C._dispatch_has_kernel_for_dispatch_key(
@ -473,25 +465,20 @@ def all_to_all_single(
that information and perform collective algebraic optimization. Use other forms of input for that.
"""
if output_split_sizes is not None:
if not all(
assert all(
isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
):
raise AssertionError(
f"All output_split_sizes must be int or SymInt, got {output_split_sizes}"
)
), output_split_sizes
if input_split_sizes is not None:
if not all(isinstance(size, (int, torch.SymInt)) for size in input_split_sizes):
raise AssertionError(
f"All input_split_sizes must be int or SymInt, got {input_split_sizes}"
)
assert all(
isinstance(size, (int, torch.SymInt)) for size in input_split_sizes
), input_split_sizes
group_name = _resolve_group_name(group, tag)
group_size = c10d._get_group_size_by_name(group_name)
if output_split_sizes is None or input_split_sizes is None:
if not (output_split_sizes is None and input_split_sizes is None):
raise AssertionError(
"output_split_sizes and input_split_sizes must either be "
"specified together or both set to None"
)
assert output_split_sizes is None and input_split_sizes is None, (
"output_split_sizes and input_split_sizes must either be "
"specified together or both set to None"
)
output_split_sizes = [self.shape[0] // group_size] * group_size
input_split_sizes = output_split_sizes
tensor = torch.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined]
@ -514,26 +501,21 @@ def all_to_all_single_autograd(
Same as all_to_all_single but supports autograd.
"""
if output_split_sizes is not None:
if not all(
assert all(
isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
):
raise AssertionError(
f"All output_split_sizes must be int or SymInt, got {output_split_sizes}"
)
), output_split_sizes
if input_split_sizes is not None:
if not all(isinstance(size, (int, torch.SymInt)) for size in input_split_sizes):
raise AssertionError(
f"All input_split_sizes must be int or SymInt, got {input_split_sizes}"
)
assert all(
isinstance(size, (int, torch.SymInt)) for size in input_split_sizes
), input_split_sizes
group_name = _resolve_group_name(group, tag)
group_size = c10d._get_group_size_by_name(group_name)
if output_split_sizes is None or input_split_sizes is None:
if not (output_split_sizes is None and input_split_sizes is None):
raise AssertionError(
"output_split_sizes and input_split_sizes must either be "
"specified together or both set to None"
)
assert output_split_sizes is None and input_split_sizes is None, (
"output_split_sizes and input_split_sizes must either be "
"specified together or both set to None"
)
output_split_sizes = [self.shape[0] // group_size] * group_size
input_split_sizes = output_split_sizes
tensor = torch.ops._c10d_functional_autograd.all_to_all_single( # type: ignore[attr-defined]
@ -616,10 +598,7 @@ class AsyncCollectiveTensor(torch.Tensor):
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
if meta is not None:
raise AssertionError(
"meta must be None for AsyncCollectiveTensor unflatten"
)
assert meta is None
elem = inner_tensors["elem"]
return AsyncCollectiveTensor(elem)
@ -669,10 +648,7 @@ class AsyncCollectiveTensor(torch.Tensor):
def wrap(e: torch.Tensor):
# wait_tensor is idepotent and will do stream sync only once
if isinstance(e, AsyncCollectiveTensor):
raise AssertionError(
"Cannot wrap an AsyncCollectiveTensor inside another AsyncCollectiveTensor"
)
assert not isinstance(e, AsyncCollectiveTensor)
res = AsyncCollectiveTensor(e)
return res
@ -746,10 +722,9 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int
group_size = len(rankset)
tag = tag or c10d._get_group_tag(group)
elif isinstance(group, DeviceMesh):
if group.ndim != 1:
raise AssertionError(
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
)
assert group.ndim == 1, (
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
)
# TODO: it should run collective in the whole mesh instead of dim 0
pg = group.get_group()
rankset = dist.get_process_group_ranks(pg)
@ -788,10 +763,9 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
elif isinstance(group, str):
return group
elif isinstance(group, DeviceMesh):
if group.ndim != 1:
raise AssertionError(
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
)
assert group.ndim == 1, (
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
)
return group._dim_group_names[0]
elif isinstance(group, tuple):
if (
@ -1081,14 +1055,12 @@ def all_gather_tensor_inplace(
tag: str = "",
gather_dim: int = 0,
):
if async_op:
raise AssertionError(
"Can't remap async version of inplace op to functional collective"
)
assert not async_op, (
"Can't remap async version of inplace op to functional collective"
)
group = group or dist.group.WORLD
if group is None:
raise AssertionError("group cannot be None")
assert group is not None
return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag))
@ -1102,14 +1074,12 @@ def reduce_scatter_tensor_inplace(
scatter_dim: int = 0,
tag: str = "",
):
if async_op:
raise AssertionError(
"Can't remap async version of inplace op to functional collective"
)
assert not async_op, (
"Can't remap async version of inplace op to functional collective"
)
group = group or dist.group.WORLD
if group is None:
raise AssertionError("group cannot be None")
assert group is not None
return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag))
@ -1133,14 +1103,12 @@ def all_reduce_inplace(
async_op: bool = False,
tag: str = "",
):
if async_op:
raise AssertionError(
"Can't remap async version of inplace op to functional collective"
)
assert not async_op, (
"Can't remap async version of inplace op to functional collective"
)
group = group or dist.group.WORLD
if group is None:
raise AssertionError("group cannot be None")
assert group is not None
return tensor.copy_(all_reduce(tensor, op, group, tag))
@ -1154,14 +1122,12 @@ def all_to_all_inplace(
async_op=False,
tag: str = "",
):
if async_op:
raise AssertionError(
"Can't remap async version of inplace op to functional collective"
)
assert not async_op, (
"Can't remap async version of inplace op to functional collective"
)
group = group or dist.group.WORLD
if group is None:
raise AssertionError("group cannot be None")
assert group is not None
return output.copy_(
all_to_all_single(
@ -1181,16 +1147,15 @@ def all_gather_inplace(
async_op=False,
tag: str = "",
):
if async_op:
raise AssertionError(
"Can't remap async version of inplace op to functional collective"
)
if tensor.dim() != 0 and not all(t.size(0) == tensor.size(0) for t in tensor_list):
raise AssertionError("Remapping variable size all_gather is not yet supported")
assert not async_op, (
"Can't remap async version of inplace op to functional collective"
)
assert tensor.dim() == 0 or all(t.size(0) == tensor.size(0) for t in tensor_list), (
"Remapping variable size all_gather is not yet supported"
)
group = group or dist.group.WORLD
if group is None:
raise AssertionError("group cannot be None")
assert group is not None
output = all_gather_tensor(tensor, 0, group, tag)

View File

@ -97,11 +97,10 @@ def _all_to_all_single(
group_size: int,
):
if output_split_sizes is None or input_split_sizes is None:
if not (output_split_sizes is None and input_split_sizes is None):
raise AssertionError(
"output_split_sizes and input_split_sizes must either be "
"specified together or both set to None"
)
assert output_split_sizes is None and input_split_sizes is None, (
"output_split_sizes and input_split_sizes must either be "
"specified together or both set to None"
)
output_split_sizes = [input.shape[0] // group_size] * group_size
input_split_sizes = output_split_sizes

View File

@ -184,18 +184,12 @@ def _iterate_state_dict(
if companion_obj is not None:
if isinstance(companion_obj, DTensor):
if not isinstance(ret, DTensor):
raise AssertionError(
"ret must be a DTensor when companion_obj is a DTensor"
)
assert isinstance(ret, DTensor)
companion_obj._local_tensor.copy_(
ret._local_tensor, non_blocking=non_blocking
)
elif isinstance(companion_obj, ShardedTensor):
if not isinstance(ret, ShardedTensor):
raise AssertionError(
"ret must be a ShardedTensor when companion_obj is a ShardedTensor"
)
assert isinstance(ret, ShardedTensor)
for idx, shard in enumerate(companion_obj.local_shards()):
shard.tensor.copy_(
ret.local_shards()[idx].tensor, non_blocking=non_blocking
@ -554,8 +548,7 @@ def _broadcast_tensors(
for key in keys:
if dist.get_rank() == 0:
full_state = full_state_dict[key]
if not isinstance(full_state, torch.Tensor):
raise AssertionError("full_state must be a torch.Tensor")
assert isinstance(full_state, torch.Tensor)
full_tensor = full_state.detach().to(pg_device)
else:
tensor_info = full_state_dict[key]
@ -714,8 +707,7 @@ def _distribute_state_dict(
elif value.dim() == 0:
local_state_dict[key] = value.cpu()
else:
if not isinstance(value, torch.Tensor):
raise AssertionError("value must be a torch.Tensor")
assert isinstance(value, torch.Tensor)
local_state = local_state_dict.get(key, None)
if local_state is None:
continue

View File

@ -104,10 +104,7 @@ def broadcast(
if pg is not None:
broadcast_list = [sync_obj]
dist.broadcast_object_list(broadcast_list, src=rank, group=pg)
if len(broadcast_list) != 1:
raise AssertionError(
f"Expected broadcast_list to have exactly 1 element, got {len(broadcast_list)}"
)
assert len(broadcast_list) == 1
sync_obj = broadcast_list[0]
# failure in any rank will trigger a throw in every rank.
@ -243,10 +240,8 @@ def all_gather_object_enforce_type(
def _summarize_ranks(ranks: Iterable[int]) -> str:
ranks = sorted(ranks)
if min(ranks) < 0:
raise AssertionError("ranks should all be positive")
if len(set(ranks)) != len(ranks):
raise AssertionError("ranks should not contain duplicates")
assert min(ranks) >= 0, "ranks should all be positive"
assert len(set(ranks)) == len(ranks), "ranks should not contain duplicates"
curr: Optional[Union[int, range]] = None
ranges = []
while ranks:
@ -260,8 +255,7 @@ def _summarize_ranks(ranks: Iterable[int]) -> str:
step = x - curr
curr = range(curr, x + step, step)
else:
if not isinstance(curr, range):
raise AssertionError("curr must be an instance of range")
assert isinstance(curr, range)
if x == curr.stop:
curr = range(curr.start, curr.stop + curr.step, curr.step)
else:

View File

@ -213,16 +213,14 @@ else:
if _layout
else _MeshLayout(self.mesh.size(), self.mesh.stride())
)
if not self._layout.check_non_overlap():
raise AssertionError(
"Please use a non-overlapping layout when creating a DeviceMesh."
)
assert self._layout.check_non_overlap(), (
"Please use a non-overlapping layout when creating a DeviceMesh."
)
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
if self._layout.numel() != self.mesh.numel():
raise AssertionError(
"Please use a valid layout when creating a DeviceMesh."
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
)
assert self._layout.numel() == self.mesh.numel(), (
"Please use a valid layout when creating a DeviceMesh."
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
)
# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
@ -247,10 +245,7 @@ else:
# calculate the coordinates of the current global rank on the mesh
rank_coords = (self.mesh == _rank).nonzero()
if rank_coords.size(0) not in (0, 1):
raise AssertionError(
f"rank_coords.size(0) must be 0 or 1, got {rank_coords.size(0)}"
)
assert rank_coords.size(0) in (0, 1)
self._coordinate_on_dim: Optional[list[int]] = (
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
)
@ -595,10 +590,7 @@ else:
if isinstance(mesh_dim, str)
else mesh_dim
)
if not isinstance(mesh_dim, int):
raise AssertionError(
f"mesh_dim must be an int, got {type(mesh_dim)}"
)
assert isinstance(mesh_dim, int)
return not_none(_resolve_process_group(self._dim_group_names[mesh_dim]))
def get_all_groups(self) -> list[ProcessGroup]:
@ -717,8 +709,9 @@ else:
root_mesh = self._get_root_mesh()
child_mesh_dim_names = self._mesh_dim_names
if root_mesh and child_mesh_dim_names:
if len(child_mesh_dim_names) != 1:
raise AssertionError("The submesh can only be a 1D mesh.")
assert len(child_mesh_dim_names) == 1, (
"The submesh can only be a 1D mesh."
)
child_mesh_dim_name = child_mesh_dim_names[0]
return root_mesh._get_mesh_dim_by_name(child_mesh_dim_name)
return None
@ -1055,10 +1048,9 @@ else:
mesh_dim = 0
mesh_dim_group = not_none(self.get_group(mesh_dim))
if not isinstance(mesh_dim_group, ProcessGroup):
raise AssertionError(
"We expect ProcessGroup before calling `get_rank`!"
)
assert isinstance(mesh_dim_group, ProcessGroup), (
"We expect ProcessGroup before calling `get_rank`!"
)
return not_none(get_rank(mesh_dim_group))
def get_coordinate(self) -> Optional[list[int]]:

View File

@ -1526,8 +1526,7 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) ->
group = _get_default_group()
if _rank_not_in_group(group):
raise ValueError("Invalid process group specified")
if not isinstance(group, ProcessGroup):
raise AssertionError(f"Expected ProcessGroup, got {type(group)}")
assert isinstance(group, ProcessGroup)
devices = group._device_types
backends = set()
if torch.device("cpu") in devices and is_gloo_available():
@ -1666,14 +1665,13 @@ def init_process_group(
if "torch._dynamo" in sys.modules:
torch._dynamo.trace_rules.clear_lru_cache()
if not ((store is None) or (init_method is None)):
raise AssertionError("Cannot specify both init_method and store.")
assert (store is None) or (init_method is None), (
"Cannot specify both init_method and store."
)
if store is not None:
if not world_size > 0:
raise AssertionError("world_size must be positive if using store")
if not rank >= 0:
raise AssertionError("rank must be non-negative if using store")
assert world_size > 0, "world_size must be positive if using store"
assert rank >= 0, "rank must be non-negative if using store"
elif init_method is None:
init_method = "env://"
@ -1947,8 +1945,7 @@ def _new_process_group_helper(
backend_config = BackendConfig(backend)
# Set the default backend when single backend is passed in.
if "," not in str(backend) and ":" not in str(backend):
if backend not in Backend.backend_type_map:
raise AssertionError(f"Unknown backend type {backend}")
assert backend in Backend.backend_type_map, f"Unknown backend type {backend}"
if backend == Backend.UNDEFINED:
# Currently when backend is UNDEFINED, only one backend will be initialized
# we use nccl (if cuda is available) or gloo as default backend
@ -2018,10 +2015,9 @@ def _new_process_group_helper(
if not is_nccl_available():
raise RuntimeError("Distributed package doesn't have NCCL built in")
if backend_options is not None:
if not isinstance(backend_options, ProcessGroupNCCL.Options):
raise AssertionError(
"Expected backend_options argument to be of type ProcessGroupNCCL.Options"
)
assert isinstance(backend_options, ProcessGroupNCCL.Options), (
"Expected backend_options argument to be of type ProcessGroupNCCL.Options"
)
if backend_options._timeout != timeout:
warnings.warn(
"backend_options._timeout was specified, "
@ -2071,8 +2067,9 @@ def _new_process_group_helper(
)
backend_type = ProcessGroup.BackendType.XCCL
else:
if backend_str.upper() not in Backend._plugins:
raise AssertionError(f"Unknown c10d backend type {backend_str.upper()}")
assert backend_str.upper() in Backend._plugins, (
f"Unknown c10d backend type {backend_str.upper()}"
)
backend_plugin = Backend._plugins[backend_str.upper()]
creator_fn = backend_plugin.creator_fn
@ -2097,16 +2094,10 @@ def _new_process_group_helper(
# Set sequence numbers for gloo and nccl backends.
if backend_str == Backend.GLOO:
if not isinstance(backend_class, ProcessGroupGloo):
raise AssertionError(
f"Expected ProcessGroupGloo, got {type(backend_class)}"
)
assert isinstance(backend_class, ProcessGroupGloo)
backend_class._set_sequence_number_for_group()
elif backend_str == Backend.NCCL:
if not isinstance(backend_class, ProcessGroupNCCL):
raise AssertionError(
f"Expected ProcessGroupNCCL, got {type(backend_class)}"
)
assert isinstance(backend_class, ProcessGroupNCCL)
backend_class._set_sequence_number_for_group()
# If the type is a subclass of ProcessGroup then return this process group immediately
@ -2153,10 +2144,8 @@ def _new_process_group_helper(
pg._register_backend(torch.device(device), backend_type, backend_class)
# set group_name and group_dsec to backend
if group_name is None:
raise AssertionError("group_name must not be None")
if group_desc is None:
raise AssertionError("group_desc must not be None")
assert group_name is not None
assert group_desc is not None
pg._set_group_name(group_name)
pg._set_group_desc(group_desc)
@ -2202,8 +2191,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
else:
pg = group
if pg is None:
raise AssertionError("Process group cannot be None")
assert pg is not None
if _world.pg_map.get(pg, None) is None:
raise ValueError("Invalid process group specified")
@ -2293,8 +2281,7 @@ def _abort_process_group(group: Optional[ProcessGroup] = None):
pg = group or GroupMember.WORLD
if pg is None:
raise AssertionError("Process group cannot be None")
assert pg is not None
if _world.pg_map.get(pg, None) is None:
raise ValueError("Invalid process group specified or has been destroyed.")
@ -3351,8 +3338,7 @@ def gather_object(
if my_group_rank != group_dst:
return
if object_gather_list is None:
raise AssertionError("Must provide object_gather_list on dst rank")
assert object_gather_list is not None, "Must provide object_gather_list on dst rank"
# pyrefly: ignore # unbound-name
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8)
@ -3608,8 +3594,9 @@ def recv_object_list(
rank_objects = get_global_rank(group, group_src)
else:
rank_objects = recv(object_tensor, group=group, group_src=group_src)
if rank_sizes != rank_objects:
raise AssertionError("Mismatch in return ranks for object sizes and objects.")
assert rank_sizes == rank_objects, (
"Mismatch in return ranks for object sizes and objects."
)
# Deserialize objects using their stored sizes.
offset = 0
for i, obj_size in enumerate(object_sizes_tensor):
@ -5016,8 +5003,7 @@ def _create_process_group_wrapper(
world_size: int,
timeout: timedelta = default_pg_timeout,
):
if not _GLOO_AVAILABLE:
raise RuntimeError("ProcessGroupWrapper unsupported without GLOO backend.")
assert _GLOO_AVAILABLE, "ProcessGroupWrapper unsupported without GLOO backend."
# (whc) this appears to be just for the gloo backend? if so, `default_pg_timeout` is appropriate...
@ -5219,10 +5205,9 @@ def split_group(
split_pg.bound_device_id = device_id # type: ignore[union-attr]
split_backend_class = split_pg._get_backend(torch.device("cuda"))
split_backend_class._set_sequence_number_for_group()
if split_pg.group_name != group_name:
raise AssertionError(
f"group name should be set to {group_name} but got {split_pg.group_name}"
)
assert split_pg.group_name == group_name, (
f"group name should be set to {group_name} but got {split_pg.group_name}"
)
# update global state
_world.pg_map[split_pg] = (backend, split_pg.get_group_store())
@ -5354,10 +5339,9 @@ def _new_group_with_tag(
if device_id is None:
device_id = default_pg.bound_device_id
elif default_pg.bound_device_id is not None:
if device_id != default_pg.bound_device_id:
raise AssertionError(
"Mismatched bound device between new pg and the default pg."
)
assert device_id == default_pg.bound_device_id, (
"Mismatched bound device between new pg and the default pg."
)
default_backend, default_store = _world.pg_map[default_pg]
global_rank = default_pg.rank()
global_world_size = default_pg.size()
@ -5671,25 +5655,22 @@ def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGro
def _find_or_create_pg_by_ranks_and_tag(
tag: str, ranks: list[int], stride: int
) -> ProcessGroup:
if len(ranks) % stride != 0:
raise ValueError(
f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
)
assert len(ranks) % stride == 0, (
f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
)
my_rank = get_rank()
my_ranks = None
if stride == len(ranks):
my_ranks = ranks.copy()
if my_rank not in my_ranks:
raise RuntimeError("rankset doesn't include the current node")
assert my_rank in my_ranks, "rankset doesn't include the current node"
else:
for i in range(0, len(ranks), stride):
rank_set = ranks[i : i + stride]
if my_rank in rank_set:
my_ranks = rank_set
if my_ranks is None:
raise RuntimeError("rankset doesn't include the current node")
assert my_ranks is not None, "rankset doesn't include the current node"
my_ranks = sorted(my_ranks)

View File

@ -83,10 +83,9 @@ def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwa
world_size = world_size_opt
if rank != -1 or world_size != -1 or world_size_opt is None:
query_dict = _query_to_dict(result.query)
if "rank" in query_dict or "world_size" in query_dict:
raise AssertionError(
f"The url: {url} has node-specific arguments(rank, world_size) already."
)
assert "rank" not in query_dict and "world_size" not in query_dict, (
f"The url: {url} has node-specific arguments(rank, world_size) already."
)
if rank != -1:
query_dict["rank"] = str(rank)
if world_size != -1 or world_size_opt is None:
@ -228,8 +227,7 @@ def _tcp_rendezvous_handler(
world_size = int(query_dict["world_size"])
use_libuv = _get_use_libuv_from_query_dict(query_dict)
if result.hostname is None:
raise AssertionError("hostname cannot be None")
assert result.hostname is not None
store = _create_c10d_store(
result.hostname, result.port, rank, world_size, timeout, use_libuv

View File

@ -792,12 +792,8 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]:
def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str]]:
# If ``args`` not passed, defaults to ``sys.argv[:1]``
min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
if not (0 < min_nodes <= max_nodes):
raise AssertionError(
f"min_nodes must be > 0 and <= max_nodes, got min_nodes={min_nodes}, max_nodes={max_nodes}"
)
if args.max_restarts < 0:
raise AssertionError("max_restarts must be >= 0")
assert 0 < min_nodes <= max_nodes
assert args.max_restarts >= 0
if (
hasattr(args, "master_addr")
@ -837,8 +833,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str
if args.local_ranks_filter:
try:
ranks = set(map(int, args.local_ranks_filter.split(",")))
if not ranks:
raise AssertionError("ranks set cannot be empty")
assert ranks
except Exception as e:
raise ValueError(
"--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2"

View File

@ -69,8 +69,9 @@ def _unpack_kwargs(
flat_args: tuple[Any, ...], kwarg_keys: tuple[str, ...]
) -> tuple[tuple[Any, ...], dict[str, Any]]:
"""See _pack_kwargs."""
if len(kwarg_keys) > len(flat_args):
raise AssertionError(f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}")
assert len(kwarg_keys) <= len(flat_args), (
f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}"
)
if len(kwarg_keys) == 0:
return flat_args, {}
args = flat_args[: -len(kwarg_keys)]
@ -126,8 +127,7 @@ def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies):
if isinstance(obj, PackedSequence):
output.data.record_stream(current_stream) # type: ignore[arg-type]
else:
if not isinstance(output, torch.Tensor):
raise AssertionError("output must be a torch.Tensor")
assert isinstance(output, torch.Tensor)
output.record_stream(current_stream) # type: ignore[arg-type]
return (output,)