mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[distributed] Replace assert statements with AssertionError exceptions (#165216)"
This reverts commit 74db92b21868b7e9e77cc966e5d57a8246723cbd.
Reverted https://github.com/pytorch/pytorch/pull/165216 on behalf of https://github.com/clee2000 due to I think this broke distributed/test_pg_wrapper.py::ProcessGroupNCCLWrapperTest::test_debug_level_detail_no_gloo [GH job link](https://github.com/pytorch/pytorch/actions/runs/18492765290/job/52693842750) [HUD commit link](74db92b218
), note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/165216#issuecomment-3402838765))
This commit is contained in:
@ -15,8 +15,7 @@ _module_state_mapping: weakref.WeakKeyDictionary[
|
||||
|
||||
def _insert_module_state(module: nn.Module, state: _State) -> None:
|
||||
global _module_state_mapping
|
||||
if module in _module_state_mapping:
|
||||
raise AssertionError(f"Inserting {module} more than once.")
|
||||
assert module not in _module_state_mapping, f"Inserting {module} more than once."
|
||||
_module_state_mapping[module] = weakref.ref(state)
|
||||
|
||||
|
||||
|
@ -71,8 +71,7 @@ def _gloo_factory(
|
||||
) -> ProcessGroup:
|
||||
from torch.distributed import ProcessGroupGloo
|
||||
|
||||
if len(kwargs) != 0:
|
||||
raise AssertionError("Gloo backend received unexpected kwargs")
|
||||
assert len(kwargs) == 0, "Gloo backend received unexpected kwargs"
|
||||
|
||||
backend_class = ProcessGroupGloo(store, rank, world_size, timeout)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
@ -193,8 +193,7 @@ def all_gather_tensor(
|
||||
:: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
|
||||
that information and perform collective algebraic optimization. Use other forms of input for that.
|
||||
"""
|
||||
if not self.is_contiguous():
|
||||
raise AssertionError("Tensor must be contiguous for all_gather_tensor")
|
||||
assert self.is_contiguous()
|
||||
group_name = _resolve_group_name(group, tag)
|
||||
group_size = c10d._get_group_size_by_name(group_name)
|
||||
tensor = torch.ops._c10d_functional.all_gather_into_tensor(
|
||||
@ -269,10 +268,9 @@ def reduce_scatter_tensor(
|
||||
group_name = _resolve_group_name(group, tag)
|
||||
group_size = c10d._get_group_size_by_name(group_name)
|
||||
|
||||
if self.size(scatter_dim) % group_size != 0:
|
||||
raise AssertionError(
|
||||
f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size})"
|
||||
)
|
||||
assert self.size(scatter_dim) % group_size == 0, (
|
||||
f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size})"
|
||||
)
|
||||
if scatter_dim != 0:
|
||||
tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
|
||||
self = torch.cat(tensor_list)
|
||||
@ -309,10 +307,9 @@ def reduce_scatter_tensor_autograd(
|
||||
group_name = _resolve_group_name(group, tag)
|
||||
group_size = c10d._get_group_size_by_name(group_name)
|
||||
|
||||
if self.size(scatter_dim) % group_size != 0:
|
||||
raise AssertionError(
|
||||
f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
|
||||
)
|
||||
assert self.size(scatter_dim) % group_size == 0, (
|
||||
f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
|
||||
)
|
||||
if scatter_dim != 0:
|
||||
tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
|
||||
self = torch.cat(tensor_list)
|
||||
@ -409,15 +406,11 @@ def reduce_scatter_tensor_coalesced(
|
||||
group_name = _resolve_group_name(group, tag)
|
||||
group_size = c10d._get_group_size_by_name(group_name)
|
||||
|
||||
if len(scatter_dim) != len(inputs):
|
||||
raise AssertionError(
|
||||
f"Length of scatter_dim ({len(scatter_dim)}) must equal length of inputs ({len(inputs)})"
|
||||
)
|
||||
assert len(scatter_dim) == len(inputs)
|
||||
for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)):
|
||||
if tensor.size(dim) % group_size != 0:
|
||||
raise AssertionError(
|
||||
f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
|
||||
)
|
||||
assert tensor.size(dim) % group_size == 0, (
|
||||
f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
|
||||
)
|
||||
if dim != 0:
|
||||
tensor_list = torch.chunk(tensor, group_size, dim=dim)
|
||||
inputs[idx] = torch.cat(tensor_list)
|
||||
@ -435,8 +428,7 @@ def reduce_scatter_tensor_coalesced(
|
||||
# This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias.
|
||||
# Today, this maps 1:1 with "aten ops that are views".
|
||||
def _is_view_op(tgt):
|
||||
if not isinstance(tgt, torch._ops.OpOverload):
|
||||
raise AssertionError(f"Expected torch._ops.OpOverload, got {type(tgt)}")
|
||||
assert isinstance(tgt, torch._ops.OpOverload)
|
||||
# Don't apply the view optimization to any `CompositeImplicitAutograd` ops.
|
||||
# See issue: https://github.com/pytorch/pytorch/issues/133421
|
||||
if torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
@ -473,25 +465,20 @@ def all_to_all_single(
|
||||
that information and perform collective algebraic optimization. Use other forms of input for that.
|
||||
"""
|
||||
if output_split_sizes is not None:
|
||||
if not all(
|
||||
assert all(
|
||||
isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
|
||||
):
|
||||
raise AssertionError(
|
||||
f"All output_split_sizes must be int or SymInt, got {output_split_sizes}"
|
||||
)
|
||||
), output_split_sizes
|
||||
if input_split_sizes is not None:
|
||||
if not all(isinstance(size, (int, torch.SymInt)) for size in input_split_sizes):
|
||||
raise AssertionError(
|
||||
f"All input_split_sizes must be int or SymInt, got {input_split_sizes}"
|
||||
)
|
||||
assert all(
|
||||
isinstance(size, (int, torch.SymInt)) for size in input_split_sizes
|
||||
), input_split_sizes
|
||||
group_name = _resolve_group_name(group, tag)
|
||||
group_size = c10d._get_group_size_by_name(group_name)
|
||||
if output_split_sizes is None or input_split_sizes is None:
|
||||
if not (output_split_sizes is None and input_split_sizes is None):
|
||||
raise AssertionError(
|
||||
"output_split_sizes and input_split_sizes must either be "
|
||||
"specified together or both set to None"
|
||||
)
|
||||
assert output_split_sizes is None and input_split_sizes is None, (
|
||||
"output_split_sizes and input_split_sizes must either be "
|
||||
"specified together or both set to None"
|
||||
)
|
||||
output_split_sizes = [self.shape[0] // group_size] * group_size
|
||||
input_split_sizes = output_split_sizes
|
||||
tensor = torch.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined]
|
||||
@ -514,26 +501,21 @@ def all_to_all_single_autograd(
|
||||
Same as all_to_all_single but supports autograd.
|
||||
"""
|
||||
if output_split_sizes is not None:
|
||||
if not all(
|
||||
assert all(
|
||||
isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
|
||||
):
|
||||
raise AssertionError(
|
||||
f"All output_split_sizes must be int or SymInt, got {output_split_sizes}"
|
||||
)
|
||||
), output_split_sizes
|
||||
if input_split_sizes is not None:
|
||||
if not all(isinstance(size, (int, torch.SymInt)) for size in input_split_sizes):
|
||||
raise AssertionError(
|
||||
f"All input_split_sizes must be int or SymInt, got {input_split_sizes}"
|
||||
)
|
||||
assert all(
|
||||
isinstance(size, (int, torch.SymInt)) for size in input_split_sizes
|
||||
), input_split_sizes
|
||||
|
||||
group_name = _resolve_group_name(group, tag)
|
||||
group_size = c10d._get_group_size_by_name(group_name)
|
||||
if output_split_sizes is None or input_split_sizes is None:
|
||||
if not (output_split_sizes is None and input_split_sizes is None):
|
||||
raise AssertionError(
|
||||
"output_split_sizes and input_split_sizes must either be "
|
||||
"specified together or both set to None"
|
||||
)
|
||||
assert output_split_sizes is None and input_split_sizes is None, (
|
||||
"output_split_sizes and input_split_sizes must either be "
|
||||
"specified together or both set to None"
|
||||
)
|
||||
output_split_sizes = [self.shape[0] // group_size] * group_size
|
||||
input_split_sizes = output_split_sizes
|
||||
tensor = torch.ops._c10d_functional_autograd.all_to_all_single( # type: ignore[attr-defined]
|
||||
@ -616,10 +598,7 @@ class AsyncCollectiveTensor(torch.Tensor):
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
|
||||
if meta is not None:
|
||||
raise AssertionError(
|
||||
"meta must be None for AsyncCollectiveTensor unflatten"
|
||||
)
|
||||
assert meta is None
|
||||
elem = inner_tensors["elem"]
|
||||
return AsyncCollectiveTensor(elem)
|
||||
|
||||
@ -669,10 +648,7 @@ class AsyncCollectiveTensor(torch.Tensor):
|
||||
|
||||
def wrap(e: torch.Tensor):
|
||||
# wait_tensor is idepotent and will do stream sync only once
|
||||
if isinstance(e, AsyncCollectiveTensor):
|
||||
raise AssertionError(
|
||||
"Cannot wrap an AsyncCollectiveTensor inside another AsyncCollectiveTensor"
|
||||
)
|
||||
assert not isinstance(e, AsyncCollectiveTensor)
|
||||
res = AsyncCollectiveTensor(e)
|
||||
return res
|
||||
|
||||
@ -746,10 +722,9 @@ def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int
|
||||
group_size = len(rankset)
|
||||
tag = tag or c10d._get_group_tag(group)
|
||||
elif isinstance(group, DeviceMesh):
|
||||
if group.ndim != 1:
|
||||
raise AssertionError(
|
||||
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
|
||||
)
|
||||
assert group.ndim == 1, (
|
||||
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
|
||||
)
|
||||
# TODO: it should run collective in the whole mesh instead of dim 0
|
||||
pg = group.get_group()
|
||||
rankset = dist.get_process_group_ranks(pg)
|
||||
@ -788,10 +763,9 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
|
||||
elif isinstance(group, str):
|
||||
return group
|
||||
elif isinstance(group, DeviceMesh):
|
||||
if group.ndim != 1:
|
||||
raise AssertionError(
|
||||
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
|
||||
)
|
||||
assert group.ndim == 1, (
|
||||
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
|
||||
)
|
||||
return group._dim_group_names[0]
|
||||
elif isinstance(group, tuple):
|
||||
if (
|
||||
@ -1081,14 +1055,12 @@ def all_gather_tensor_inplace(
|
||||
tag: str = "",
|
||||
gather_dim: int = 0,
|
||||
):
|
||||
if async_op:
|
||||
raise AssertionError(
|
||||
"Can't remap async version of inplace op to functional collective"
|
||||
)
|
||||
assert not async_op, (
|
||||
"Can't remap async version of inplace op to functional collective"
|
||||
)
|
||||
|
||||
group = group or dist.group.WORLD
|
||||
if group is None:
|
||||
raise AssertionError("group cannot be None")
|
||||
assert group is not None
|
||||
|
||||
return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag))
|
||||
|
||||
@ -1102,14 +1074,12 @@ def reduce_scatter_tensor_inplace(
|
||||
scatter_dim: int = 0,
|
||||
tag: str = "",
|
||||
):
|
||||
if async_op:
|
||||
raise AssertionError(
|
||||
"Can't remap async version of inplace op to functional collective"
|
||||
)
|
||||
assert not async_op, (
|
||||
"Can't remap async version of inplace op to functional collective"
|
||||
)
|
||||
|
||||
group = group or dist.group.WORLD
|
||||
if group is None:
|
||||
raise AssertionError("group cannot be None")
|
||||
assert group is not None
|
||||
|
||||
return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag))
|
||||
|
||||
@ -1133,14 +1103,12 @@ def all_reduce_inplace(
|
||||
async_op: bool = False,
|
||||
tag: str = "",
|
||||
):
|
||||
if async_op:
|
||||
raise AssertionError(
|
||||
"Can't remap async version of inplace op to functional collective"
|
||||
)
|
||||
assert not async_op, (
|
||||
"Can't remap async version of inplace op to functional collective"
|
||||
)
|
||||
|
||||
group = group or dist.group.WORLD
|
||||
if group is None:
|
||||
raise AssertionError("group cannot be None")
|
||||
assert group is not None
|
||||
|
||||
return tensor.copy_(all_reduce(tensor, op, group, tag))
|
||||
|
||||
@ -1154,14 +1122,12 @@ def all_to_all_inplace(
|
||||
async_op=False,
|
||||
tag: str = "",
|
||||
):
|
||||
if async_op:
|
||||
raise AssertionError(
|
||||
"Can't remap async version of inplace op to functional collective"
|
||||
)
|
||||
assert not async_op, (
|
||||
"Can't remap async version of inplace op to functional collective"
|
||||
)
|
||||
|
||||
group = group or dist.group.WORLD
|
||||
if group is None:
|
||||
raise AssertionError("group cannot be None")
|
||||
assert group is not None
|
||||
|
||||
return output.copy_(
|
||||
all_to_all_single(
|
||||
@ -1181,16 +1147,15 @@ def all_gather_inplace(
|
||||
async_op=False,
|
||||
tag: str = "",
|
||||
):
|
||||
if async_op:
|
||||
raise AssertionError(
|
||||
"Can't remap async version of inplace op to functional collective"
|
||||
)
|
||||
if tensor.dim() != 0 and not all(t.size(0) == tensor.size(0) for t in tensor_list):
|
||||
raise AssertionError("Remapping variable size all_gather is not yet supported")
|
||||
assert not async_op, (
|
||||
"Can't remap async version of inplace op to functional collective"
|
||||
)
|
||||
assert tensor.dim() == 0 or all(t.size(0) == tensor.size(0) for t in tensor_list), (
|
||||
"Remapping variable size all_gather is not yet supported"
|
||||
)
|
||||
|
||||
group = group or dist.group.WORLD
|
||||
if group is None:
|
||||
raise AssertionError("group cannot be None")
|
||||
assert group is not None
|
||||
|
||||
output = all_gather_tensor(tensor, 0, group, tag)
|
||||
|
||||
|
@ -97,11 +97,10 @@ def _all_to_all_single(
|
||||
group_size: int,
|
||||
):
|
||||
if output_split_sizes is None or input_split_sizes is None:
|
||||
if not (output_split_sizes is None and input_split_sizes is None):
|
||||
raise AssertionError(
|
||||
"output_split_sizes and input_split_sizes must either be "
|
||||
"specified together or both set to None"
|
||||
)
|
||||
assert output_split_sizes is None and input_split_sizes is None, (
|
||||
"output_split_sizes and input_split_sizes must either be "
|
||||
"specified together or both set to None"
|
||||
)
|
||||
output_split_sizes = [input.shape[0] // group_size] * group_size
|
||||
input_split_sizes = output_split_sizes
|
||||
|
||||
|
@ -184,18 +184,12 @@ def _iterate_state_dict(
|
||||
|
||||
if companion_obj is not None:
|
||||
if isinstance(companion_obj, DTensor):
|
||||
if not isinstance(ret, DTensor):
|
||||
raise AssertionError(
|
||||
"ret must be a DTensor when companion_obj is a DTensor"
|
||||
)
|
||||
assert isinstance(ret, DTensor)
|
||||
companion_obj._local_tensor.copy_(
|
||||
ret._local_tensor, non_blocking=non_blocking
|
||||
)
|
||||
elif isinstance(companion_obj, ShardedTensor):
|
||||
if not isinstance(ret, ShardedTensor):
|
||||
raise AssertionError(
|
||||
"ret must be a ShardedTensor when companion_obj is a ShardedTensor"
|
||||
)
|
||||
assert isinstance(ret, ShardedTensor)
|
||||
for idx, shard in enumerate(companion_obj.local_shards()):
|
||||
shard.tensor.copy_(
|
||||
ret.local_shards()[idx].tensor, non_blocking=non_blocking
|
||||
@ -554,8 +548,7 @@ def _broadcast_tensors(
|
||||
for key in keys:
|
||||
if dist.get_rank() == 0:
|
||||
full_state = full_state_dict[key]
|
||||
if not isinstance(full_state, torch.Tensor):
|
||||
raise AssertionError("full_state must be a torch.Tensor")
|
||||
assert isinstance(full_state, torch.Tensor)
|
||||
full_tensor = full_state.detach().to(pg_device)
|
||||
else:
|
||||
tensor_info = full_state_dict[key]
|
||||
@ -714,8 +707,7 @@ def _distribute_state_dict(
|
||||
elif value.dim() == 0:
|
||||
local_state_dict[key] = value.cpu()
|
||||
else:
|
||||
if not isinstance(value, torch.Tensor):
|
||||
raise AssertionError("value must be a torch.Tensor")
|
||||
assert isinstance(value, torch.Tensor)
|
||||
local_state = local_state_dict.get(key, None)
|
||||
if local_state is None:
|
||||
continue
|
||||
|
@ -104,10 +104,7 @@ def broadcast(
|
||||
if pg is not None:
|
||||
broadcast_list = [sync_obj]
|
||||
dist.broadcast_object_list(broadcast_list, src=rank, group=pg)
|
||||
if len(broadcast_list) != 1:
|
||||
raise AssertionError(
|
||||
f"Expected broadcast_list to have exactly 1 element, got {len(broadcast_list)}"
|
||||
)
|
||||
assert len(broadcast_list) == 1
|
||||
sync_obj = broadcast_list[0]
|
||||
|
||||
# failure in any rank will trigger a throw in every rank.
|
||||
@ -243,10 +240,8 @@ def all_gather_object_enforce_type(
|
||||
|
||||
def _summarize_ranks(ranks: Iterable[int]) -> str:
|
||||
ranks = sorted(ranks)
|
||||
if min(ranks) < 0:
|
||||
raise AssertionError("ranks should all be positive")
|
||||
if len(set(ranks)) != len(ranks):
|
||||
raise AssertionError("ranks should not contain duplicates")
|
||||
assert min(ranks) >= 0, "ranks should all be positive"
|
||||
assert len(set(ranks)) == len(ranks), "ranks should not contain duplicates"
|
||||
curr: Optional[Union[int, range]] = None
|
||||
ranges = []
|
||||
while ranks:
|
||||
@ -260,8 +255,7 @@ def _summarize_ranks(ranks: Iterable[int]) -> str:
|
||||
step = x - curr
|
||||
curr = range(curr, x + step, step)
|
||||
else:
|
||||
if not isinstance(curr, range):
|
||||
raise AssertionError("curr must be an instance of range")
|
||||
assert isinstance(curr, range)
|
||||
if x == curr.stop:
|
||||
curr = range(curr.start, curr.stop + curr.step, curr.step)
|
||||
else:
|
||||
|
@ -213,16 +213,14 @@ else:
|
||||
if _layout
|
||||
else _MeshLayout(self.mesh.size(), self.mesh.stride())
|
||||
)
|
||||
if not self._layout.check_non_overlap():
|
||||
raise AssertionError(
|
||||
"Please use a non-overlapping layout when creating a DeviceMesh."
|
||||
)
|
||||
assert self._layout.check_non_overlap(), (
|
||||
"Please use a non-overlapping layout when creating a DeviceMesh."
|
||||
)
|
||||
# Because we still need to support slicing of flattened dim from root mesh, so we don't check stride here.
|
||||
if self._layout.numel() != self.mesh.numel():
|
||||
raise AssertionError(
|
||||
"Please use a valid layout when creating a DeviceMesh."
|
||||
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
|
||||
)
|
||||
assert self._layout.numel() == self.mesh.numel(), (
|
||||
"Please use a valid layout when creating a DeviceMesh."
|
||||
f"The layout {self._layout} is not consistent with the mesh size {self.mesh.size()}."
|
||||
)
|
||||
|
||||
# private field to pre-generate DeviceMesh's hash
|
||||
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
|
||||
@ -247,10 +245,7 @@ else:
|
||||
|
||||
# calculate the coordinates of the current global rank on the mesh
|
||||
rank_coords = (self.mesh == _rank).nonzero()
|
||||
if rank_coords.size(0) not in (0, 1):
|
||||
raise AssertionError(
|
||||
f"rank_coords.size(0) must be 0 or 1, got {rank_coords.size(0)}"
|
||||
)
|
||||
assert rank_coords.size(0) in (0, 1)
|
||||
self._coordinate_on_dim: Optional[list[int]] = (
|
||||
rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
|
||||
)
|
||||
@ -595,10 +590,7 @@ else:
|
||||
if isinstance(mesh_dim, str)
|
||||
else mesh_dim
|
||||
)
|
||||
if not isinstance(mesh_dim, int):
|
||||
raise AssertionError(
|
||||
f"mesh_dim must be an int, got {type(mesh_dim)}"
|
||||
)
|
||||
assert isinstance(mesh_dim, int)
|
||||
return not_none(_resolve_process_group(self._dim_group_names[mesh_dim]))
|
||||
|
||||
def get_all_groups(self) -> list[ProcessGroup]:
|
||||
@ -717,8 +709,9 @@ else:
|
||||
root_mesh = self._get_root_mesh()
|
||||
child_mesh_dim_names = self._mesh_dim_names
|
||||
if root_mesh and child_mesh_dim_names:
|
||||
if len(child_mesh_dim_names) != 1:
|
||||
raise AssertionError("The submesh can only be a 1D mesh.")
|
||||
assert len(child_mesh_dim_names) == 1, (
|
||||
"The submesh can only be a 1D mesh."
|
||||
)
|
||||
child_mesh_dim_name = child_mesh_dim_names[0]
|
||||
return root_mesh._get_mesh_dim_by_name(child_mesh_dim_name)
|
||||
return None
|
||||
@ -1055,10 +1048,9 @@ else:
|
||||
mesh_dim = 0
|
||||
|
||||
mesh_dim_group = not_none(self.get_group(mesh_dim))
|
||||
if not isinstance(mesh_dim_group, ProcessGroup):
|
||||
raise AssertionError(
|
||||
"We expect ProcessGroup before calling `get_rank`!"
|
||||
)
|
||||
assert isinstance(mesh_dim_group, ProcessGroup), (
|
||||
"We expect ProcessGroup before calling `get_rank`!"
|
||||
)
|
||||
return not_none(get_rank(mesh_dim_group))
|
||||
|
||||
def get_coordinate(self) -> Optional[list[int]]:
|
||||
|
@ -1526,8 +1526,7 @@ def _set_pg_timeout(timeout: timedelta, group: Optional[ProcessGroup] = None) ->
|
||||
group = _get_default_group()
|
||||
if _rank_not_in_group(group):
|
||||
raise ValueError("Invalid process group specified")
|
||||
if not isinstance(group, ProcessGroup):
|
||||
raise AssertionError(f"Expected ProcessGroup, got {type(group)}")
|
||||
assert isinstance(group, ProcessGroup)
|
||||
devices = group._device_types
|
||||
backends = set()
|
||||
if torch.device("cpu") in devices and is_gloo_available():
|
||||
@ -1666,14 +1665,13 @@ def init_process_group(
|
||||
if "torch._dynamo" in sys.modules:
|
||||
torch._dynamo.trace_rules.clear_lru_cache()
|
||||
|
||||
if not ((store is None) or (init_method is None)):
|
||||
raise AssertionError("Cannot specify both init_method and store.")
|
||||
assert (store is None) or (init_method is None), (
|
||||
"Cannot specify both init_method and store."
|
||||
)
|
||||
|
||||
if store is not None:
|
||||
if not world_size > 0:
|
||||
raise AssertionError("world_size must be positive if using store")
|
||||
if not rank >= 0:
|
||||
raise AssertionError("rank must be non-negative if using store")
|
||||
assert world_size > 0, "world_size must be positive if using store"
|
||||
assert rank >= 0, "rank must be non-negative if using store"
|
||||
elif init_method is None:
|
||||
init_method = "env://"
|
||||
|
||||
@ -1947,8 +1945,7 @@ def _new_process_group_helper(
|
||||
backend_config = BackendConfig(backend)
|
||||
# Set the default backend when single backend is passed in.
|
||||
if "," not in str(backend) and ":" not in str(backend):
|
||||
if backend not in Backend.backend_type_map:
|
||||
raise AssertionError(f"Unknown backend type {backend}")
|
||||
assert backend in Backend.backend_type_map, f"Unknown backend type {backend}"
|
||||
if backend == Backend.UNDEFINED:
|
||||
# Currently when backend is UNDEFINED, only one backend will be initialized
|
||||
# we use nccl (if cuda is available) or gloo as default backend
|
||||
@ -2018,10 +2015,9 @@ def _new_process_group_helper(
|
||||
if not is_nccl_available():
|
||||
raise RuntimeError("Distributed package doesn't have NCCL built in")
|
||||
if backend_options is not None:
|
||||
if not isinstance(backend_options, ProcessGroupNCCL.Options):
|
||||
raise AssertionError(
|
||||
"Expected backend_options argument to be of type ProcessGroupNCCL.Options"
|
||||
)
|
||||
assert isinstance(backend_options, ProcessGroupNCCL.Options), (
|
||||
"Expected backend_options argument to be of type ProcessGroupNCCL.Options"
|
||||
)
|
||||
if backend_options._timeout != timeout:
|
||||
warnings.warn(
|
||||
"backend_options._timeout was specified, "
|
||||
@ -2071,8 +2067,9 @@ def _new_process_group_helper(
|
||||
)
|
||||
backend_type = ProcessGroup.BackendType.XCCL
|
||||
else:
|
||||
if backend_str.upper() not in Backend._plugins:
|
||||
raise AssertionError(f"Unknown c10d backend type {backend_str.upper()}")
|
||||
assert backend_str.upper() in Backend._plugins, (
|
||||
f"Unknown c10d backend type {backend_str.upper()}"
|
||||
)
|
||||
|
||||
backend_plugin = Backend._plugins[backend_str.upper()]
|
||||
creator_fn = backend_plugin.creator_fn
|
||||
@ -2097,16 +2094,10 @@ def _new_process_group_helper(
|
||||
|
||||
# Set sequence numbers for gloo and nccl backends.
|
||||
if backend_str == Backend.GLOO:
|
||||
if not isinstance(backend_class, ProcessGroupGloo):
|
||||
raise AssertionError(
|
||||
f"Expected ProcessGroupGloo, got {type(backend_class)}"
|
||||
)
|
||||
assert isinstance(backend_class, ProcessGroupGloo)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
elif backend_str == Backend.NCCL:
|
||||
if not isinstance(backend_class, ProcessGroupNCCL):
|
||||
raise AssertionError(
|
||||
f"Expected ProcessGroupNCCL, got {type(backend_class)}"
|
||||
)
|
||||
assert isinstance(backend_class, ProcessGroupNCCL)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
# If the type is a subclass of ProcessGroup then return this process group immediately
|
||||
@ -2153,10 +2144,8 @@ def _new_process_group_helper(
|
||||
pg._register_backend(torch.device(device), backend_type, backend_class)
|
||||
|
||||
# set group_name and group_dsec to backend
|
||||
if group_name is None:
|
||||
raise AssertionError("group_name must not be None")
|
||||
if group_desc is None:
|
||||
raise AssertionError("group_desc must not be None")
|
||||
assert group_name is not None
|
||||
assert group_desc is not None
|
||||
pg._set_group_name(group_name)
|
||||
pg._set_group_desc(group_desc)
|
||||
|
||||
@ -2202,8 +2191,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
|
||||
else:
|
||||
pg = group
|
||||
|
||||
if pg is None:
|
||||
raise AssertionError("Process group cannot be None")
|
||||
assert pg is not None
|
||||
if _world.pg_map.get(pg, None) is None:
|
||||
raise ValueError("Invalid process group specified")
|
||||
|
||||
@ -2293,8 +2281,7 @@ def _abort_process_group(group: Optional[ProcessGroup] = None):
|
||||
|
||||
pg = group or GroupMember.WORLD
|
||||
|
||||
if pg is None:
|
||||
raise AssertionError("Process group cannot be None")
|
||||
assert pg is not None
|
||||
if _world.pg_map.get(pg, None) is None:
|
||||
raise ValueError("Invalid process group specified or has been destroyed.")
|
||||
|
||||
@ -3351,8 +3338,7 @@ def gather_object(
|
||||
if my_group_rank != group_dst:
|
||||
return
|
||||
|
||||
if object_gather_list is None:
|
||||
raise AssertionError("Must provide object_gather_list on dst rank")
|
||||
assert object_gather_list is not None, "Must provide object_gather_list on dst rank"
|
||||
# pyrefly: ignore # unbound-name
|
||||
for i, tensor in enumerate(output_tensors):
|
||||
tensor = tensor.type(torch.uint8)
|
||||
@ -3608,8 +3594,9 @@ def recv_object_list(
|
||||
rank_objects = get_global_rank(group, group_src)
|
||||
else:
|
||||
rank_objects = recv(object_tensor, group=group, group_src=group_src)
|
||||
if rank_sizes != rank_objects:
|
||||
raise AssertionError("Mismatch in return ranks for object sizes and objects.")
|
||||
assert rank_sizes == rank_objects, (
|
||||
"Mismatch in return ranks for object sizes and objects."
|
||||
)
|
||||
# Deserialize objects using their stored sizes.
|
||||
offset = 0
|
||||
for i, obj_size in enumerate(object_sizes_tensor):
|
||||
@ -5016,8 +5003,7 @@ def _create_process_group_wrapper(
|
||||
world_size: int,
|
||||
timeout: timedelta = default_pg_timeout,
|
||||
):
|
||||
if not _GLOO_AVAILABLE:
|
||||
raise RuntimeError("ProcessGroupWrapper unsupported without GLOO backend.")
|
||||
assert _GLOO_AVAILABLE, "ProcessGroupWrapper unsupported without GLOO backend."
|
||||
|
||||
# (whc) this appears to be just for the gloo backend? if so, `default_pg_timeout` is appropriate...
|
||||
|
||||
@ -5219,10 +5205,9 @@ def split_group(
|
||||
split_pg.bound_device_id = device_id # type: ignore[union-attr]
|
||||
split_backend_class = split_pg._get_backend(torch.device("cuda"))
|
||||
split_backend_class._set_sequence_number_for_group()
|
||||
if split_pg.group_name != group_name:
|
||||
raise AssertionError(
|
||||
f"group name should be set to {group_name} but got {split_pg.group_name}"
|
||||
)
|
||||
assert split_pg.group_name == group_name, (
|
||||
f"group name should be set to {group_name} but got {split_pg.group_name}"
|
||||
)
|
||||
|
||||
# update global state
|
||||
_world.pg_map[split_pg] = (backend, split_pg.get_group_store())
|
||||
@ -5354,10 +5339,9 @@ def _new_group_with_tag(
|
||||
if device_id is None:
|
||||
device_id = default_pg.bound_device_id
|
||||
elif default_pg.bound_device_id is not None:
|
||||
if device_id != default_pg.bound_device_id:
|
||||
raise AssertionError(
|
||||
"Mismatched bound device between new pg and the default pg."
|
||||
)
|
||||
assert device_id == default_pg.bound_device_id, (
|
||||
"Mismatched bound device between new pg and the default pg."
|
||||
)
|
||||
default_backend, default_store = _world.pg_map[default_pg]
|
||||
global_rank = default_pg.rank()
|
||||
global_world_size = default_pg.size()
|
||||
@ -5671,25 +5655,22 @@ def _find_pg_by_ranks_and_tag(tag: str, ranks: list[int]) -> Optional[ProcessGro
|
||||
def _find_or_create_pg_by_ranks_and_tag(
|
||||
tag: str, ranks: list[int], stride: int
|
||||
) -> ProcessGroup:
|
||||
if len(ranks) % stride != 0:
|
||||
raise ValueError(
|
||||
f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
|
||||
)
|
||||
assert len(ranks) % stride == 0, (
|
||||
f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
|
||||
)
|
||||
|
||||
my_rank = get_rank()
|
||||
my_ranks = None
|
||||
|
||||
if stride == len(ranks):
|
||||
my_ranks = ranks.copy()
|
||||
if my_rank not in my_ranks:
|
||||
raise RuntimeError("rankset doesn't include the current node")
|
||||
assert my_rank in my_ranks, "rankset doesn't include the current node"
|
||||
else:
|
||||
for i in range(0, len(ranks), stride):
|
||||
rank_set = ranks[i : i + stride]
|
||||
if my_rank in rank_set:
|
||||
my_ranks = rank_set
|
||||
if my_ranks is None:
|
||||
raise RuntimeError("rankset doesn't include the current node")
|
||||
assert my_ranks is not None, "rankset doesn't include the current node"
|
||||
|
||||
my_ranks = sorted(my_ranks)
|
||||
|
||||
|
@ -83,10 +83,9 @@ def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwa
|
||||
world_size = world_size_opt
|
||||
if rank != -1 or world_size != -1 or world_size_opt is None:
|
||||
query_dict = _query_to_dict(result.query)
|
||||
if "rank" in query_dict or "world_size" in query_dict:
|
||||
raise AssertionError(
|
||||
f"The url: {url} has node-specific arguments(rank, world_size) already."
|
||||
)
|
||||
assert "rank" not in query_dict and "world_size" not in query_dict, (
|
||||
f"The url: {url} has node-specific arguments(rank, world_size) already."
|
||||
)
|
||||
if rank != -1:
|
||||
query_dict["rank"] = str(rank)
|
||||
if world_size != -1 or world_size_opt is None:
|
||||
@ -228,8 +227,7 @@ def _tcp_rendezvous_handler(
|
||||
world_size = int(query_dict["world_size"])
|
||||
use_libuv = _get_use_libuv_from_query_dict(query_dict)
|
||||
|
||||
if result.hostname is None:
|
||||
raise AssertionError("hostname cannot be None")
|
||||
assert result.hostname is not None
|
||||
|
||||
store = _create_c10d_store(
|
||||
result.hostname, result.port, rank, world_size, timeout, use_libuv
|
||||
|
@ -792,12 +792,8 @@ def _get_logs_specs_class(logs_specs_name: Optional[str]) -> type[LogsSpecs]:
|
||||
def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str]]:
|
||||
# If ``args`` not passed, defaults to ``sys.argv[:1]``
|
||||
min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
|
||||
if not (0 < min_nodes <= max_nodes):
|
||||
raise AssertionError(
|
||||
f"min_nodes must be > 0 and <= max_nodes, got min_nodes={min_nodes}, max_nodes={max_nodes}"
|
||||
)
|
||||
if args.max_restarts < 0:
|
||||
raise AssertionError("max_restarts must be >= 0")
|
||||
assert 0 < min_nodes <= max_nodes
|
||||
assert args.max_restarts >= 0
|
||||
|
||||
if (
|
||||
hasattr(args, "master_addr")
|
||||
@ -837,8 +833,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str
|
||||
if args.local_ranks_filter:
|
||||
try:
|
||||
ranks = set(map(int, args.local_ranks_filter.split(",")))
|
||||
if not ranks:
|
||||
raise AssertionError("ranks set cannot be empty")
|
||||
assert ranks
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
"--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2"
|
||||
|
@ -69,8 +69,9 @@ def _unpack_kwargs(
|
||||
flat_args: tuple[Any, ...], kwarg_keys: tuple[str, ...]
|
||||
) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
||||
"""See _pack_kwargs."""
|
||||
if len(kwarg_keys) > len(flat_args):
|
||||
raise AssertionError(f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}")
|
||||
assert len(kwarg_keys) <= len(flat_args), (
|
||||
f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}"
|
||||
)
|
||||
if len(kwarg_keys) == 0:
|
||||
return flat_args, {}
|
||||
args = flat_args[: -len(kwarg_keys)]
|
||||
@ -126,8 +127,7 @@ def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies):
|
||||
if isinstance(obj, PackedSequence):
|
||||
output.data.record_stream(current_stream) # type: ignore[arg-type]
|
||||
else:
|
||||
if not isinstance(output, torch.Tensor):
|
||||
raise AssertionError("output must be a torch.Tensor")
|
||||
assert isinstance(output, torch.Tensor)
|
||||
output.record_stream(current_stream) # type: ignore[arg-type]
|
||||
return (output,)
|
||||
|
||||
|
Reference in New Issue
Block a user