mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[distributed] Replace 164 assert statements in fsdp directory (#165235)
Replace assert statements with explicit if/raise patterns across 20 files: - _optim_utils.py (38 asserts) - _flat_param.py (25 asserts) - _fully_shard/_fsdp_param.py (23 asserts) - sharded_grad_scaler.py (12 asserts) - fully_sharded_data_parallel.py (11 asserts) - wrap.py (10 asserts) - _state_dict_utils.py (9 asserts) - _fully_shard/_fsdp_param_group.py (8 asserts) - _runtime_utils.py (6 asserts) - _init_utils.py (6 asserts) - 10 additional files (16 asserts) This prevents assertions from being disabled with Python -O flag. Fixes partially #164878 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165235 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
6918f17114
commit
c4565c3b94
@ -203,9 +203,10 @@ def _module_handle(state: _FSDPState, module: nn.Module) -> Optional["FlatParamH
|
||||
# handles, meaning no entry in `_fully_sharded_module_to_handles`
|
||||
if state._handle is None:
|
||||
return None
|
||||
assert module in state._fully_sharded_module_to_handle, (
|
||||
f"Expects a fully sharded module but got {module} on rank {state.rank}"
|
||||
)
|
||||
if module not in state._fully_sharded_module_to_handle:
|
||||
raise AssertionError(
|
||||
f"Expects a fully sharded module but got {module} on rank {state.rank}"
|
||||
)
|
||||
return state._fully_sharded_module_to_handle[module]
|
||||
else:
|
||||
# NOTE: This assumes `module` is a `FullyShardedDataParallel` instance.
|
||||
@ -258,9 +259,10 @@ def _named_parameters_with_duplicates(
|
||||
This API is required as some modules overwrite `named_parameters()` but do not support
|
||||
`remove_duplicate`.
|
||||
"""
|
||||
assert "remove_duplicate" not in kwargs, (
|
||||
"_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
|
||||
)
|
||||
if "remove_duplicate" in kwargs:
|
||||
raise AssertionError(
|
||||
"_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument."
|
||||
)
|
||||
kwargs["remove_duplicate"] = False
|
||||
try:
|
||||
ret = list(module.named_parameters(**kwargs))
|
||||
|
@ -39,11 +39,12 @@ class SimpleProfiler:
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def profile(cls, profile_type: str) -> Iterator[None]:
|
||||
assert profile_type not in cls.profiling, (
|
||||
f"{profile_type} is already being profiled. "
|
||||
"SimpleProfiler does not support profiling multiple instances at "
|
||||
"the same time. "
|
||||
)
|
||||
if profile_type in cls.profiling:
|
||||
raise AssertionError(
|
||||
f"{profile_type} is already being profiled. "
|
||||
"SimpleProfiler does not support profiling multiple instances at "
|
||||
"the same time. "
|
||||
)
|
||||
|
||||
cls.profiling.add(profile_type)
|
||||
begin = time.monotonic()
|
||||
@ -129,7 +130,8 @@ def _get_sharded_module_tree_with_module_name_to_fqns(
|
||||
|
||||
if handle:
|
||||
param = handle.flat_param
|
||||
assert isinstance(param, flat_param_file.FlatParameter)
|
||||
if not isinstance(param, flat_param_file.FlatParameter):
|
||||
raise AssertionError(f"Expected FlatParameter, got {type(param)}")
|
||||
global_fqns = [
|
||||
clean_tensor_name(prefix + name) for name in param._fqns
|
||||
] # prefixed from the top level `model` (i.e. including `prefix`)
|
||||
|
@ -214,7 +214,8 @@ class _ExecOrderData:
|
||||
# parameters
|
||||
# TODO (awgu): Since every module has at most one handle in the
|
||||
# current implementation, this should never raise the error.
|
||||
assert self.world_size is not None # mypy
|
||||
if self.world_size is None:
|
||||
raise AssertionError("Expected world_size to not be None")
|
||||
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
|
||||
# TODO(voz): Don't graph break on this - dynamo hates the n1 != n2
|
||||
# tensor comparison control flow.
|
||||
|
@ -360,7 +360,8 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
|
||||
_is_padding_mask: list[bool]
|
||||
|
||||
def __new__(cls, data=None, requires_grad=True):
|
||||
assert cls is FlatParameter, "subclasses FlatParameter not supported"
|
||||
if cls is not FlatParameter:
|
||||
raise AssertionError("subclasses FlatParameter not supported")
|
||||
r = nn.Parameter.__new__(nn.Parameter, data, requires_grad) # type: ignore[call-arg]
|
||||
r._is_flat_param = True # type: ignore[attr-defined]
|
||||
return r
|
||||
@ -398,11 +399,26 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
|
||||
Args:
|
||||
See the Attributes in the class docstring.
|
||||
"""
|
||||
assert len(param_infos) == len(shapes)
|
||||
assert len(param_infos) == len(strides)
|
||||
assert len(param_infos) == len(contiguities)
|
||||
assert len(param_infos) == len(fqns)
|
||||
assert len(param_infos) == len(param_extensions)
|
||||
if len(param_infos) != len(shapes):
|
||||
raise AssertionError(
|
||||
f"Expected param_infos length {len(param_infos)} to match shapes length {len(shapes)}"
|
||||
)
|
||||
if len(param_infos) != len(strides):
|
||||
raise AssertionError(
|
||||
f"Expected param_infos length {len(param_infos)} to match strides length {len(strides)}"
|
||||
)
|
||||
if len(param_infos) != len(contiguities):
|
||||
raise AssertionError(
|
||||
f"Expected param_infos length {len(param_infos)} to match contiguities length {len(contiguities)}"
|
||||
)
|
||||
if len(param_infos) != len(fqns):
|
||||
raise AssertionError(
|
||||
f"Expected param_infos length {len(param_infos)} to match fqns length {len(fqns)}"
|
||||
)
|
||||
if len(param_infos) != len(param_extensions):
|
||||
raise AssertionError(
|
||||
f"Expected param_infos length {len(param_infos)} to match param_extensions length {len(param_extensions)}"
|
||||
)
|
||||
self._num_params = len(param_infos)
|
||||
self._param_infos = param_infos
|
||||
self._shapes = shapes
|
||||
@ -418,22 +434,32 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta):
|
||||
numels_without_padding.append(numel)
|
||||
self._numels = tuple(numels_without_padding)
|
||||
self._numels_with_padding = tuple(numels)
|
||||
assert len(self._numels) == self._num_params
|
||||
if len(self._numels) != self._num_params:
|
||||
raise AssertionError(
|
||||
f"Expected _numels length {len(self._numels)} to equal _num_params {self._num_params}"
|
||||
)
|
||||
|
||||
self._shared_param_infos = tuple(shared_param_infos)
|
||||
self._modules = {pi.module for pi in self._param_infos}.union(
|
||||
{spi.module for spi in self._shared_param_infos}
|
||||
)
|
||||
assert (params is None) == (shared_params is None)
|
||||
if params is not None:
|
||||
assert shared_params is not None and len(shared_params) == len(
|
||||
shared_param_infos
|
||||
if (params is None) != (shared_params is None):
|
||||
raise AssertionError(
|
||||
"Expected params and shared_params to both be None or both be not None"
|
||||
)
|
||||
if params is not None:
|
||||
if shared_params is None or len(shared_params) != len(shared_param_infos):
|
||||
raise AssertionError(
|
||||
f"Expected shared_params to be not None and have length {len(shared_param_infos)}, got {shared_params}"
|
||||
)
|
||||
self._params = []
|
||||
for param, is_padding in zip(params, is_padding_mask):
|
||||
if not is_padding:
|
||||
self._params.append(param)
|
||||
self._shared_params = shared_params
|
||||
if shared_params is not None:
|
||||
self._shared_params = shared_params
|
||||
else:
|
||||
self._shared_params = []
|
||||
# Mark the original parameters to avoid flattening them into
|
||||
# another `FlatParameter` during recursive construction
|
||||
for param in chain(self._params, self._shared_params):
|
||||
@ -579,7 +605,8 @@ class FlatParamHandle:
|
||||
# before `_init_flat_param()`, which performs the actual validation
|
||||
self._orig_param_dtype = params[0].dtype
|
||||
self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype)
|
||||
assert self._fwd_bwd_param_dtype is not None # mypy
|
||||
if self._fwd_bwd_param_dtype is None:
|
||||
raise AssertionError("Expected _fwd_bwd_param_dtype to be not None") # mypy
|
||||
self._aligned_numel = (
|
||||
_get_aligned_numel(unsharded_dtype=self._fwd_bwd_param_dtype)
|
||||
if align_addresses
|
||||
@ -807,7 +834,8 @@ class FlatParamHandle:
|
||||
dtype = tensor.dtype
|
||||
flat_param_requires_grad = flat_param_requires_grad or tensor.requires_grad
|
||||
device = tensor.device
|
||||
assert flat_param_requires_grad is not None, "Requires non-empty `tensors` list"
|
||||
if flat_param_requires_grad is None:
|
||||
raise AssertionError("Requires non-empty `tensors` list")
|
||||
return dtype, flat_param_requires_grad, device
|
||||
|
||||
def flatten_tensors(
|
||||
@ -908,8 +936,10 @@ class FlatParamHandle:
|
||||
else:
|
||||
self._fwd_bwd_param_dtype = mp_param_dtype or self._orig_param_dtype
|
||||
self._reduce_dtype = mp_reduce_dtype or self._orig_param_dtype
|
||||
assert self._fwd_bwd_param_dtype is not None
|
||||
assert self._reduce_dtype is not None
|
||||
if self._fwd_bwd_param_dtype is None:
|
||||
raise AssertionError("Expected _fwd_bwd_param_dtype to be not None")
|
||||
if self._reduce_dtype is None:
|
||||
raise AssertionError("Expected _reduce_dtype to be not None")
|
||||
|
||||
###################################
|
||||
# SHARD INITIALIZATION & METADATA #
|
||||
@ -985,9 +1015,10 @@ class FlatParamHandle:
|
||||
shard_param_infos = self._get_shard_metadata(
|
||||
unsharded_start_idx, unsharded_end_idx
|
||||
)
|
||||
assert len(shard_param_infos) == flat_param._num_params, (
|
||||
f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
|
||||
)
|
||||
if len(shard_param_infos) != flat_param._num_params:
|
||||
raise AssertionError(
|
||||
f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
|
||||
)
|
||||
flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined]
|
||||
flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined]
|
||||
|
||||
@ -1003,9 +1034,10 @@ class FlatParamHandle:
|
||||
unsharded flat parameter specifying the shard.
|
||||
"""
|
||||
flat_param_offsets = self._get_flat_param_offsets()
|
||||
assert len(flat_param_offsets) == len(self.flat_param._numels_with_padding), (
|
||||
f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
|
||||
)
|
||||
if len(flat_param_offsets) != len(self.flat_param._numels_with_padding):
|
||||
raise AssertionError(
|
||||
f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
|
||||
)
|
||||
shard_param_infos: list[_ShardParamInfo] = []
|
||||
sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
|
||||
# `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
|
||||
@ -1033,12 +1065,13 @@ class FlatParamHandle:
|
||||
unsharded_start_idx - unsharded_param_start_idx
|
||||
)
|
||||
offset_in_shard = 0
|
||||
assert (
|
||||
if not (
|
||||
offset_in_shard >= 0 and offset_in_shard < sharded_flat_param_numel
|
||||
), (
|
||||
f"Invalid `offset_in_shard` of {offset_in_shard} for "
|
||||
f"sharded flat parameter with {sharded_flat_param_numel} numel"
|
||||
)
|
||||
):
|
||||
raise AssertionError(
|
||||
f"Invalid `offset_in_shard` of {offset_in_shard} for "
|
||||
f"sharded flat parameter with {sharded_flat_param_numel} numel"
|
||||
)
|
||||
intra_param_end_idx = (
|
||||
min(unsharded_param_end_idx, unsharded_end_idx)
|
||||
- unsharded_param_start_idx
|
||||
@ -1082,9 +1115,10 @@ class FlatParamHandle:
|
||||
else:
|
||||
chunk = chunks[rank]
|
||||
numel_to_pad = chunks[0].numel() - chunk.numel()
|
||||
assert numel_to_pad >= 0, (
|
||||
"Chunk's size should be at most the first chunk's size"
|
||||
)
|
||||
if numel_to_pad < 0:
|
||||
raise AssertionError(
|
||||
"Chunk's size should be at most the first chunk's size"
|
||||
)
|
||||
return chunk, numel_to_pad
|
||||
|
||||
@staticmethod
|
||||
@ -1115,12 +1149,16 @@ class FlatParamHandle:
|
||||
This requires ``tensor`` to have 1D shape and ensures that the returned
|
||||
shape is 1D.
|
||||
"""
|
||||
assert len(tensor.shape) == 1, f"{tensor.shape}"
|
||||
if len(tensor.shape) != 1:
|
||||
raise AssertionError(f"Expected 1D tensor shape, got {tensor.shape}")
|
||||
unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard(
|
||||
tensor, rank, world_size
|
||||
)
|
||||
unpadded_sharded_size = unpadded_sharded_tensor.size()
|
||||
assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}"
|
||||
if len(unpadded_sharded_size) != 1:
|
||||
raise AssertionError(
|
||||
f"Expected 1D unpadded_sharded_size, got {unpadded_sharded_size}"
|
||||
)
|
||||
return torch.Size([unpadded_sharded_size[0] + numel_to_pad])
|
||||
|
||||
def _get_flat_param_offsets(self) -> list[tuple[int, int]]:
|
||||
@ -2059,7 +2097,7 @@ class FlatParamHandle:
|
||||
_p_assert(
|
||||
hasattr(module, param_name),
|
||||
f"{module_name + '.' + param_name if module_name else param_name} is missing",
|
||||
) # did not save FQN info in `_shared_param_infos`
|
||||
)
|
||||
param = getattr(module, param_name)
|
||||
prim_param = getattr(prim_module, prim_param_name)
|
||||
if (
|
||||
@ -2130,7 +2168,8 @@ class FlatParamHandle:
|
||||
offset = shard_param_info.offset_in_shard
|
||||
numel_in_shard = shard_param_info.numel_in_shard
|
||||
param.data = flat_param[offset : offset + numel_in_shard]
|
||||
assert self.flat_param._shared_params is not None
|
||||
if self.flat_param._shared_params is None:
|
||||
raise AssertionError("Expected _shared_params to be not None")
|
||||
for i, (
|
||||
param,
|
||||
(param_name, module, _, prim_param_name, prim_module, _),
|
||||
@ -2194,7 +2233,8 @@ class FlatParamHandle:
|
||||
)
|
||||
else:
|
||||
param.grad = None
|
||||
assert flat_param._shared_params is not None
|
||||
if flat_param._shared_params is None:
|
||||
raise AssertionError("Expected _shared_params to be not None")
|
||||
for param, (_, _, _, prim_param_name, prim_module, _) in zip(
|
||||
flat_param._shared_params, flat_param._shared_param_infos
|
||||
):
|
||||
@ -2408,7 +2448,8 @@ class FlatParamHandle:
|
||||
dst_tensor[offset : offset + expected_shape.numel()].copy_(src_tensor)
|
||||
else:
|
||||
dst_tensor[offset : offset + expected_shape.numel()].zero_()
|
||||
assert self.flat_param._is_grad_none_mask is not None
|
||||
if self.flat_param._is_grad_none_mask is None:
|
||||
raise AssertionError("Expected _is_grad_none_mask to be not None")
|
||||
self.flat_param._is_grad_none_mask[tensor_index] = True
|
||||
|
||||
def _reset_flat_param_grad_info_if_needed(self):
|
||||
@ -2427,7 +2468,8 @@ class FlatParamHandle:
|
||||
if not self._use_orig_params:
|
||||
return
|
||||
flat_param = self.flat_param
|
||||
assert flat_param._params is not None # mypy
|
||||
if flat_param._params is None:
|
||||
raise AssertionError("Expected _params to be not None") # mypy
|
||||
all_grad_none = True
|
||||
requires_grad = False
|
||||
for param in flat_param._params:
|
||||
@ -2571,12 +2613,16 @@ class FlatParamHandle:
|
||||
"Expects to only be called in the post-backward after gradient computation",
|
||||
)
|
||||
flat_param = self.flat_param
|
||||
assert flat_param._params is not None # mypy
|
||||
if flat_param._params is None:
|
||||
raise AssertionError("Expected _params to be not None") # mypy
|
||||
for i, param in enumerate(flat_param._params): # type: ignore[arg-type]
|
||||
# As long as the parameter requires gradient, it should receive a
|
||||
# meaningful gradient (even if the gradient happens to be zeros)
|
||||
if param.requires_grad:
|
||||
assert flat_param._is_grad_none_mask is not None # mypy
|
||||
if flat_param._is_grad_none_mask is None:
|
||||
raise AssertionError(
|
||||
"Expected _is_grad_none_mask to be not None"
|
||||
) # mypy
|
||||
flat_param._is_grad_none_mask[i] = False
|
||||
|
||||
#######################
|
||||
|
@ -161,7 +161,8 @@ def _ext_pre_load_state_dict_transform(
|
||||
if fsdp_extension is not None:
|
||||
return fsdp_extension.pre_load_state_dict_transform(tensor)
|
||||
|
||||
assert type(tensor) is ShardedTensor
|
||||
if type(tensor) is not ShardedTensor:
|
||||
raise AssertionError(f"Expected ShardedTensor, got {type(tensor)}")
|
||||
shards = tensor.local_shards()
|
||||
return (tensor, shards)
|
||||
|
||||
|
@ -502,9 +502,10 @@ def foreach_reduce(
|
||||
):
|
||||
if (shard_dim := fsdp_param.fsdp_placement.dim) == 0:
|
||||
continue
|
||||
assert unsharded_grad.size(shard_dim) % world_size == 0, (
|
||||
f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}"
|
||||
)
|
||||
if unsharded_grad.size(shard_dim) % world_size != 0:
|
||||
raise AssertionError(
|
||||
f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}"
|
||||
)
|
||||
chunks = torch.chunk(unsharded_grad, world_size, dim=shard_dim)
|
||||
unsharded_grads[i] = torch.cat(chunks, dim=0)
|
||||
|
||||
@ -621,7 +622,10 @@ def foreach_reduce(
|
||||
# ensure that the D2H copy finishes before the optimizer
|
||||
fsdp_param.grad_offload_event = post_reduce_stream.record_event()
|
||||
if to_accumulate_grad:
|
||||
assert isinstance(fsdp_param.sharded_param.grad, DTensor)
|
||||
if not isinstance(fsdp_param.sharded_param.grad, DTensor):
|
||||
raise AssertionError(
|
||||
f"Expected fsdp_param.sharded_param.grad to be DTensor, got {type(fsdp_param.sharded_param.grad)}"
|
||||
)
|
||||
fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad
|
||||
else:
|
||||
new_sharded_dtensor_grad = fsdp_param.to_sharded_dtensor(
|
||||
|
@ -17,9 +17,10 @@ _compiled_autograd_enabled: bool = False
|
||||
|
||||
|
||||
def detect_compiled_autograd():
|
||||
assert not torch.compiler.is_compiling(), (
|
||||
"`detect_compiled_autograd()` is designed to be called in eager mode"
|
||||
)
|
||||
if torch.compiler.is_compiling():
|
||||
raise AssertionError(
|
||||
"`detect_compiled_autograd()` is designed to be called in eager mode"
|
||||
)
|
||||
global _compiled_autograd_enabled
|
||||
import torch._dynamo.compiled_autograd as ca
|
||||
|
||||
|
@ -275,7 +275,10 @@ class FSDPParam:
|
||||
fsdp_placement = Shard(0)
|
||||
elif fsdp_placement.dim < 0:
|
||||
fsdp_placement = Shard(fsdp_placement.dim + param.ndim)
|
||||
assert isinstance(fsdp_placement, Shard), f"{fsdp_placement}"
|
||||
if not isinstance(fsdp_placement, Shard):
|
||||
raise AssertionError(
|
||||
f"Expected Shard, got {type(fsdp_placement)}: {fsdp_placement}"
|
||||
)
|
||||
self.fsdp_placement = fsdp_placement
|
||||
shard_dim = fsdp_placement.dim
|
||||
# TODO: Replace the sharded DTensor parameter construction logic with
|
||||
@ -296,8 +299,10 @@ class FSDPParam:
|
||||
f"DP's global mesh: {dp_global_mesh}\nTP/EP's global mesh: {tp_global_mesh}"
|
||||
)
|
||||
name_dims_error = "FSDP requires named DeviceMesh dims for ND parallelism"
|
||||
assert dp_mesh.mesh_dim_names is not None, name_dims_error
|
||||
assert tp_mesh.mesh_dim_names is not None, name_dims_error
|
||||
if dp_mesh.mesh_dim_names is None:
|
||||
raise AssertionError(name_dims_error)
|
||||
if tp_mesh.mesh_dim_names is None:
|
||||
raise AssertionError(name_dims_error)
|
||||
submesh_names = dp_mesh.mesh_dim_names + tp_mesh.mesh_dim_names
|
||||
self._spmd_mesh = dp_global_mesh[submesh_names]
|
||||
if len(self._tp_spec.placements) > 2:
|
||||
@ -305,10 +310,11 @@ class FSDPParam:
|
||||
f"FSDP only supports 1D TP/EP or 2D EP+TP, not {self._tp_spec.placements}"
|
||||
)
|
||||
split_factor = self._tp_spec.num_shards_map[shard_dim]
|
||||
assert 2 <= self._spmd_mesh.ndim <= 4, (
|
||||
"_spmd_mesh.ndim can only be 2 (FSDP+TP/EP), 3 (FSDP+EP+TP, HSDP+TP/EP), "
|
||||
f"or 4 (HSDP+EP+TP) but got {self._spmd_mesh.ndim}."
|
||||
)
|
||||
if not (2 <= self._spmd_mesh.ndim <= 4):
|
||||
raise AssertionError(
|
||||
"_spmd_mesh.ndim can only be 2 (FSDP+TP/EP), 3 (FSDP+EP+TP, HSDP+TP/EP), "
|
||||
f"or 4 (HSDP+EP+TP) but got {self._spmd_mesh.ndim}."
|
||||
)
|
||||
self._spmd_placements: tuple[Placement, ...]
|
||||
dp_shard_tp_placement = (
|
||||
(
|
||||
@ -321,7 +327,10 @@ class FSDPParam:
|
||||
if dp_mesh.ndim == 1: # FSDP
|
||||
self._spmd_placements = dp_shard_tp_placement
|
||||
else: # HSDP
|
||||
assert self.mesh_info.replicate_mesh_dim == 0
|
||||
if self.mesh_info.replicate_mesh_dim != 0:
|
||||
raise AssertionError(
|
||||
f"Expected replicate_mesh_dim to be 0, got {self.mesh_info.replicate_mesh_dim}"
|
||||
)
|
||||
self._spmd_placements = (Replicate(),) + dp_shard_tp_placement
|
||||
self._sharding_spec = DTensorSpec(
|
||||
self._spmd_mesh,
|
||||
@ -341,7 +350,10 @@ class FSDPParam:
|
||||
tensor_meta=TensorMeta(param.size(), param.stride(), param.dtype),
|
||||
)
|
||||
param_data = param
|
||||
assert param_data.is_contiguous(), f"{param_data.shape=} {param_data.stride()=}"
|
||||
if not param_data.is_contiguous():
|
||||
raise AssertionError(
|
||||
f"Expected contiguous tensor, got {param_data.shape=} {param_data.stride()=}"
|
||||
)
|
||||
shard_dim = fsdp_placement.dim
|
||||
if shard_dim >= param_data.ndim:
|
||||
raise AssertionError(
|
||||
@ -383,7 +395,10 @@ class FSDPParam:
|
||||
sharded_param = padded_sharded_param.narrow(
|
||||
dim=shard_dim, start=0, length=length
|
||||
)
|
||||
assert sharded_param.is_contiguous(), f"{self.fsdp_placement=}"
|
||||
if not sharded_param.is_contiguous():
|
||||
raise AssertionError(
|
||||
f"Expected contiguous tensor with {self.fsdp_placement=}"
|
||||
)
|
||||
self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param))
|
||||
self.sharded_param.requires_grad_(param.requires_grad)
|
||||
# Let `param_data` be freed normally when its ref count reaches 0 when
|
||||
@ -393,7 +408,8 @@ class FSDPParam:
|
||||
|
||||
def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None:
|
||||
mesh_info = self.post_forward_mesh_info
|
||||
assert mesh_info is not None # mypy
|
||||
if mesh_info is None:
|
||||
raise AssertionError("Expected post_forward_mesh_info to not be None")
|
||||
param_data = param._local_tensor if isinstance(param, DTensor) else param
|
||||
chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0)
|
||||
self.sharded_post_forward_size = _get_dim_chunked_size(
|
||||
@ -498,7 +514,10 @@ class FSDPParam:
|
||||
else:
|
||||
# For the default path (no post-all-gather), the all-gather output
|
||||
# gives the unsharded parameter data directly
|
||||
assert len(self.all_gather_outputs) == 1, f"{len(self.all_gather_outputs)}"
|
||||
if len(self.all_gather_outputs) != 1:
|
||||
raise AssertionError(
|
||||
f"Expected 1 all_gather_output, got {len(self.all_gather_outputs)}"
|
||||
)
|
||||
unsharded_tensor = self.all_gather_outputs[0]
|
||||
unsharded_param = torch.as_strided(
|
||||
unsharded_tensor,
|
||||
@ -509,7 +528,8 @@ class FSDPParam:
|
||||
if self.is_dtensor:
|
||||
unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec)
|
||||
if hasattr(self, "_unsharded_param"):
|
||||
assert compiled_autograd_enabled()
|
||||
if not compiled_autograd_enabled():
|
||||
raise AssertionError("Expected compiled_autograd to be enabled")
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autograd._unsafe_preserve_version_counter(self._unsharded_param),
|
||||
@ -546,8 +566,12 @@ class FSDPParam:
|
||||
"Resharding to smaller mesh with TP is not supported yet"
|
||||
)
|
||||
self._assert_in_states(ShardedState.UNSHARDED)
|
||||
assert self.post_forward_mesh_info is not None # mypy
|
||||
assert len(self.all_gather_outputs) == 1
|
||||
if self.post_forward_mesh_info is None:
|
||||
raise AssertionError("Expected post_forward_mesh_info to not be None")
|
||||
if len(self.all_gather_outputs) != 1:
|
||||
raise AssertionError(
|
||||
f"Expected 1 all_gather_output, got {len(self.all_gather_outputs)}"
|
||||
)
|
||||
shard_world_size = self.post_forward_mesh_info.shard_mesh_size
|
||||
if (numel := self.all_gather_outputs[0].numel()) % shard_world_size != 0:
|
||||
_raise_assert_with_print(
|
||||
@ -616,7 +640,10 @@ class FSDPParam:
|
||||
_raise_assert_with_print(
|
||||
f"Expects size {self.sharded_post_forward_size} but got {tensor.shape}"
|
||||
)
|
||||
assert isinstance(self.post_forward_mesh_info, HSDPMeshInfo)
|
||||
if not isinstance(self.post_forward_mesh_info, HSDPMeshInfo):
|
||||
raise AssertionError(
|
||||
f"Expected HSDPMeshInfo, got {type(self.post_forward_mesh_info)}"
|
||||
)
|
||||
# TODO: Prefer this DTensor to be read-only and generalize the
|
||||
# placement once we support TP.
|
||||
post_forward_sharding_spec = DTensorSpec(
|
||||
@ -691,15 +718,13 @@ class FSDPParam:
|
||||
)
|
||||
num_fn_params = len(pre_all_gather_signature.parameters)
|
||||
# Old signature only passes mesh; keep for BC for now
|
||||
assert num_fn_params in (
|
||||
1,
|
||||
5,
|
||||
), (
|
||||
f"Invalid fsdp_pre_all_gather: {pre_all_gather_signature}\n"
|
||||
"Expects fsdp_pre_all_gather(self, mesh: DeviceMesh, "
|
||||
"outer_size: torch.Size, outer_stride: tuple[int, ...], "
|
||||
"module: nn.Module, mp_policy: MixedPrecisionPolicy)"
|
||||
)
|
||||
if num_fn_params not in (1, 5):
|
||||
raise AssertionError(
|
||||
f"Invalid fsdp_pre_all_gather: {pre_all_gather_signature}\n"
|
||||
"Expects fsdp_pre_all_gather(self, mesh: DeviceMesh, "
|
||||
"outer_size: torch.Size, outer_stride: tuple[int, ...], "
|
||||
"module: nn.Module, mp_policy: MixedPrecisionPolicy)"
|
||||
)
|
||||
if num_fn_params == 1:
|
||||
(
|
||||
all_gather_inputs,
|
||||
@ -765,25 +790,29 @@ class FSDPParam:
|
||||
@property
|
||||
def unsharded_grad_data(self) -> torch.Tensor:
|
||||
grad = self.unsharded_param.grad
|
||||
assert grad is not None, "Expects unsharded_param.grad to not be None"
|
||||
if grad is None:
|
||||
raise AssertionError("Expects unsharded_param.grad to not be None")
|
||||
return self._get_grad_inner_tensor(grad)
|
||||
|
||||
@property
|
||||
def unsharded_accumulated_grad_data(self) -> torch.Tensor:
|
||||
grad = self.unsharded_accumulated_grad
|
||||
assert grad is not None, "Expects unsharded_accumulated_grad to not be None"
|
||||
if grad is None:
|
||||
raise AssertionError("Expects unsharded_accumulated_grad to not be None")
|
||||
return self._get_grad_inner_tensor(grad)
|
||||
|
||||
def _get_grad_inner_tensor(self, grad: torch.Tensor) -> torch.Tensor:
|
||||
if self.is_dtensor:
|
||||
if isinstance(grad, AsyncCollectiveTensor):
|
||||
grad = grad.wait()
|
||||
assert isinstance(grad, DTensor), f"{type(grad)}"
|
||||
if not isinstance(grad, DTensor):
|
||||
raise AssertionError(f"Expected DTensor, got {type(grad)}")
|
||||
placements = self._tp_spec.placements
|
||||
if placements != grad.placements:
|
||||
assert len(self._tp_spec.placements) == len(grad.placements), (
|
||||
f"{self._tp_spec=} {grad.placements=}"
|
||||
)
|
||||
if len(self._tp_spec.placements) != len(grad.placements):
|
||||
raise AssertionError(
|
||||
f"Expected same placement length: {self._tp_spec=} {grad.placements=}"
|
||||
)
|
||||
grad = grad.redistribute(placements=placements)
|
||||
grad = grad._local_tensor
|
||||
return grad
|
||||
@ -798,7 +827,8 @@ class FSDPParam:
|
||||
if mesh.ndim == 1:
|
||||
return mesh
|
||||
elif mesh.ndim == 2:
|
||||
assert mesh.mesh_dim_names is not None
|
||||
if mesh.mesh_dim_names is None:
|
||||
raise AssertionError("Expected mesh_dim_names to not be None")
|
||||
return mesh[mesh.mesh_dim_names[-1]]
|
||||
raise ValueError(f"Invalid mesh: {mesh}")
|
||||
|
||||
@ -809,7 +839,8 @@ class FSDPParam:
|
||||
if mesh.ndim == 1:
|
||||
return mesh
|
||||
else:
|
||||
assert mesh.mesh_dim_names is not None
|
||||
if mesh.mesh_dim_names is None:
|
||||
raise AssertionError("Expected mesh_dim_names to not be None")
|
||||
shard_dim_name = mesh.mesh_dim_names[-1]
|
||||
|
||||
root_mesh = _mesh_resources.get_root_mesh(mesh)
|
||||
@ -860,9 +891,10 @@ class FSDPParam:
|
||||
shard_dim = self.fsdp_placement.dim
|
||||
length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0
|
||||
if local_tensor.size() != padded_sharded_size and not same_local_tensor:
|
||||
assert shard_dim == 0, (
|
||||
f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}"
|
||||
)
|
||||
if shard_dim != 0:
|
||||
raise AssertionError(
|
||||
f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}"
|
||||
)
|
||||
padded_local_tensor = local_tensor.new_zeros(padded_sharded_size)
|
||||
padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_(
|
||||
local_tensor
|
||||
@ -874,13 +906,17 @@ class FSDPParam:
|
||||
updated_local_tensor = True
|
||||
if not same_local_tensor:
|
||||
self._sharded_param_data = local_tensor.view(-1)
|
||||
assert isinstance(self.sharded_param, DTensor) # mypy
|
||||
if not isinstance(self.sharded_param, DTensor):
|
||||
raise AssertionError(f"Expected DTensor, got {type(self.sharded_param)}")
|
||||
if updated_local_tensor:
|
||||
# Only change the local tensor object if needed
|
||||
self.sharded_param._local_tensor = local_tensor.narrow(
|
||||
dim=shard_dim, start=0, length=length
|
||||
)
|
||||
assert self.sharded_param._local_tensor.is_contiguous()
|
||||
if not self.sharded_param._local_tensor.is_contiguous():
|
||||
raise AssertionError(
|
||||
"Expected sharded_param._local_tensor to be contiguous"
|
||||
)
|
||||
self._sharding_spec = self.sharded_param._spec
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -273,25 +273,27 @@ class FSDPParamGroup:
|
||||
Whether to (try to) use the ProcessGroup's allocate_tensor method for
|
||||
the staging buffers for collective comms.
|
||||
"""
|
||||
assert isinstance(
|
||||
if not isinstance(
|
||||
self._all_gather_comm, (DefaultAllGather | ProcessGroupAllocAllGather)
|
||||
), (
|
||||
"cannot call set_allocate_memory_from_process_group() "
|
||||
f"when all gather comm is custom: {self._all_gather_comm.__class__.__name__}"
|
||||
)
|
||||
):
|
||||
raise AssertionError(
|
||||
"cannot call set_allocate_memory_from_process_group() "
|
||||
f"when all gather comm is custom: {self._all_gather_comm.__class__.__name__}"
|
||||
)
|
||||
self._all_gather_comm = (
|
||||
ProcessGroupAllocAllGather(self._all_gather_process_group)
|
||||
if enable
|
||||
else DefaultAllGather()
|
||||
)
|
||||
|
||||
assert isinstance(
|
||||
if not isinstance(
|
||||
self._reduce_scatter_comm,
|
||||
(DefaultReduceScatter | ProcessGroupAllocReduceScatter),
|
||||
), (
|
||||
"cannot call set_allocate_memory_from_process_group() "
|
||||
f"when reduce scatter comm is custom: {self._reduce_scatter_comm.__class__.__name__}"
|
||||
)
|
||||
):
|
||||
raise AssertionError(
|
||||
"cannot call set_allocate_memory_from_process_group() "
|
||||
f"when reduce scatter comm is custom: {self._reduce_scatter_comm.__class__.__name__}"
|
||||
)
|
||||
self._reduce_scatter_comm = (
|
||||
ProcessGroupAllocReduceScatter(self._reduce_scatter_process_group)
|
||||
if enable
|
||||
@ -536,9 +538,10 @@ class FSDPParamGroup:
|
||||
if all_reduce_pg is None and self._all_reduce_hook_stream is not None:
|
||||
# this means the native HSDP is not enabled,
|
||||
# but user may want to have a custom HSDP setup
|
||||
assert self._all_reduce_hook is not None, (
|
||||
"all reduce hook stream is specified but hook itself is missing."
|
||||
)
|
||||
if self._all_reduce_hook is None:
|
||||
raise AssertionError(
|
||||
"all reduce hook stream is specified but hook itself is missing."
|
||||
)
|
||||
all_reduce_stream = self._all_reduce_hook_stream
|
||||
else:
|
||||
all_reduce_stream = self.comm_ctx.all_reduce_stream
|
||||
@ -573,7 +576,10 @@ class FSDPParamGroup:
|
||||
)
|
||||
if all_reduce_input is not None:
|
||||
if self.device.type != "cpu":
|
||||
assert all_reduce_event is not None
|
||||
if all_reduce_event is None:
|
||||
raise AssertionError(
|
||||
"Expected all_reduce_event to be set for non-CPU device"
|
||||
)
|
||||
self._all_reduce_state = AllReduceState(
|
||||
all_reduce_input, all_reduce_event
|
||||
)
|
||||
@ -712,9 +718,10 @@ class FSDPParamGroup:
|
||||
def _register_state_dict_hooks(self) -> None:
|
||||
num_pre_save_hooks = len(self._module_to_pre_save_state_dict_hook_handle)
|
||||
num_pre_load_hooks = len(self._module_to_pre_load_state_dict_hook_handle)
|
||||
assert num_pre_save_hooks == num_pre_load_hooks, (
|
||||
f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}"
|
||||
)
|
||||
if num_pre_save_hooks != num_pre_load_hooks:
|
||||
raise AssertionError(
|
||||
f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}"
|
||||
)
|
||||
if num_pre_save_hooks > 0:
|
||||
return # already registered
|
||||
modules_with_fsdp_params: set[nn.Module] = {
|
||||
@ -755,17 +762,26 @@ class FSDPParamGroup:
|
||||
if self.is_sharded_post_forward
|
||||
else self.mesh_info
|
||||
)
|
||||
assert isinstance(mesh_info, FSDPMeshInfo)
|
||||
if not isinstance(mesh_info, FSDPMeshInfo):
|
||||
raise AssertionError(
|
||||
f"Expected mesh_info to be FSDPMeshInfo, got {type(mesh_info)}"
|
||||
)
|
||||
return mesh_info.shard_process_group
|
||||
|
||||
@property
|
||||
def _reduce_scatter_process_group(self) -> dist.ProcessGroup:
|
||||
assert isinstance(self.mesh_info, FSDPMeshInfo)
|
||||
if not isinstance(self.mesh_info, FSDPMeshInfo):
|
||||
raise AssertionError(
|
||||
f"Expected mesh_info to be FSDPMeshInfo, got {type(self.mesh_info)}"
|
||||
)
|
||||
return self.mesh_info.shard_process_group
|
||||
|
||||
@property
|
||||
def _all_reduce_process_group(self) -> dist.ProcessGroup:
|
||||
assert isinstance(self.mesh_info, HSDPMeshInfo)
|
||||
if not isinstance(self.mesh_info, HSDPMeshInfo):
|
||||
raise AssertionError(
|
||||
f"Expected mesh_info to be HSDPMeshInfo, got {type(self.mesh_info)}"
|
||||
)
|
||||
return self.mesh_info.replicate_process_group
|
||||
|
||||
def _with_fqn(self, label: str) -> str:
|
||||
@ -834,7 +850,7 @@ def _get_param_module_infos(
|
||||
param_name
|
||||
)
|
||||
if len(param_to_module_info) != len(params):
|
||||
raise AssertionError(f"Some parameters are not in the module tree of {module}")
|
||||
raise AssertionError(f"Some parameters are not in the module tree of {modules}")
|
||||
return [param_to_module_info[param] for param in params]
|
||||
|
||||
|
||||
|
@ -203,7 +203,8 @@ class FSDPState(_State):
|
||||
|
||||
def _init_fqns(self) -> None:
|
||||
"""Sets module and parameter FQN attributes for debugging."""
|
||||
assert self._is_root
|
||||
if not self._is_root:
|
||||
raise AssertionError("Expected _is_root to be True")
|
||||
root_module = self._modules[0]
|
||||
param_to_fsdp_param: dict[nn.Parameter, FSDPParam] = {}
|
||||
module_to_fsdp_param_group: dict[nn.Module, FSDPParamGroup] = {}
|
||||
@ -222,7 +223,10 @@ class FSDPState(_State):
|
||||
if module_fqn is None:
|
||||
module_to_fsdp_param_group[module]._module_fqn = module_name
|
||||
else:
|
||||
assert isinstance(module_fqn, str), f"{module_fqn}"
|
||||
if not isinstance(module_fqn, str):
|
||||
raise AssertionError(
|
||||
f"Expected module_fqn to be str, got {type(module_fqn)}: {module_fqn}"
|
||||
)
|
||||
module_fqn += f", {module_name}"
|
||||
module_to_fsdp_param_group[module]._module_fqn = module_fqn
|
||||
|
||||
|
@ -243,9 +243,10 @@ def _init_inter_node_process_group(
|
||||
if local_rank == my_local_rank:
|
||||
inter_node_pg = grp
|
||||
|
||||
assert inter_node_pg is not None, (
|
||||
f"{my_local_rank} expected to assign inter-node pg, but did not"
|
||||
)
|
||||
if inter_node_pg is None:
|
||||
raise AssertionError(
|
||||
f"{my_local_rank} expected to assign inter-node pg, but did not"
|
||||
)
|
||||
return inter_node_pg
|
||||
|
||||
|
||||
@ -548,7 +549,8 @@ def _verify_managed_params(module: nn.Module, params: list[nn.Parameter]) -> Non
|
||||
if param is param_:
|
||||
param_name = name
|
||||
break
|
||||
assert param_name
|
||||
if not param_name:
|
||||
raise AssertionError("Expected param_name to be set")
|
||||
raise ValueError(
|
||||
"FSDP doesn't support scalar parameters. "
|
||||
f"Change {param_name} to a 1D tensor with numel equal to 1."
|
||||
@ -646,7 +648,8 @@ def _init_param_handle_from_params(
|
||||
fsdp_extension=state._fsdp_extension,
|
||||
)
|
||||
handle.shard()
|
||||
assert not state._handle
|
||||
if state._handle:
|
||||
raise AssertionError("Expected state._handle to be None")
|
||||
state.params.append(handle.flat_param)
|
||||
state._handle = handle
|
||||
state._fully_sharded_module_to_handle[handle._fully_sharded_module] = handle
|
||||
@ -707,7 +710,10 @@ def _get_ignored_modules(
|
||||
for submodule in root_module.modules():
|
||||
optional_fsdp_state = _get_module_fsdp_state(submodule)
|
||||
if optional_fsdp_state is not None:
|
||||
assert hasattr(optional_fsdp_state, "_ignored_modules")
|
||||
if not hasattr(optional_fsdp_state, "_ignored_modules"):
|
||||
raise AssertionError(
|
||||
"Expected optional_fsdp_state to have _ignored_modules attribute"
|
||||
)
|
||||
ignored_modules.update(optional_fsdp_state._ignored_modules)
|
||||
return ignored_modules
|
||||
|
||||
@ -740,7 +746,10 @@ def _get_ignored_params(
|
||||
for submodule in root_module.modules():
|
||||
optional_fsdp_state = _get_module_fsdp_state(submodule)
|
||||
if optional_fsdp_state is not None:
|
||||
assert hasattr(optional_fsdp_state, "_ignored_params")
|
||||
if not hasattr(optional_fsdp_state, "_ignored_params"):
|
||||
raise AssertionError(
|
||||
"Expected optional_fsdp_state to have _ignored_params attribute"
|
||||
)
|
||||
all_ignored_params.update(optional_fsdp_state._ignored_params)
|
||||
|
||||
return all_ignored_params
|
||||
@ -769,7 +778,10 @@ def _get_ignored_buffer_names(
|
||||
for submodule in root_module.modules():
|
||||
optional_fsdp_state = _get_module_fsdp_state(submodule)
|
||||
if optional_fsdp_state is not None:
|
||||
assert hasattr(optional_fsdp_state, "_ignored_buffer_names")
|
||||
if not hasattr(optional_fsdp_state, "_ignored_buffer_names"):
|
||||
raise AssertionError(
|
||||
"Expected optional_fsdp_state to have _ignored_buffer_names attribute"
|
||||
)
|
||||
all_ignored_buffer_names.update(optional_fsdp_state._ignored_buffer_names)
|
||||
|
||||
return all_ignored_buffer_names
|
||||
|
@ -146,9 +146,8 @@ def _unflatten_optim_state(
|
||||
dict will need to map these entries using the proper unflattened
|
||||
parameter IDs.
|
||||
"""
|
||||
assert not shard_state or to_save, (
|
||||
"If ``shard_state`` is True, ``to_save`` has to be True."
|
||||
)
|
||||
if shard_state and not to_save:
|
||||
raise AssertionError("If ``shard_state`` is True, ``to_save`` has to be True.")
|
||||
consolidated_state = _communicate_optim_state(
|
||||
fsdp_param_info,
|
||||
flat_param_state,
|
||||
@ -219,9 +218,8 @@ def _communicate_optim_state(
|
||||
):
|
||||
tensor_state[state_name] = value
|
||||
continue
|
||||
assert fsdp_state.compute_device is not None, (
|
||||
"compute_device has not been initialized"
|
||||
)
|
||||
if fsdp_state.compute_device is None:
|
||||
raise AssertionError("compute_device has not been initialized")
|
||||
if value.device.type != fsdp_state.compute_device.type:
|
||||
value = value.to(fsdp_state.compute_device)
|
||||
# Assume that positive-dimension tensor optimizer state
|
||||
@ -294,7 +292,10 @@ def _unflatten_communicated_optim_state(
|
||||
if shard_state:
|
||||
osd_config = fsdp_state._optim_state_dict_config
|
||||
if getattr(osd_config, "_use_dtensor", False):
|
||||
assert fsdp_state._device_mesh is not None
|
||||
if fsdp_state._device_mesh is None:
|
||||
raise AssertionError(
|
||||
f"Expected _device_mesh to be not None, got {fsdp_state._device_mesh}"
|
||||
)
|
||||
optim_state = _ext_chunk_dtensor(
|
||||
optim_state,
|
||||
fsdp_state.rank,
|
||||
@ -302,7 +303,10 @@ def _unflatten_communicated_optim_state(
|
||||
fsdp_state._fsdp_extension,
|
||||
)
|
||||
else:
|
||||
assert fsdp_state.process_group is not None
|
||||
if fsdp_state.process_group is None:
|
||||
raise AssertionError(
|
||||
f"Expected process_group to be not None, got {fsdp_state.process_group}"
|
||||
)
|
||||
optim_state = _ext_chunk_tensor(
|
||||
optim_state,
|
||||
fsdp_state.rank,
|
||||
@ -349,10 +353,11 @@ def _broadcast_state(
|
||||
tensor = state.to(fsdp_state.compute_device)
|
||||
else:
|
||||
if isinstance(state, torch.Tensor):
|
||||
assert state.dim() == 0, (
|
||||
"For non-zero ranks, a tensor state should have zero dimension, "
|
||||
f"but got the state with shape {state.shape}."
|
||||
)
|
||||
if state.dim() != 0:
|
||||
raise AssertionError(
|
||||
"For non-zero ranks, a tensor state should have zero dimension, "
|
||||
f"but got the state with shape {state.shape}."
|
||||
)
|
||||
return state
|
||||
elif not isinstance(state, _PosDimTensorInfo):
|
||||
return state
|
||||
@ -491,9 +496,10 @@ def _flatten_optim_state_dict(
|
||||
if flat_state:
|
||||
flat_osd_state[key] = flat_state
|
||||
elif use_orig_params:
|
||||
assert len(fqns) == 1, (
|
||||
f"use_orig_params is True but there are multiple FQNs, {fqns}."
|
||||
)
|
||||
if len(fqns) != 1:
|
||||
raise AssertionError(
|
||||
f"use_orig_params is True but there are multiple FQNs, {fqns}."
|
||||
)
|
||||
if optim is not None: # NamedOptimizer or KeyedOptimizer case.
|
||||
state = optim.state.get(param, None) # type: ignore[call-overload]
|
||||
if state is not None:
|
||||
@ -509,7 +515,8 @@ def _flatten_optim_state_dict(
|
||||
"use_orig_params=True."
|
||||
)
|
||||
else: # do not flatten non-FSDP parameters' states
|
||||
assert len(fqns) == 1
|
||||
if len(fqns) != 1:
|
||||
raise AssertionError(f"Expected len(fqns) == 1, got {len(fqns)}")
|
||||
key = _OptimStateKey(tuple(fqns), False)
|
||||
flat_osd_state[key] = copy.copy(unflat_osd_state[fqn])
|
||||
|
||||
@ -571,14 +578,16 @@ def _flatten_optim_state(
|
||||
handle = fsdp_param_info.handle
|
||||
flat_param = handle.flat_param
|
||||
num_unflat_params = len(unflat_param_names)
|
||||
assert num_unflat_params > 0, (
|
||||
"Expects at least one unflattened parameter corresponding to the flat parameter"
|
||||
)
|
||||
if num_unflat_params <= 0:
|
||||
raise AssertionError(
|
||||
"Expects at least one unflattened parameter corresponding to the flat parameter"
|
||||
)
|
||||
unflat_param_shapes = flat_param._shapes
|
||||
num_unflat_param_shapes = len(unflat_param_shapes)
|
||||
assert num_unflat_params == num_unflat_param_shapes, (
|
||||
f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}"
|
||||
)
|
||||
if num_unflat_params != num_unflat_param_shapes:
|
||||
raise AssertionError(
|
||||
f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}"
|
||||
)
|
||||
|
||||
# Check if these unflattened parameters have any optimizer state
|
||||
has_state = [
|
||||
@ -615,7 +624,8 @@ def _flatten_optim_state(
|
||||
"Differing optimizer state names for the unflattened "
|
||||
f"parameters: {unflat_param_names}"
|
||||
)
|
||||
assert state_names is not None
|
||||
if state_names is None:
|
||||
raise AssertionError(f"Expected state_names to be not None, got {state_names}")
|
||||
|
||||
# Flatten the state
|
||||
flat_state: dict[str, Optional[torch.Tensor]] = {}
|
||||
@ -672,7 +682,10 @@ def _flatten_optim_state(
|
||||
unflat_param_names,
|
||||
)
|
||||
else:
|
||||
assert are_non_tensors
|
||||
if not are_non_tensors:
|
||||
raise AssertionError(
|
||||
f"Expected are_non_tensors to be True, got {are_non_tensors}"
|
||||
)
|
||||
flat_state[state_name] = _flatten_non_tensor_optim_state(
|
||||
state_name,
|
||||
state_values,
|
||||
@ -760,9 +773,10 @@ def _flatten_tensor_optim_state(
|
||||
]
|
||||
flat_tensor = handle.flatten_tensors(tensors_to_flatten, handle._aligned_numel)
|
||||
flat_param_shape = flat_param._unpadded_unsharded_size # type: ignore[attr-defined]
|
||||
assert flat_tensor.shape == flat_param_shape, (
|
||||
f"tensor optim state: {flat_tensor.shape} flat parameter: {flat_param_shape}"
|
||||
)
|
||||
if flat_tensor.shape != flat_param_shape:
|
||||
raise AssertionError(
|
||||
f"tensor optim state: {flat_tensor.shape} flat parameter: {flat_param_shape}"
|
||||
)
|
||||
return flat_tensor
|
||||
|
||||
|
||||
@ -893,7 +907,10 @@ def _rekey_sharded_optim_state_dict(
|
||||
# All parameter keys in `param_to_param_key` should be in
|
||||
# `param_to_fqns` -- strict inequality follows when not all parameters are
|
||||
# passed to the optimizer
|
||||
assert len(param_to_param_key) <= len(param_to_fqns)
|
||||
if len(param_to_param_key) > len(param_to_fqns):
|
||||
raise AssertionError(
|
||||
f"Expected len(param_to_param_key) <= len(param_to_fqns), got {len(param_to_param_key)} > {len(param_to_fqns)}"
|
||||
)
|
||||
|
||||
unflat_param_names_to_flat_param_key: dict[
|
||||
tuple[str, ...], Union[int, str]
|
||||
@ -1002,14 +1019,15 @@ def _get_param_id_to_param_from_optim_input(
|
||||
raise TypeError("Optimizer input should be an iterable of Tensors or dicts")
|
||||
if all_tensors:
|
||||
return dict(enumerate(params))
|
||||
assert all_dicts
|
||||
if not all_dicts:
|
||||
raise AssertionError(f"Expected all_dicts to be True, got {all_dicts}")
|
||||
param_id_to_param: list[nn.Parameter] = []
|
||||
for param_group in params:
|
||||
has_params_key = "params" in param_group # type: ignore[operator]
|
||||
assert has_params_key, (
|
||||
'A parameter group should map "params" to a list of the '
|
||||
"parameters in the group"
|
||||
)
|
||||
if not has_params_key:
|
||||
raise AssertionError(
|
||||
'A parameter group should map "params" to a list of the parameters in the group'
|
||||
)
|
||||
# Implicitly map `flat_param_id` (current length of the list) to
|
||||
# `param`
|
||||
param_id_to_param.extend(param_group["params"]) # type: ignore[index]
|
||||
@ -1068,10 +1086,12 @@ def _get_param_key_to_param(
|
||||
"""
|
||||
clean_fqn_to_curr_fqn: dict[str, str] = {}
|
||||
if is_named_optimizer:
|
||||
assert param_to_fqns is not None and flat_param_to_fqn is not None, (
|
||||
"The optimizer is a NamedOptimizer, `param_to_fqns` must not be None."
|
||||
)
|
||||
assert model is not None
|
||||
if param_to_fqns is None or flat_param_to_fqn is None:
|
||||
raise AssertionError(
|
||||
"The optimizer is a NamedOptimizer, `param_to_fqns` must not be None."
|
||||
)
|
||||
if model is None:
|
||||
raise AssertionError(f"Expected model to be not None, got {model}")
|
||||
for key, _ in _named_parameters_with_duplicates(model):
|
||||
clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key
|
||||
|
||||
@ -1080,14 +1100,23 @@ def _get_param_key_to_param(
|
||||
for param_group in optim.param_groups:
|
||||
if is_named_optimizer:
|
||||
for param in param_group["params"]:
|
||||
assert flat_param_to_fqn is not None
|
||||
if flat_param_to_fqn is None:
|
||||
raise AssertionError(
|
||||
f"Expected flat_param_to_fqn to be not None, got {flat_param_to_fqn}"
|
||||
)
|
||||
if param in flat_param_to_fqn:
|
||||
# FlatParameter case
|
||||
key = flat_param_to_fqn[param]
|
||||
else:
|
||||
assert param_to_fqns is not None
|
||||
if param_to_fqns is None:
|
||||
raise AssertionError(
|
||||
f"Expected param_to_fqns to be not None, got {param_to_fqns}"
|
||||
)
|
||||
# use_orig_params case
|
||||
assert len(param_to_fqns[param]) == 1
|
||||
if len(param_to_fqns[param]) != 1:
|
||||
raise AssertionError(
|
||||
f"Expected len(param_to_fqns[param]) == 1, got {len(param_to_fqns[param])}"
|
||||
)
|
||||
key = param_to_fqns[param][0]
|
||||
try:
|
||||
key = clean_fqn_to_curr_fqn[key]
|
||||
@ -1153,9 +1182,8 @@ def _check_missing_keys_on_rank(
|
||||
continue
|
||||
param_key = optim_state_key_to_param_key[r0_optim_state_key]
|
||||
if isinstance(param_key, int):
|
||||
assert param_key >= 0 and param_key < len(param_key_to_param), (
|
||||
"Check the `param_key_to_param` construction"
|
||||
)
|
||||
if not (param_key >= 0 and param_key < len(param_key_to_param)):
|
||||
raise AssertionError("Check the `param_key_to_param` construction")
|
||||
# We cannot use FSDPState.compute_device as this API is a global view.
|
||||
device = _get_pg_default_device(group)
|
||||
num_missing = torch.tensor([len(missing_keys)], dtype=torch.int32, device=device)
|
||||
@ -1204,10 +1232,10 @@ def _map_param_key_to_optim_keys(
|
||||
fqns = param_to_fqns[param]
|
||||
is_fsdp_managed = isinstance(param, FlatParameter)
|
||||
if is_fsdp_managed:
|
||||
assert fqns[0] in fqn_to_fsdp_param_info, (
|
||||
fqns[0],
|
||||
list(fqn_to_fsdp_param_info.keys()),
|
||||
)
|
||||
if fqns[0] not in fqn_to_fsdp_param_info:
|
||||
raise AssertionError(
|
||||
f"Expected {fqns[0]} to be in fqn_to_fsdp_param_info, got keys: {list(fqn_to_fsdp_param_info.keys())}"
|
||||
)
|
||||
is_fsdp_managed = fqns[0] in fqn_to_fsdp_param_info
|
||||
optim_state_key = _OptimStateKey(
|
||||
unflat_param_names=tuple(fqns),
|
||||
@ -1229,7 +1257,10 @@ def _map_param_key_to_optim_keys(
|
||||
[all_optim_state_keys] if rank == 0 else [None]
|
||||
)
|
||||
dist.broadcast_object_list(key_obj_list, src=0, group=group)
|
||||
assert key_obj_list[0] is not None
|
||||
if key_obj_list[0] is None:
|
||||
raise AssertionError(
|
||||
f"Expected key_obj_list[0] to be not None, got {key_obj_list[0]}"
|
||||
)
|
||||
all_optim_state_keys = key_obj_list[0]
|
||||
_check_missing_keys_on_rank(
|
||||
all_optim_state_keys,
|
||||
@ -1362,11 +1393,17 @@ def _convert_all_state_info(
|
||||
if not dtype:
|
||||
dtype = info.dtype
|
||||
else:
|
||||
assert dtype == info.dtype
|
||||
if dtype != info.dtype:
|
||||
raise AssertionError(
|
||||
f"Expected dtype == info.dtype, got {dtype} != {info.dtype}"
|
||||
)
|
||||
if numels[-1] == 0:
|
||||
_empty_ranks.add(rank)
|
||||
|
||||
assert not empty_ranks or empty_ranks == _empty_ranks
|
||||
if not (not empty_ranks or empty_ranks == _empty_ranks):
|
||||
raise AssertionError(
|
||||
f"Expected empty_ranks to be empty or equal to _empty_ranks, got {empty_ranks} vs {_empty_ranks}"
|
||||
)
|
||||
empty_ranks = _empty_ranks
|
||||
if state_name not in state_buffers:
|
||||
state_buffers[state_name] = [
|
||||
@ -1388,23 +1425,26 @@ def _convert_all_state_info(
|
||||
continue
|
||||
for name, non_tensor_value in object_state.non_tensors.items():
|
||||
curr_non_tensor_value = gathered_state.get(name, None)
|
||||
assert (
|
||||
if not (
|
||||
curr_non_tensor_value is None
|
||||
or curr_non_tensor_value == non_tensor_value
|
||||
), (
|
||||
f"Rank {rank} has different values for {name}: {non_tensor_value}."
|
||||
+ f" Other ranks: {curr_non_tensor_value}"
|
||||
)
|
||||
):
|
||||
raise AssertionError(
|
||||
f"Rank {rank} has different values for {name}: {non_tensor_value}."
|
||||
+ f" Other ranks: {curr_non_tensor_value}"
|
||||
)
|
||||
gathered_state[name] = non_tensor_value
|
||||
|
||||
for name, scalar_tensor_value in object_state.scalar_tensors.items():
|
||||
curr_scalar_tensor_value = gathered_state.get(name, None)
|
||||
assert curr_scalar_tensor_value is None or torch.equal(
|
||||
scalar_tensor_value, curr_scalar_tensor_value
|
||||
), (
|
||||
f"Rank {rank} has different values for {name}: {scalar_tensor_value}."
|
||||
+ f" Other ranks: {curr_scalar_tensor_value}"
|
||||
)
|
||||
if not (
|
||||
curr_scalar_tensor_value is None
|
||||
or torch.equal(scalar_tensor_value, curr_scalar_tensor_value)
|
||||
):
|
||||
raise AssertionError(
|
||||
f"Rank {rank} has different values for {name}: {scalar_tensor_value}."
|
||||
+ f" Other ranks: {curr_scalar_tensor_value}"
|
||||
)
|
||||
gathered_state[name] = scalar_tensor_value
|
||||
|
||||
return dtype, state_buffers # type: ignore[possibly-undefined]
|
||||
@ -1455,7 +1495,10 @@ def _unflatten_orig_param_states(
|
||||
if shard_state:
|
||||
osd_config = fsdp_state._optim_state_dict_config
|
||||
if getattr(osd_config, "_use_dtensor", False):
|
||||
assert fsdp_state._device_mesh is not None
|
||||
if fsdp_state._device_mesh is None:
|
||||
raise AssertionError(
|
||||
f"Expected _device_mesh to be not None, got {fsdp_state._device_mesh}"
|
||||
)
|
||||
value = _ext_chunk_dtensor(
|
||||
value,
|
||||
fsdp_state.rank,
|
||||
@ -1463,7 +1506,10 @@ def _unflatten_orig_param_states(
|
||||
fsdp_state._fsdp_extension,
|
||||
)
|
||||
else:
|
||||
assert fsdp_state.process_group is not None
|
||||
if fsdp_state.process_group is None:
|
||||
raise AssertionError(
|
||||
f"Expected process_group to be not None, got {fsdp_state.process_group}"
|
||||
)
|
||||
value = _ext_chunk_tensor(
|
||||
value,
|
||||
fsdp_state.rank,
|
||||
@ -1598,24 +1644,26 @@ def _allgather_orig_param_states(
|
||||
sum(t.numel() for t in local_buffers)
|
||||
)
|
||||
|
||||
assert flat_param._shard_numel_padded == shard_numel_padded, (
|
||||
"Manually calculated _sharded_numel_padded is incorrect. "
|
||||
f"_shard_numel_padded={flat_param._shard_numel_padded}, "
|
||||
f"shard_numel_padded={shard_numel_padded}, "
|
||||
f"_sharded_size.numel={flat_param._sharded_size.numel()}, "
|
||||
f"_numels_with_padding={flat_param._numels_with_padding}, "
|
||||
f"begin={begin}, end={end},"
|
||||
)
|
||||
if flat_param._shard_numel_padded != shard_numel_padded:
|
||||
raise AssertionError(
|
||||
"Manually calculated _sharded_numel_padded is incorrect. "
|
||||
f"_shard_numel_padded={flat_param._shard_numel_padded}, "
|
||||
f"shard_numel_padded={shard_numel_padded}, "
|
||||
f"_sharded_size.numel={flat_param._sharded_size.numel()}, "
|
||||
f"_numels_with_padding={flat_param._numels_with_padding}, "
|
||||
f"begin={begin}, end={end},"
|
||||
)
|
||||
if shard_numel_padded > 0:
|
||||
# Add right-handed padding.
|
||||
local_buffers.append(empty_func(shard_numel_padded))
|
||||
local_shard = torch.cat(local_buffers)
|
||||
assert local_shard.numel() * fsdp_state.world_size == gathered_tensor.numel(), (
|
||||
"The size of local shard times the world size should equal to the "
|
||||
"gathered tensor size. The inconsistency may be from a bug of "
|
||||
"FlatParameter's metadata or the reconstruction logic in optimizer "
|
||||
"state dict."
|
||||
)
|
||||
if local_shard.numel() * fsdp_state.world_size != gathered_tensor.numel():
|
||||
raise AssertionError(
|
||||
"The size of local shard times the world size should equal to the "
|
||||
"gathered tensor size. The inconsistency may be from a bug of "
|
||||
"FlatParameter's metadata or the reconstruction logic in optimizer "
|
||||
"state dict."
|
||||
)
|
||||
fsdp_state._device_handle.synchronize()
|
||||
with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER):
|
||||
dist.all_gather_into_tensor(
|
||||
@ -1627,11 +1675,12 @@ def _allgather_orig_param_states(
|
||||
unpadded_tensor = gathered_tensor[: flat_param._unpadded_unsharded_size.numel()]
|
||||
flat_param_handle = fsdp_param_info.handle
|
||||
orig_states = flat_param_handle._get_unflat_views_aligned(unpadded_tensor)
|
||||
assert len(orig_states) == len(fsdp_param_info.param_indices), (
|
||||
"The number of parameters from FlatParameter is not consistent to "
|
||||
"the number of states used by optimizer state dict reconstruction "
|
||||
"logic."
|
||||
)
|
||||
if len(orig_states) != len(fsdp_param_info.param_indices):
|
||||
raise AssertionError(
|
||||
"The number of parameters from FlatParameter is not consistent to "
|
||||
"the number of states used by optimizer state dict reconstruction "
|
||||
"logic."
|
||||
)
|
||||
for fqn, idx in fsdp_param_info.param_indices.items():
|
||||
if fsdp_param_info.param_requires_grad[idx] or fqn in output_states:
|
||||
output_states[fqn][state_name] = orig_states[idx]
|
||||
@ -1741,7 +1790,10 @@ def _convert_state_with_orig_params(
|
||||
all_states[id(fsdp_param_info)][fqn] = state
|
||||
|
||||
elif to_save:
|
||||
assert len(optim_state_key.unflat_param_names) == 1
|
||||
if len(optim_state_key.unflat_param_names) != 1:
|
||||
raise AssertionError(
|
||||
f"Expected len(optim_state_key.unflat_param_names) == 1, got {len(optim_state_key.unflat_param_names)}"
|
||||
)
|
||||
unflat_param_name = optim_state_key.unflat_param_names[0]
|
||||
with SimpleProfiler.profile("none_fsdp_managed_copy"):
|
||||
param_key = cast(Union[str, int], param_key)
|
||||
@ -1761,10 +1813,11 @@ def _convert_state_with_orig_params(
|
||||
for _all_states in all_states.values():
|
||||
fqn = next(iter(_all_states.keys()))
|
||||
fsdp_param_info = fqn_to_fsdp_param_info[fqn]
|
||||
assert len(fsdp_param_info.param_requires_grad) > 0, (
|
||||
"With use_orig_params, FSDPParamInfo should have requires_grad "
|
||||
"information. However, the length is zero."
|
||||
)
|
||||
if len(fsdp_param_info.param_requires_grad) <= 0:
|
||||
raise AssertionError(
|
||||
"With use_orig_params, FSDPParamInfo should have requires_grad "
|
||||
"information. However, the length is zero."
|
||||
)
|
||||
for key, idx in fsdp_param_info.param_indices.items():
|
||||
if key in _all_states:
|
||||
continue
|
||||
@ -1807,10 +1860,11 @@ def _convert_state_with_flat_params(
|
||||
optim_state_key
|
||||
)
|
||||
|
||||
assert param_key is not None, (
|
||||
"If use_orig_params is False, we must be able to find the "
|
||||
f"corresponding param id. {optim_state_key} {param_key}"
|
||||
)
|
||||
if param_key is None:
|
||||
raise AssertionError(
|
||||
"If use_orig_params is False, we must be able to find the "
|
||||
f"corresponding param id. {optim_state_key} {param_key}"
|
||||
)
|
||||
|
||||
if optim_state_key.is_fsdp_managed:
|
||||
# If there are multiple unflat_param_names (not use_orig_params),
|
||||
@ -1826,7 +1880,11 @@ def _convert_state_with_flat_params(
|
||||
cpu_offload,
|
||||
)
|
||||
if to_save:
|
||||
assert len(unflat_state) == len(optim_state_key.unflat_param_names)
|
||||
if len(unflat_state) != len(optim_state_key.unflat_param_names):
|
||||
raise AssertionError(
|
||||
f"Expected len(unflat_state) == len(optim_state_key.unflat_param_names), "
|
||||
f"got {len(unflat_state)} != {len(optim_state_key.unflat_param_names)}"
|
||||
)
|
||||
fsdp_osd_state.update(
|
||||
zip(
|
||||
optim_state_key.unflat_param_names,
|
||||
@ -1834,7 +1892,10 @@ def _convert_state_with_flat_params(
|
||||
)
|
||||
)
|
||||
elif to_save:
|
||||
assert len(optim_state_key.unflat_param_names) == 1
|
||||
if len(optim_state_key.unflat_param_names) != 1:
|
||||
raise AssertionError(
|
||||
f"Expected len(optim_state_key.unflat_param_names) == 1, got {len(optim_state_key.unflat_param_names)}"
|
||||
)
|
||||
unflat_param_name = optim_state_key.unflat_param_names[0]
|
||||
fsdp_osd_state[unflat_param_name] = copy.copy(optim_state_dict[param_key])
|
||||
if cpu_offload:
|
||||
@ -2030,7 +2091,10 @@ def _get_fqn_to_fsdp_param_info(model: nn.Module) -> dict[str, FSDPParamInfo]:
|
||||
for idx, local_fqn in enumerate(flat_param._fqns):
|
||||
fqn = clean_tensor_name(prefix + local_fqn)
|
||||
if fqn in fqn_to_param_info:
|
||||
assert fqn_to_param_info[fqn].handle.flat_param is flat_param, fqn
|
||||
if fqn_to_param_info[fqn].handle.flat_param is not flat_param:
|
||||
raise AssertionError(
|
||||
f"Expected fqn_to_param_info[fqn].handle.flat_param is flat_param for {fqn}"
|
||||
)
|
||||
fqn_to_param_info[fqn] = fsdp_param_info
|
||||
fsdp_param_info.param_indices[fqn] = idx
|
||||
if flat_param._params is not None:
|
||||
|
@ -103,7 +103,8 @@ def _is_fsdp_root(state: _FSDPState, module: nn.Module) -> bool:
|
||||
"""
|
||||
# Force a lazy initialization to determine the FSDP root
|
||||
_lazy_init(state, module)
|
||||
assert state._is_root is not None # mypy
|
||||
if state._is_root is None:
|
||||
raise AssertionError("Expected _is_root to be set after lazy init")
|
||||
return state._is_root
|
||||
|
||||
|
||||
@ -240,8 +241,10 @@ def _init_streams(
|
||||
Initializes CUDA streams for overlapping communication, computation, and
|
||||
data transfers. The streams should be shared across FSDP instances.
|
||||
"""
|
||||
assert state._is_root
|
||||
assert state._device_handle.is_available()
|
||||
if not state._is_root:
|
||||
raise AssertionError("Expected state to be root")
|
||||
if not state._device_handle.is_available():
|
||||
raise AssertionError("Expected device handle to be available")
|
||||
uses_hybrid_sharding = any(
|
||||
fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES
|
||||
for fsdp_state in state._all_fsdp_states
|
||||
@ -1459,7 +1462,8 @@ def _register_post_backward_hook(
|
||||
"register the post-backward hook",
|
||||
)
|
||||
acc_grad = temp_flat_param.grad_fn.next_functions[0][0] # type: ignore[union-attr]
|
||||
assert acc_grad is not None
|
||||
if acc_grad is None:
|
||||
raise AssertionError("Expected acc_grad to be set")
|
||||
hook_handle = acc_grad.register_hook(
|
||||
functools.partial(_post_backward_hook, state, handle)
|
||||
)
|
||||
@ -1501,7 +1505,8 @@ def _register_post_backward_reshard_only_hook(
|
||||
inp_tensors = [
|
||||
obj for obj in args_flat if torch.is_tensor(obj) and obj.requires_grad
|
||||
]
|
||||
assert inp_tensors is not None # mypy
|
||||
if inp_tensors is None:
|
||||
raise AssertionError("Expected inp_tensors to be set")
|
||||
hook_handle = register_multi_grad_hook(
|
||||
inp_tensors, functools.partial(_post_backward_reshard_only_hook, state, handle)
|
||||
)
|
||||
@ -1599,7 +1604,10 @@ def _get_buffers_and_dtypes_for_computation(
|
||||
continue
|
||||
buffers.append(buffer)
|
||||
buffer_dtypes.append(fsdp_state.mixed_precision.buffer_dtype)
|
||||
assert len(buffers) == len(buffer_dtypes), f"{len(buffers)} {len(buffer_dtypes)}"
|
||||
if len(buffers) != len(buffer_dtypes):
|
||||
raise AssertionError(
|
||||
f"Expected buffers and buffer_dtypes to have the same length, got {len(buffers)} and {len(buffer_dtypes)}"
|
||||
)
|
||||
return buffers, buffer_dtypes
|
||||
|
||||
|
||||
|
@ -68,7 +68,11 @@ def _create_chunk_sharded_tensor(
|
||||
)
|
||||
for r in range(len(chunk_sizes))
|
||||
]
|
||||
assert len(chunk_sizes) == len(chunk_offsets) == len(placements)
|
||||
if len(chunk_sizes) != len(chunk_offsets) or len(chunk_sizes) != len(placements):
|
||||
raise AssertionError(
|
||||
f"Expected chunk_sizes, chunk_offsets, and placements to have the same length, "
|
||||
f"got {len(chunk_sizes)}, {len(chunk_offsets)}, {len(placements)}"
|
||||
)
|
||||
shard_metadata = [
|
||||
ShardMetadata(offset, size, placement)
|
||||
for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements)
|
||||
@ -121,9 +125,8 @@ def _all_gather_dtensor(
|
||||
"""
|
||||
All gather a DTensor in its sharded dimension and return the local tensor.
|
||||
"""
|
||||
assert root_mesh == tensor.device_mesh, (
|
||||
"The device mesh of a tensor should be a root mesh."
|
||||
)
|
||||
if root_mesh != tensor.device_mesh:
|
||||
raise AssertionError("The device mesh of a tensor should be a root mesh.")
|
||||
|
||||
placements = list(copy.deepcopy(tensor.placements))
|
||||
# FSDP placements: [Shard(0)] -> [Replicate()]
|
||||
|
@ -110,10 +110,11 @@ def _enter_unshard_params_ctx(
|
||||
requires to enter the context in the pre-hook but leave the context in the
|
||||
post-hook. This API enters the context of ``_unshard_fsdp_state_params``.
|
||||
"""
|
||||
assert module not in fsdp_state._unshard_params_ctx, (
|
||||
"Entering the ``_unshard_fsdp_state_params`` context but _unshard_params_ctx[module] "
|
||||
"is not None."
|
||||
)
|
||||
if module in fsdp_state._unshard_params_ctx:
|
||||
raise AssertionError(
|
||||
"Entering the ``_unshard_fsdp_state_params`` context but _unshard_params_ctx[module] "
|
||||
"is not None."
|
||||
)
|
||||
fsdp_state._unshard_params_ctx[module] = _unshard_fsdp_state_params(
|
||||
module,
|
||||
fsdp_state,
|
||||
@ -219,12 +220,13 @@ def _common_unshard_post_state_dict_hook(
|
||||
if no_fsdp_return:
|
||||
state_dict.pop(fqn)
|
||||
continue
|
||||
assert fqn in state_dict, (
|
||||
f"FSDP assumes {fqn} is in the state_dict but the state_dict only "
|
||||
f"has {state_dict.keys()}. "
|
||||
f"prefix={prefix}, module_name={module_name}, "
|
||||
f"param_name={param_name} rank={fsdp_state.rank}."
|
||||
)
|
||||
if fqn not in state_dict:
|
||||
raise AssertionError(
|
||||
f"FSDP assumes {fqn} is in the state_dict but the state_dict only "
|
||||
f"has {state_dict.keys()}. "
|
||||
f"prefix={prefix}, module_name={module_name}, "
|
||||
f"param_name={param_name} rank={fsdp_state.rank}."
|
||||
)
|
||||
|
||||
param_hook(state_dict, prefix, fqn)
|
||||
|
||||
@ -410,7 +412,8 @@ def _local_post_state_dict_hook(
|
||||
# value as the flat_param but it is a pure Tensor because
|
||||
# nn.Module.state_dict() will detach the parameter. Therefore, we need
|
||||
# to get flat_param to get the metadata.
|
||||
assert _module_handle(fsdp_state, module), "Should have returned early"
|
||||
if not _module_handle(fsdp_state, module):
|
||||
raise AssertionError("Should have returned early")
|
||||
flat_param = _module_handle(fsdp_state, module).flat_param
|
||||
# Constructs a ShardedTensor from the flat_param "without" padding.
|
||||
# Removing the padding allows users to change the number of ranks
|
||||
@ -460,32 +463,37 @@ def _local_pre_load_state_dict_hook(
|
||||
_replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_PREFIX}")
|
||||
fqn = f"{prefix}{FSDP_PREFIX}{FLAT_PARAM}"
|
||||
if fqn not in state_dict:
|
||||
assert not _has_fsdp_params(fsdp_state, module), (
|
||||
"No `FlatParameter` in `state_dict` for this FSDP instance "
|
||||
"but it has parameters"
|
||||
)
|
||||
if _has_fsdp_params(fsdp_state, module):
|
||||
raise AssertionError(
|
||||
"No `FlatParameter` in `state_dict` for this FSDP instance "
|
||||
"but it has parameters"
|
||||
)
|
||||
return
|
||||
load_tensor = state_dict[fqn]
|
||||
assert isinstance(load_tensor, ShardedTensor), (
|
||||
"Tensors in local_state_dict should be ShardedTensor."
|
||||
)
|
||||
if not isinstance(load_tensor, ShardedTensor):
|
||||
raise AssertionError("Tensors in local_state_dict should be ShardedTensor.")
|
||||
|
||||
# Convert the ShardedTensor to a Tensor.
|
||||
flat_param = _module_handle(fsdp_state, module).flat_param
|
||||
assert flat_param is not None
|
||||
if flat_param is None:
|
||||
raise AssertionError("Expected flat_param to be set")
|
||||
valid_data_size = flat_param.numel() - flat_param._shard_numel_padded
|
||||
shards = load_tensor.local_shards()
|
||||
if valid_data_size > 0:
|
||||
assert len(shards), "load_local_state_dict assume one shard per ShardedTensor."
|
||||
if not len(shards):
|
||||
raise AssertionError(
|
||||
"load_local_state_dict assume one shard per ShardedTensor."
|
||||
)
|
||||
load_tensor = shards[0].tensor
|
||||
|
||||
# Get the metadata of the flat_param to decide whether to pad the loaded
|
||||
# tensor.
|
||||
if flat_param._shard_numel_padded > 0:
|
||||
assert load_tensor.numel() < flat_param.numel(), (
|
||||
f"Local shard size = {flat_param.numel()} and the tensor in "
|
||||
f"the state_dict is {load_tensor.numel()}."
|
||||
)
|
||||
if load_tensor.numel() >= flat_param.numel():
|
||||
raise AssertionError(
|
||||
f"Local shard size = {flat_param.numel()} and the tensor in "
|
||||
f"the state_dict is {load_tensor.numel()}."
|
||||
)
|
||||
load_tensor = F.pad(load_tensor, [0, flat_param._shard_numel_padded])
|
||||
else:
|
||||
load_tensor = flat_param
|
||||
@ -618,10 +626,11 @@ def _sharded_pre_load_state_dict_hook(
|
||||
param, fsdp_state._fsdp_extension
|
||||
)
|
||||
|
||||
assert len(shards) < 2, (
|
||||
"Expects 0 or 1 shard per rank "
|
||||
f"but got {len(shards)} shards on rank {fsdp_state.rank}."
|
||||
)
|
||||
if len(shards) >= 2:
|
||||
raise AssertionError(
|
||||
"Expects 0 or 1 shard per rank "
|
||||
f"but got {len(shards)} shards on rank {fsdp_state.rank}."
|
||||
)
|
||||
param_numel = param.size().numel()
|
||||
dim_0_size = param.size()[0]
|
||||
chunk_size = (
|
||||
|
@ -144,9 +144,10 @@ class _ExecOrderTracer:
|
||||
named_params = list(module.named_parameters())
|
||||
curr_module = exec_info.curr_module
|
||||
if named_params:
|
||||
assert curr_module in exec_info.module_to_param_usage_infos, (
|
||||
"The current module should have already been processed by a patched `call_module`"
|
||||
)
|
||||
if curr_module not in exec_info.module_to_param_usage_infos:
|
||||
raise AssertionError(
|
||||
"The current module should have already been processed by a patched `call_module`"
|
||||
)
|
||||
exec_info.module_to_param_usage_infos[exec_info.curr_module].append(
|
||||
_ParamUsageInfo(module, named_params)
|
||||
)
|
||||
|
@ -66,7 +66,8 @@ def _writeback_to_local_shard(
|
||||
if writeback_grad:
|
||||
existing_grad = handle.sharded_grad
|
||||
if existing_grad is not None:
|
||||
assert handle.flat_param.grad is not None
|
||||
if handle.flat_param.grad is None:
|
||||
raise AssertionError("Expected handle.flat_param.grad to not be None")
|
||||
grad_shard = _get_shard(handle.flat_param.grad)
|
||||
existing_grad[: grad_shard.numel()].copy_(grad_shard)
|
||||
|
||||
@ -185,9 +186,10 @@ def _unshard_fsdp_state_params(
|
||||
yield
|
||||
return
|
||||
|
||||
assert handle._training_state == HandleTrainingState.IDLE, (
|
||||
f"Expects the handle training to be IDLE but got {handle._training_state}"
|
||||
)
|
||||
if handle._training_state != HandleTrainingState.IDLE:
|
||||
raise AssertionError(
|
||||
f"Expects the handle training to be IDLE but got {handle._training_state}"
|
||||
)
|
||||
|
||||
handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS
|
||||
|
||||
|
@ -718,24 +718,29 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||
if prev_state_dict_type is None:
|
||||
prev_state_dict_type = submodule._state_dict_type
|
||||
else:
|
||||
assert prev_state_dict_type == submodule._state_dict_type, (
|
||||
"All FSDP modules should have the same state_dict_type."
|
||||
)
|
||||
if prev_state_dict_type != submodule._state_dict_type:
|
||||
raise AssertionError(
|
||||
"All FSDP modules should have the same state_dict_type."
|
||||
)
|
||||
if prev_state_dict_config is None:
|
||||
prev_state_dict_config = submodule._state_dict_config
|
||||
else:
|
||||
assert isinstance(
|
||||
if not isinstance(
|
||||
submodule._state_dict_config, type(prev_state_dict_config)
|
||||
), "All FSDP modules must have the same type of state_dict_config."
|
||||
):
|
||||
raise AssertionError(
|
||||
"All FSDP modules must have the same type of state_dict_config."
|
||||
)
|
||||
if prev_optim_state_dict_config is None:
|
||||
prev_optim_state_dict_config = submodule._optim_state_dict_config
|
||||
else:
|
||||
assert isinstance(
|
||||
if not isinstance(
|
||||
submodule._optim_state_dict_config,
|
||||
type(prev_optim_state_dict_config),
|
||||
), (
|
||||
"All FSDP modules must have the same type of optim_state_dict_config."
|
||||
)
|
||||
):
|
||||
raise AssertionError(
|
||||
"All FSDP modules must have the same type of optim_state_dict_config."
|
||||
)
|
||||
|
||||
submodule._state_dict_type = state_dict_type
|
||||
submodule._state_dict_config = state_dict_config
|
||||
@ -774,10 +779,11 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||
submodule._state_dict_config,
|
||||
submodule._optim_state_dict_config,
|
||||
)
|
||||
assert state_dict_settings == submodule_settings, (
|
||||
"All FSDP modules must have the same state dict settings."
|
||||
f"Got {submodule_settings} and {state_dict_settings}."
|
||||
)
|
||||
if state_dict_settings != submodule_settings:
|
||||
raise AssertionError(
|
||||
"All FSDP modules must have the same state dict settings."
|
||||
f"Got {submodule_settings} and {state_dict_settings}."
|
||||
)
|
||||
_set_optim_use_dtensor(submodule, submodule_settings)
|
||||
return state_dict_settings
|
||||
|
||||
@ -1054,10 +1060,11 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||
yield
|
||||
finally:
|
||||
for m, old_flag in old_flags:
|
||||
assert not m._sync_gradients, (
|
||||
"`_sync_gradients` was incorrectly set to "
|
||||
"`True` while in the `no_sync()` context manager"
|
||||
)
|
||||
if m._sync_gradients:
|
||||
raise AssertionError(
|
||||
"`_sync_gradients` was incorrectly set to "
|
||||
"`True` while in the `no_sync()` context manager"
|
||||
)
|
||||
m._sync_gradients = old_flag
|
||||
|
||||
@torch.no_grad()
|
||||
@ -1275,15 +1282,22 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||
)
|
||||
else:
|
||||
using_optim_input = False
|
||||
assert optim_input is None and not rank0_only
|
||||
if optim_input is not None or rank0_only:
|
||||
raise AssertionError(
|
||||
f"Expected optim_input to be None and rank0_only to be False, "
|
||||
f"got optim_input={optim_input}, rank0_only={rank0_only}"
|
||||
)
|
||||
|
||||
use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[
|
||||
0
|
||||
]._use_orig_params
|
||||
assert all(
|
||||
if not all(
|
||||
use_orig_params == m._use_orig_params
|
||||
for m in FullyShardedDataParallel.fsdp_modules(model)
|
||||
), "Not all FSDP modules have the same _use_orig_params value"
|
||||
):
|
||||
raise AssertionError(
|
||||
"Not all FSDP modules have the same _use_orig_params value"
|
||||
)
|
||||
|
||||
return _optim_state_dict(
|
||||
model=model,
|
||||
@ -1329,15 +1343,22 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||
)
|
||||
else:
|
||||
using_optim_input = False
|
||||
assert optim_input is None and not rank0_only
|
||||
if optim_input is not None or rank0_only:
|
||||
raise AssertionError(
|
||||
f"Expected optim_input to be None and rank0_only to be False, "
|
||||
f"got optim_input={optim_input}, rank0_only={rank0_only}"
|
||||
)
|
||||
|
||||
use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[
|
||||
0
|
||||
]._use_orig_params
|
||||
assert all(
|
||||
if not all(
|
||||
use_orig_params == m._use_orig_params
|
||||
for m in FullyShardedDataParallel.fsdp_modules(model)
|
||||
), "Not all FSDP modules have the same _use_orig_params value"
|
||||
):
|
||||
raise AssertionError(
|
||||
"Not all FSDP modules have the same _use_orig_params value"
|
||||
)
|
||||
|
||||
if rank0_only and dist.get_rank(group) > 0:
|
||||
optim_state_dict = {}
|
||||
@ -1719,10 +1740,13 @@ class FullyShardedDataParallel(nn.Module, _FSDPState):
|
||||
optim_input,
|
||||
optim,
|
||||
)
|
||||
assert optim_state_key_type in (
|
||||
if optim_state_key_type not in (
|
||||
OptimStateKeyType.PARAM_NAME,
|
||||
OptimStateKeyType.PARAM_ID,
|
||||
)
|
||||
):
|
||||
raise AssertionError(
|
||||
f"Expected optim_state_key_type to be PARAM_NAME or PARAM_ID, got {optim_state_key_type}"
|
||||
)
|
||||
osd = optim_state_dict # alias
|
||||
# Validate that the existing parameter keys are uniformly typed
|
||||
uses_param_name_mask = [type(param_key) is str for param_key in osd["state"]]
|
||||
@ -2150,9 +2174,10 @@ def _get_param_to_fqn(
|
||||
"""
|
||||
param_to_param_names = _get_param_to_fqns(model)
|
||||
for param_names in param_to_param_names.values():
|
||||
assert len(param_names) > 0, (
|
||||
"`_get_param_to_fqns()` should not construct empty lists"
|
||||
)
|
||||
if len(param_names) == 0:
|
||||
raise AssertionError(
|
||||
"`_get_param_to_fqns()` should not construct empty lists"
|
||||
)
|
||||
if len(param_names) > 1:
|
||||
raise RuntimeError(
|
||||
"Each parameter should only map to one parameter name but got "
|
||||
|
@ -35,7 +35,10 @@ class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator):
|
||||
"""
|
||||
|
||||
def __init__(self, master_tensor: torch.Tensor) -> None:
|
||||
assert _is_supported_device(master_tensor)
|
||||
if not _is_supported_device(master_tensor):
|
||||
raise AssertionError(
|
||||
f"Expected supported device, got {master_tensor.device}"
|
||||
)
|
||||
self.master = master_tensor
|
||||
self._per_device_tensors: dict[torch.device, torch.Tensor] = {}
|
||||
|
||||
@ -130,10 +133,12 @@ class ShardedGradScaler(GradScaler):
|
||||
return outputs
|
||||
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
assert _is_supported_device(outputs)
|
||||
if not _is_supported_device(outputs):
|
||||
raise AssertionError(f"Expected supported device, got {outputs.device}")
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker(outputs.device)
|
||||
assert self._scale is not None
|
||||
if self._scale is None:
|
||||
raise AssertionError("Expected _scale to be initialized, got None")
|
||||
scaled_output = outputs * self._scale.to(
|
||||
device=outputs.device, non_blocking=True
|
||||
)
|
||||
@ -146,11 +151,15 @@ class ShardedGradScaler(GradScaler):
|
||||
|
||||
def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]):
|
||||
if isinstance(val, torch.Tensor):
|
||||
assert _is_supported_device(val)
|
||||
if not _is_supported_device(val):
|
||||
raise AssertionError(f"Expected supported device, got {val.device}")
|
||||
if len(stash) == 0:
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker(val.device)
|
||||
assert self._scale is not None
|
||||
if self._scale is None:
|
||||
raise AssertionError(
|
||||
"Expected _scale to be initialized, got None"
|
||||
)
|
||||
stash.append(_GeneralMultiDeviceReplicator(self._scale))
|
||||
scaled_val = val * stash[0].get(val.device)
|
||||
# Here we ensure the return dtype is the same as the outputs dtype.
|
||||
@ -218,7 +227,8 @@ class ShardedGradScaler(GradScaler):
|
||||
# ranks may have no (non-zero sized) parameter shards, necessitating the
|
||||
# initialization of `per_device_found_inf._per_device_tensors` here
|
||||
if not per_device_found_inf._per_device_tensors:
|
||||
assert self._scale is not None
|
||||
if self._scale is None:
|
||||
raise AssertionError("Expected _scale to be initialized, got None")
|
||||
per_device_found_inf.get(self._scale.device)
|
||||
return per_device_found_inf._per_device_tensors
|
||||
|
||||
@ -238,7 +248,8 @@ class ShardedGradScaler(GradScaler):
|
||||
raise RuntimeError("unscale_() is being called after step().")
|
||||
|
||||
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||
assert self._scale is not None
|
||||
if self._scale is None:
|
||||
raise AssertionError("Expected _scale to be initialized, got None")
|
||||
inv_scale = self._scale.double().reciprocal().float()
|
||||
found_inf = torch.full(
|
||||
(1,), 0.0, dtype=torch.float32, device=self._scale.device
|
||||
@ -279,7 +290,10 @@ class ShardedGradScaler(GradScaler):
|
||||
If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero.
|
||||
Otherwise, scale is multiplied by the growth factor when the growth interval is reached.
|
||||
"""
|
||||
assert self._scale is not None and self._growth_tracker is not None
|
||||
if self._scale is None or self._growth_tracker is None:
|
||||
raise AssertionError(
|
||||
"Expected _scale and _growth_tracker to be initialized, got None"
|
||||
)
|
||||
|
||||
if found_inf.item() >= 1.0:
|
||||
self._scale *= self._backoff_factor
|
||||
@ -323,9 +337,12 @@ class ShardedGradScaler(GradScaler):
|
||||
"new_scale should be a float or a 1-element torch.cuda.FloatTensor or "
|
||||
"torch.FloatTensor with requires_grad=False."
|
||||
)
|
||||
assert new_scale.device.type == self._device, reason
|
||||
assert new_scale.numel() == 1, reason
|
||||
assert new_scale.requires_grad is False, reason
|
||||
if new_scale.device.type != self._device:
|
||||
raise AssertionError(reason)
|
||||
if new_scale.numel() != 1:
|
||||
raise AssertionError(reason)
|
||||
if new_scale.requires_grad is not False:
|
||||
raise AssertionError(reason)
|
||||
self._scale.copy_(new_scale) # type: ignore[union-attr]
|
||||
else:
|
||||
# Consume shared inf/nan data collected from optimizers to update the scale.
|
||||
@ -336,7 +353,8 @@ class ShardedGradScaler(GradScaler):
|
||||
for found_inf in state["found_inf_per_device"].values()
|
||||
]
|
||||
|
||||
assert len(found_infs) > 0, "No inf checks were recorded prior to update."
|
||||
if len(found_infs) == 0:
|
||||
raise AssertionError("No inf checks were recorded prior to update.")
|
||||
|
||||
found_inf_combined = found_infs[0]
|
||||
if len(found_infs) > 1:
|
||||
|
@ -53,17 +53,20 @@ def _post_order_apply(
|
||||
_post_order_apply_inner(child_module, child_module_name, module)
|
||||
optional_module = fn(module)
|
||||
if optional_module is not None:
|
||||
assert isinstance(parent_module, nn.Module), (
|
||||
"Non-root modules should have their parent module set but got "
|
||||
f"{parent_module} for {module}"
|
||||
)
|
||||
assert module_name, (
|
||||
"Non-root modules should have their module name set but got "
|
||||
f"an empty module name for {module}"
|
||||
)
|
||||
assert isinstance(optional_module, nn.Module), (
|
||||
f"fn should return None or an nn.Module but got {optional_module}"
|
||||
)
|
||||
if not isinstance(parent_module, nn.Module):
|
||||
raise AssertionError(
|
||||
"Non-root modules should have their parent module set but got "
|
||||
f"{parent_module} for {module}"
|
||||
)
|
||||
if not module_name:
|
||||
raise AssertionError(
|
||||
"Non-root modules should have their module name set but got "
|
||||
f"an empty module name for {module}"
|
||||
)
|
||||
if not isinstance(optional_module, nn.Module):
|
||||
raise AssertionError(
|
||||
f"fn should return None or an nn.Module but got {optional_module}"
|
||||
)
|
||||
setattr(parent_module, module_name, optional_module)
|
||||
|
||||
_post_order_apply_inner(root_module, "", None)
|
||||
@ -456,7 +459,8 @@ def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
|
||||
the values provided by the :func:`enable_wrap` context
|
||||
"""
|
||||
if _ConfigAutoWrap.in_autowrap_context:
|
||||
assert _ConfigAutoWrap.wrapper_cls is not None
|
||||
if _ConfigAutoWrap.wrapper_cls is None:
|
||||
raise AssertionError("Expected _ConfigAutoWrap.wrapper_cls to be set")
|
||||
|
||||
wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides}
|
||||
return _wrap(
|
||||
@ -468,7 +472,8 @@ def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
|
||||
|
||||
|
||||
def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module:
|
||||
assert wrapper_cls is not None
|
||||
if wrapper_cls is None:
|
||||
raise AssertionError("Expected wrapper_cls to be set")
|
||||
if hasattr(module, "_wrap_overrides"):
|
||||
# If module has a _wrap_overrides attribute, we force overriding the
|
||||
# FSDP config with these attributes for this module. Currently this
|
||||
@ -506,14 +511,19 @@ def _recursive_wrap(
|
||||
(nn.Module, int):
|
||||
``module`` after wrapping and the numel recursively wrapped.
|
||||
"""
|
||||
assert auto_wrap_policy is not None, "Must specify auto_wrap_policy."
|
||||
assert wrapper_cls is not None, "Must specify wrapper_cls"
|
||||
if auto_wrap_policy is None:
|
||||
raise AssertionError("Must specify auto_wrap_policy.")
|
||||
if wrapper_cls is None:
|
||||
raise AssertionError("Must specify wrapper_cls")
|
||||
# Make sure no child is already wrapped.
|
||||
for _, child in module.named_modules():
|
||||
if child in ignored_modules:
|
||||
continue
|
||||
try:
|
||||
assert not isinstance(child, cast(type, wrapper_cls))
|
||||
if isinstance(child, cast(type, wrapper_cls)):
|
||||
raise AssertionError(
|
||||
f"Child module {child} is already wrapped by {wrapper_cls}"
|
||||
)
|
||||
except TypeError:
|
||||
# wrapper_cls is a function as opposed to a class type, just bypass above check.
|
||||
pass
|
||||
@ -523,7 +533,8 @@ def _recursive_wrap(
|
||||
p.numel() for p in module.parameters() if p not in ignored_params
|
||||
)
|
||||
|
||||
assert auto_wrap_policy is not None
|
||||
if auto_wrap_policy is None:
|
||||
raise AssertionError("Expected auto_wrap_policy to be set")
|
||||
if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel):
|
||||
total_wrapped_numel = 0
|
||||
# Iterate through the children, recursively wrap if necessary
|
||||
@ -575,9 +586,10 @@ class _ConfigAutoWrap:
|
||||
)
|
||||
_ConfigAutoWrap.in_autowrap_context = True
|
||||
# Get and save the wrapper cls for the context.
|
||||
assert "wrapper_cls" in kwargs.keys(), (
|
||||
"Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
|
||||
)
|
||||
if "wrapper_cls" not in kwargs.keys():
|
||||
raise AssertionError(
|
||||
"Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
|
||||
)
|
||||
_ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
|
||||
del kwargs["wrapper_cls"]
|
||||
# Save the rest.
|
||||
|
Reference in New Issue
Block a user