diff --git a/torch/distributed/_composable_state.py b/torch/distributed/_composable_state.py index b90a1007e763..507db1bf7fc6 100644 --- a/torch/distributed/_composable_state.py +++ b/torch/distributed/_composable_state.py @@ -15,8 +15,7 @@ _module_state_mapping: weakref.WeakKeyDictionary[ def _insert_module_state(module: nn.Module, state: _State) -> None: global _module_state_mapping - if module in _module_state_mapping: - raise AssertionError(f"Inserting {module} more than once.") + assert module not in _module_state_mapping, f"Inserting {module} more than once." _module_state_mapping[module] = weakref.ref(state) diff --git a/torch/distributed/_dist2.py b/torch/distributed/_dist2.py index d9ed7003ccfd..ce5cb8d7e0cc 100644 --- a/torch/distributed/_dist2.py +++ b/torch/distributed/_dist2.py @@ -71,8 +71,7 @@ def _gloo_factory( ) -> ProcessGroup: from torch.distributed import ProcessGroupGloo - if len(kwargs) != 0: - raise AssertionError("Gloo backend received unexpected kwargs") + assert len(kwargs) == 0, "Gloo backend received unexpected kwargs" backend_class = ProcessGroupGloo(store, rank, world_size, timeout) backend_class._set_sequence_number_for_group() diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index f1d59ca7655d..5dd56fc006c4 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -193,8 +193,7 @@ def all_gather_tensor( :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover that information and perform collective algebraic optimization. Use other forms of input for that. """ - if not self.is_contiguous(): - raise AssertionError("Tensor must be contiguous for all_gather_tensor") + assert self.is_contiguous() group_name = _resolve_group_name(group, tag) group_size = c10d._get_group_size_by_name(group_name) tensor = torch.ops._c10d_functional.all_gather_into_tensor( @@ -269,10 +268,9 @@ def reduce_scatter_tensor( group_name = _resolve_group_name(group, tag) group_size = c10d._get_group_size_by_name(group_name) - if self.size(scatter_dim) % group_size != 0: - raise AssertionError( - f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size})" - ) + assert self.size(scatter_dim) % group_size == 0, ( + f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size})" + ) if scatter_dim != 0: tensor_list = torch.chunk(self, group_size, dim=scatter_dim) self = torch.cat(tensor_list) @@ -309,10 +307,9 @@ def reduce_scatter_tensor_autograd( group_name = _resolve_group_name(group, tag) group_size = c10d._get_group_size_by_name(group_name) - if self.size(scatter_dim) % group_size != 0: - raise AssertionError( - f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" - ) + assert self.size(scatter_dim) % group_size == 0, ( + f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}" + ) if scatter_dim != 0: tensor_list = torch.chunk(self, group_size, dim=scatter_dim) self = torch.cat(tensor_list) @@ -409,15 +406,11 @@ def reduce_scatter_tensor_coalesced( group_name = _resolve_group_name(group, tag) group_size = c10d._get_group_size_by_name(group_name) - if len(scatter_dim) != len(inputs): - raise AssertionError( - f"Length of scatter_dim ({len(scatter_dim)}) must equal length of inputs ({len(inputs)})" - ) + assert len(scatter_dim) == len(inputs) for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)): - if tensor.size(dim) % group_size != 0: - raise AssertionError( - f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}" - ) + assert tensor.size(dim) % group_size == 0, ( + f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}" + ) if dim != 0: tensor_list = torch.chunk(tensor, group_size, dim=dim) inputs[idx] = torch.cat(tensor_list) @@ -435,8 +428,7 @@ def reduce_scatter_tensor_coalesced( # This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias. # Today, this maps 1:1 with "aten ops that are views". def _is_view_op(tgt): - if not isinstance(tgt, torch._ops.OpOverload): - raise AssertionError(f"Expected torch._ops.OpOverload, got {type(tgt)}") + assert isinstance(tgt, torch._ops.OpOverload) # Don't apply the view optimization to any `CompositeImplicitAutograd` ops. # See issue: https://github.com/pytorch/pytorch/issues/133421 if torch._C._dispatch_has_kernel_for_dispatch_key( @@ -473,25 +465,20 @@ def all_to_all_single( that information and perform collective algebraic optimization. Use other forms of input for that. """ if output_split_sizes is not None: - if not all( + assert all( isinstance(size, (int, torch.SymInt)) for size in output_split_sizes - ): - raise AssertionError( - f"All output_split_sizes must be int or SymInt, got {output_split_sizes}" - ) + ), output_split_sizes if input_split_sizes is not None: - if not all(isinstance(size, (int, torch.SymInt)) for size in input_split_sizes): - raise AssertionError( - f"All input_split_sizes must be int or SymInt, got {input_split_sizes}" - ) + assert all( + isinstance(size, (int, torch.SymInt)) for size in input_split_sizes + ), input_split_sizes group_name = _resolve_group_name(group, tag) group_size = c10d._get_group_size_by_name(group_name) if output_split_sizes is None or input_split_sizes is None: - if not (output_split_sizes is None and input_split_sizes is None): - raise AssertionError( - "output_split_sizes and input_split_sizes must either be " - "specified together or both set to None" - ) + assert output_split_sizes is None and input_split_sizes is None, ( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) output_split_sizes = [self.shape[0] // group_size] * group_size input_split_sizes = output_split_sizes tensor = torch.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined] @@ -514,26 +501,21 @@ def all_to_all_single_autograd( Same as all_to_all_single but supports autograd. """ if output_split_sizes is not None: - if not all( + assert all( isinstance(size, (int, torch.SymInt)) for size in output_split_sizes - ): - raise AssertionError( - f"All output_split_sizes must be int or SymInt, got {output_split_sizes}" - ) + ), output_split_sizes if input_split_sizes is not None: - if not all(isinstance(size, (int, torch.SymInt)) for size in input_split_sizes): - raise AssertionError( - f"All input_split_sizes must be int or SymInt, got {input_split_sizes}" - ) + assert all( + isinstance(size, (int, torch.SymInt)) for size in input_split_sizes + ), input_split_sizes group_name = _resolve_group_name(group, tag) group_size = c10d._get_group_size_by_name(group_name) if output_split_sizes is None or input_split_sizes is None: - if not (output_split_sizes is None and input_split_sizes is None): - raise AssertionError( - "output_split_sizes and input_split_sizes must either be " - "specified together or both set to None" - ) + assert output_split_sizes is None and input_split_sizes is None, ( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) output_split_sizes = [self.shape[0] // group_size] * group_size input_split_sizes = output_split_sizes tensor = torch.ops._c10d_functional_autograd.all_to_all_single( # type: ignore[attr-defined] @@ -616,10 +598,7 @@ class AsyncCollectiveTensor(torch.Tensor): @staticmethod def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride): - if meta is not None: - raise AssertionError( - "meta must be None for AsyncCollectiveTensor unflatten" - ) + assert meta is None elem = inner_tensors["elem"] return AsyncCollectiveTensor(elem) @@ -669,10 +648,7 @@ class AsyncCollectiveTensor(torch.Tensor): def wrap(e: torch.Tensor): # wait_tensor is idepotent and will do stream sync only once - if isinstance(e, AsyncCollectiveTensor): - raise AssertionError( - "Cannot wrap an AsyncCollectiveTensor inside another AsyncCollectiveTensor" - ) + assert not isinstance(e, AsyncCollectiveTensor) res = AsyncCollectiveTensor(e) return res @@ -746,10 +722,9 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int group_size = len(rankset) tag = tag or c10d._get_group_tag(group) elif isinstance(group, DeviceMesh): - if group.ndim != 1: - raise AssertionError( - "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" - ) + assert group.ndim == 1, ( + "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" + ) # TODO: it should run collective in the whole mesh instead of dim 0 pg = group.get_group() rankset = dist.get_process_group_ranks(pg) @@ -788,10 +763,9 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str: elif isinstance(group, str): return group elif isinstance(group, DeviceMesh): - if group.ndim != 1: - raise AssertionError( - "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" - ) + assert group.ndim == 1, ( + "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" + ) return group._dim_group_names[0] elif isinstance(group, tuple): if ( @@ -1081,14 +1055,12 @@ def all_gather_tensor_inplace( tag: str = "", gather_dim: int = 0, ): - if async_op: - raise AssertionError( - "Can't remap async version of inplace op to functional collective" - ) + assert not async_op, ( + "Can't remap async version of inplace op to functional collective" + ) group = group or dist.group.WORLD - if group is None: - raise AssertionError("group cannot be None") + assert group is not None return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag)) @@ -1102,14 +1074,12 @@ def reduce_scatter_tensor_inplace( scatter_dim: int = 0, tag: str = "", ): - if async_op: - raise AssertionError( - "Can't remap async version of inplace op to functional collective" - ) + assert not async_op, ( + "Can't remap async version of inplace op to functional collective" + ) group = group or dist.group.WORLD - if group is None: - raise AssertionError("group cannot be None") + assert group is not None return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag)) @@ -1133,14 +1103,12 @@ def all_reduce_inplace( async_op: bool = False, tag: str = "", ): - if async_op: - raise AssertionError( - "Can't remap async version of inplace op to functional collective" - ) + assert not async_op, ( + "Can't remap async version of inplace op to functional collective" + ) group = group or dist.group.WORLD - if group is None: - raise AssertionError("group cannot be None") + assert group is not None return tensor.copy_(all_reduce(tensor, op, group, tag)) @@ -1154,14 +1122,12 @@ def all_to_all_inplace( async_op=False, tag: str = "", ): - if async_op: - raise AssertionError( - "Can't remap async version of inplace op to functional collective" - ) + assert not async_op, ( + "Can't remap async version of inplace op to functional collective" + ) group = group or dist.group.WORLD - if group is None: - raise AssertionError("group cannot be None") + assert group is not None return output.copy_( all_to_all_single( @@ -1181,16 +1147,15 @@ def all_gather_inplace( async_op=False, tag: str = "", ): - if async_op: - raise AssertionError( - "Can't remap async version of inplace op to functional collective" - ) - if tensor.dim() != 0 and not all(t.size(0) == tensor.size(0) for t in tensor_list): - raise AssertionError("Remapping variable size all_gather is not yet supported") + assert not async_op, ( + "Can't remap async version of inplace op to functional collective" + ) + assert tensor.dim() == 0 or all(t.size(0) == tensor.size(0) for t in tensor_list), ( + "Remapping variable size all_gather is not yet supported" + ) group = group or dist.group.WORLD - if group is None: - raise AssertionError("group cannot be None") + assert group is not None output = all_gather_tensor(tensor, 0, group, tag) diff --git a/torch/distributed/_functional_collectives_impl.py b/torch/distributed/_functional_collectives_impl.py index e6174c11cd61..0c1ac0a079de 100644 --- a/torch/distributed/_functional_collectives_impl.py +++ b/torch/distributed/_functional_collectives_impl.py @@ -97,11 +97,10 @@ def _all_to_all_single( group_size: int, ): if output_split_sizes is None or input_split_sizes is None: - if not (output_split_sizes is None and input_split_sizes is None): - raise AssertionError( - "output_split_sizes and input_split_sizes must either be " - "specified together or both set to None" - ) + assert output_split_sizes is None and input_split_sizes is None, ( + "output_split_sizes and input_split_sizes must either be " + "specified together or both set to None" + ) output_split_sizes = [input.shape[0] // group_size] * group_size input_split_sizes = output_split_sizes diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index 30562afda2a8..cea7903bd0e2 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -184,18 +184,12 @@ def _iterate_state_dict( if companion_obj is not None: if isinstance(companion_obj, DTensor): - if not isinstance(ret, DTensor): - raise AssertionError( - "ret must be a DTensor when companion_obj is a DTensor" - ) + assert isinstance(ret, DTensor) companion_obj._local_tensor.copy_( ret._local_tensor, non_blocking=non_blocking ) elif isinstance(companion_obj, ShardedTensor): - if not isinstance(ret, ShardedTensor): - raise AssertionError( - "ret must be a ShardedTensor when companion_obj is a ShardedTensor" - ) + assert isinstance(ret, ShardedTensor) for idx, shard in enumerate(companion_obj.local_shards()): shard.tensor.copy_( ret.local_shards()[idx].tensor, non_blocking=non_blocking @@ -554,8 +548,7 @@ def _broadcast_tensors( for key in keys: if dist.get_rank() == 0: full_state = full_state_dict[key] - if not isinstance(full_state, torch.Tensor): - raise AssertionError("full_state must be a torch.Tensor") + assert isinstance(full_state, torch.Tensor) full_tensor = full_state.detach().to(pg_device) else: tensor_info = full_state_dict[key] @@ -714,8 +707,7 @@ def _distribute_state_dict( elif value.dim() == 0: local_state_dict[key] = value.cpu() else: - if not isinstance(value, torch.Tensor): - raise AssertionError("value must be a torch.Tensor") + assert isinstance(value, torch.Tensor) local_state = local_state_dict.get(key, None) if local_state is None: continue diff --git a/torch/distributed/collective_utils.py b/torch/distributed/collective_utils.py index 50e0517ca844..b61155274bc8 100644 --- a/torch/distributed/collective_utils.py +++ b/torch/distributed/collective_utils.py @@ -104,10 +104,7 @@ def broadcast( if pg is not None: broadcast_list = [sync_obj] dist.broadcast_object_list(broadcast_list, src=rank, group=pg) - if len(broadcast_list) != 1: - raise AssertionError( - f"Expected broadcast_list to have exactly 1 element, got {len(broadcast_list)}" - ) + assert len(broadcast_list) == 1 sync_obj = broadcast_list[0] # failure in any rank will trigger a throw in every rank. @@ -243,10 +240,8 @@ def all_gather_object_enforce_type( def _summarize_ranks(ranks: Iterable[int]) -> str: ranks = sorted(ranks) - if min(ranks) < 0: - raise AssertionError("ranks should all be positive") - if len(set(ranks)) != len(ranks): - raise AssertionError("ranks should not contain duplicates") + assert min(ranks) >= 0, "ranks should all be positive" + assert len(set(ranks)) == len(ranks), "ranks should not contain duplicates" curr: Optional[Union[int, range]] = None ranges = [] while ranks: @@ -260,8 +255,7 @@ def _summarize_ranks(ranks: Iterable[int]) -> str: step = x - curr curr = range(curr, x + step, step) else: - if not isinstance(curr, range): - raise AssertionError("curr must be an instance of range") + assert isinstance(curr, range) if x == curr.stop: curr = range(curr.start, curr.stop + curr.step, curr.step) else: diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 2063f24b584e..e30965cf3205 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -213,16 +213,14 @@ else: if _layout else _MeshLayout(self.mesh.size(), self.mesh.stride()) ) - if not self._layout.check_non_overlap(): - raise AssertionError( - "Please use a non-overlapping layout when creating a DeviceMesh." - ) + assert self._layout.check_non_overlap(), ( + "Please use a non-overlapping layout when creating a DeviceMesh." + ) # Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here. - if self._layout.numel() != self.mesh.numel(): - raise AssertionError( - "Please use a valid layout when creating a DeviceMesh." - f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}." - ) + assert self._layout.numel() == self.mesh.numel(), ( + "Please use a valid layout when creating a DeviceMesh." + f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}." + ) # private field to pre-generate DeviceMesh's hash self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) @@ -247,10 +245,7 @@ else: # calculate the coordinates of the current global rank on the mesh rank_coords = (self.mesh == _rank).nonzero() - if rank_coords.size(0) not in (0, 1): - raise AssertionError( - f"rank_coords.size(0) must be 0 or 1, got {rank_coords.size(0)}" - ) + assert rank_coords.size(0) in (0, 1) self._coordinate_on_dim: Optional[list[int]] = ( rank_coords[0].tolist() if rank_coords.size(0) > 0 else None ) @@ -595,10 +590,7 @@ else: if isinstance(mesh_dim, str) else mesh_dim ) - if not isinstance(mesh_dim, int): - raise AssertionError( - f"mesh_dim must be an int, got {type(mesh_dim)}" - ) + assert isinstance(mesh_dim, int) return not_none(_resolve_process_group(self._dim_group_names[mesh_dim])) def get_all_groups(self) -> list[ProcessGroup]: @@ -717,8 +709,9 @@ else: root_mesh = self._get_root_mesh() child_mesh_dim_names = self._mesh_dim_names if root_mesh and child_mesh_dim_names: - if len(child_mesh_dim_names) != 1: - raise AssertionError("The submesh can only be a 1D mesh.") + assert len(child_mesh_dim_names) == 1, ( + "The submesh can only be a 1D mesh." + ) child_mesh_dim_name = child_mesh_dim_names[0] return root_mesh._get_mesh_dim_by_name(child_mesh_dim_name) return None @@ -1055,10 +1048,9 @@ else: mesh_dim = 0 mesh_dim_group = not_none(self.get_group(mesh_dim)) - if not isinstance(mesh_dim_group, ProcessGroup): - raise AssertionError( - "We expect ProcessGroup before calling `get_rank`!" - ) + assert isinstance(mesh_dim_group, ProcessGroup), ( + "We expect ProcessGroup before calling `get_rank`!" + ) return not_none(get_rank(mesh_dim_group)) def get_coordinate(self) -> Optional[list[int]]: diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index dff669a21f8e..ea194a6ebe9a 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -1526,8 +1526,7 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) -> group = _get_default_group() if _rank_not_in_group(group): raise ValueError("Invalid process group specified") - if not isinstance(group, ProcessGroup): - raise AssertionError(f"Expected ProcessGroup, got {type(group)}") + assert isinstance(group, ProcessGroup) devices = group._device_types backends = set() if torch.device("cpu") in devices and is_gloo_available(): @@ -1666,14 +1665,13 @@ def init_process_group( if "torch._dynamo" in sys.modules: torch._dynamo.trace_rules.clear_lru_cache() - if not ((store is None) or (init_method is None)): - raise AssertionError("Cannot specify both init_method and store.") + assert (store is None) or (init_method is None), ( + "Cannot specify both init_method and store." + ) if store is not None: - if not world_size > 0: - raise AssertionError("world_size must be positive if using store") - if not rank >= 0: - raise AssertionError("rank must be non-negative if using store") + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" elif init_method is None: init_method = "env://" @@ -1947,8 +1945,7 @@ def _new_process_group_helper( backend_config = BackendConfig(backend) # Set the default backend when single backend is passed in. if "," not in str(backend) and ":" not in str(backend): - if backend not in Backend.backend_type_map: - raise AssertionError(f"Unknown backend type {backend}") + assert backend in Backend.backend_type_map, f"Unknown backend type {backend}" if backend == Backend.UNDEFINED: # Currently when backend is UNDEFINED, only one backend will be initialized # we use nccl (if cuda is available) or gloo as default backend @@ -2018,10 +2015,9 @@ def _new_process_group_helper( if not is_nccl_available(): raise RuntimeError("Distributed package doesn't have NCCL built in") if backend_options is not None: - if not isinstance(backend_options, ProcessGroupNCCL.Options): - raise AssertionError( - "Expected backend_options argument to be of type ProcessGroupNCCL.Options" - ) + assert isinstance(backend_options, ProcessGroupNCCL.Options), ( + "Expected backend_options argument to be of type ProcessGroupNCCL.Options" + ) if backend_options._timeout != timeout: warnings.warn( "backend_options._timeout was specified, " @@ -2071,8 +2067,9 @@ def _new_process_group_helper( ) backend_type = ProcessGroup.BackendType.XCCL else: - if backend_str.upper() not in Backend._plugins: - raise AssertionError(f"Unknown c10d backend type {backend_str.upper()}") + assert backend_str.upper() in Backend._plugins, ( + f"Unknown c10d backend type {backend_str.upper()}" + ) backend_plugin = Backend._plugins[backend_str.upper()] creator_fn = backend_plugin.creator_fn @@ -2097,16 +2094,10 @@ def _new_process_group_helper( # Set sequence numbers for gloo and nccl backends. if backend_str == Backend.GLOO: - if not isinstance(backend_class, ProcessGroupGloo): - raise AssertionError( - f"Expected ProcessGroupGloo, got {type(backend_class)}" - ) + assert isinstance(backend_class, ProcessGroupGloo) backend_class._set_sequence_number_for_group() elif backend_str == Backend.NCCL: - if not isinstance(backend_class, ProcessGroupNCCL): - raise AssertionError( - f"Expected ProcessGroupNCCL, got {type(backend_class)}" - ) + assert isinstance(backend_class, ProcessGroupNCCL) backend_class._set_sequence_number_for_group() # If the type is a subclass of ProcessGroup then return this process group immediately @@ -2153,10 +2144,8 @@ def _new_process_group_helper( pg._register_backend(torch.device(device), backend_type, backend_class) # set group_name and group_dsec to backend - if group_name is None: - raise AssertionError("group_name must not be None") - if group_desc is None: - raise AssertionError("group_desc must not be None") + assert group_name is not None + assert group_desc is not None pg._set_group_name(group_name) pg._set_group_desc(group_desc) @@ -2202,8 +2191,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None): else: pg = group - if pg is None: - raise AssertionError("Process group cannot be None") + assert pg is not None if _world.pg_map.get(pg, None) is None: raise ValueError("Invalid process group specified") @@ -2293,8 +2281,7 @@ def _abort_process_group(group: Optional[ProcessGroup] = None): pg = group or GroupMember.WORLD - if pg is None: - raise AssertionError("Process group cannot be None") + assert pg is not None if _world.pg_map.get(pg, None) is None: raise ValueError("Invalid process group specified or has been destroyed.") @@ -3351,8 +3338,7 @@ def gather_object( if my_group_rank != group_dst: return - if object_gather_list is None: - raise AssertionError("Must provide object_gather_list on dst rank") + assert object_gather_list is not None, "Must provide object_gather_list on dst rank" # pyrefly: ignore # unbound-name for i, tensor in enumerate(output_tensors): tensor = tensor.type(torch.uint8) @@ -3608,8 +3594,9 @@ def recv_object_list( rank_objects = get_global_rank(group, group_src) else: rank_objects = recv(object_tensor, group=group, group_src=group_src) - if rank_sizes != rank_objects: - raise AssertionError("Mismatch in return ranks for object sizes and objects.") + assert rank_sizes == rank_objects, ( + "Mismatch in return ranks for object sizes and objects." + ) # Deserialize objects using their stored sizes. offset = 0 for i, obj_size in enumerate(object_sizes_tensor): @@ -5016,8 +5003,7 @@ def _create_process_group_wrapper( world_size: int, timeout: timedelta = default_pg_timeout, ): - if not _GLOO_AVAILABLE: - raise RuntimeError("ProcessGroupWrapper unsupported without GLOO backend.") + assert _GLOO_AVAILABLE, "ProcessGroupWrapper unsupported without GLOO backend." # (whc) this appears to be just for the gloo backend? if so, `default_pg_timeout` is appropriate... @@ -5219,10 +5205,9 @@ def split_group( split_pg.bound_device_id = device_id # type: ignore[union-attr] split_backend_class = split_pg._get_backend(torch.device("cuda")) split_backend_class._set_sequence_number_for_group() - if split_pg.group_name != group_name: - raise AssertionError( - f"group name should be set to {group_name} but got {split_pg.group_name}" - ) + assert split_pg.group_name == group_name, ( + f"group name should be set to {group_name} but got {split_pg.group_name}" + ) # update global state _world.pg_map[split_pg] = (backend, split_pg.get_group_store()) @@ -5354,10 +5339,9 @@ def _new_group_with_tag( if device_id is None: device_id = default_pg.bound_device_id elif default_pg.bound_device_id is not None: - if device_id != default_pg.bound_device_id: - raise AssertionError( - "Mismatched bound device between new pg and the default pg." - ) + assert device_id == default_pg.bound_device_id, ( + "Mismatched bound device between new pg and the default pg." + ) default_backend, default_store = _world.pg_map[default_pg] global_rank = default_pg.rank() global_world_size = default_pg.size() @@ -5671,25 +5655,22 @@ def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGro def _find_or_create_pg_by_ranks_and_tag( tag: str, ranks: list[int], stride: int ) -> ProcessGroup: - if len(ranks) % stride != 0: - raise ValueError( - f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" - ) + assert len(ranks) % stride == 0, ( + f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})" + ) my_rank = get_rank() my_ranks = None if stride == len(ranks): my_ranks = ranks.copy() - if my_rank not in my_ranks: - raise RuntimeError("rankset doesn't include the current node") + assert my_rank in my_ranks, "rankset doesn't include the current node" else: for i in range(0, len(ranks), stride): rank_set = ranks[i : i + stride] if my_rank in rank_set: my_ranks = rank_set - if my_ranks is None: - raise RuntimeError("rankset doesn't include the current node") + assert my_ranks is not None, "rankset doesn't include the current node" my_ranks = sorted(my_ranks) diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index 602456ca6831..4d5e58778164 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -83,10 +83,9 @@ def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwa world_size = world_size_opt if rank != -1 or world_size != -1 or world_size_opt is None: query_dict = _query_to_dict(result.query) - if "rank" in query_dict or "world_size" in query_dict: - raise AssertionError( - f"The url: {url} has node-specific arguments(rank, world_size) already." - ) + assert "rank" not in query_dict and "world_size" not in query_dict, ( + f"The url: {url} has node-specific arguments(rank, world_size) already." + ) if rank != -1: query_dict["rank"] = str(rank) if world_size != -1 or world_size_opt is None: @@ -228,8 +227,7 @@ def _tcp_rendezvous_handler( world_size = int(query_dict["world_size"]) use_libuv = _get_use_libuv_from_query_dict(query_dict) - if result.hostname is None: - raise AssertionError("hostname cannot be None") + assert result.hostname is not None store = _create_c10d_store( result.hostname, result.port, rank, world_size, timeout, use_libuv diff --git a/torch/distributed/run.py b/torch/distributed/run.py index 67947e44ea66..c312b9dc9a0d 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -792,12 +792,8 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]: def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str]]: # If ``args`` not passed, defaults to ``sys.argv[:1]`` min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes) - if not (0 < min_nodes <= max_nodes): - raise AssertionError( - f"min_nodes must be > 0 and <= max_nodes, got min_nodes={min_nodes}, max_nodes={max_nodes}" - ) - if args.max_restarts < 0: - raise AssertionError("max_restarts must be >= 0") + assert 0 < min_nodes <= max_nodes + assert args.max_restarts >= 0 if ( hasattr(args, "master_addr") @@ -837,8 +833,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str if args.local_ranks_filter: try: ranks = set(map(int, args.local_ranks_filter.split(","))) - if not ranks: - raise AssertionError("ranks set cannot be empty") + assert ranks except Exception as e: raise ValueError( "--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2" diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index 8b77867de459..1dc123b50dbe 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -69,8 +69,9 @@ def _unpack_kwargs( flat_args: tuple[Any, ...], kwarg_keys: tuple[str, ...] ) -> tuple[tuple[Any, ...], dict[str, Any]]: """See _pack_kwargs.""" - if len(kwarg_keys) > len(flat_args): - raise AssertionError(f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}") + assert len(kwarg_keys) <= len(flat_args), ( + f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}" + ) if len(kwarg_keys) == 0: return flat_args, {} args = flat_args[: -len(kwarg_keys)] @@ -126,8 +127,7 @@ def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies): if isinstance(obj, PackedSequence): output.data.record_stream(current_stream) # type: ignore[arg-type] else: - if not isinstance(output, torch.Tensor): - raise AssertionError("output must be a torch.Tensor") + assert isinstance(output, torch.Tensor) output.record_stream(current_stream) # type: ignore[arg-type] return (output,)