[distributed] Replace assert statements in distributed checkpoint with explicit checks (#165256)

Fixes partially #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165256
Approved by: https://github.com/albanD
This commit is contained in:
Rohit Singh Rathaur
2025-10-17 20:14:32 +00:00
committed by PyTorch MergeBot
parent 75e2a9fae3
commit 2bcd892c86
22 changed files with 218 additions and 129 deletions

View File

@ -109,7 +109,8 @@ class _AsyncCheckpointProcess:
# Wait for the checkpoint background process to initialize.
# Using default GLOO init timeout.
response = self._wait_for_response(timeout=1800)
assert response == _CheckpointSaveProcessControlOpts.INIT_COMPLETE
if not response == _CheckpointSaveProcessControlOpts.INIT_COMPLETE:
raise AssertionError(f"Expected INIT_COMPLETE response, got {response}")
def __del__(self) -> None:
if self._save_process.is_alive():
@ -175,7 +176,8 @@ class _AsyncCheckpointProcess:
)
self._send(async_cp_request)
result = self._wait_for_response()
assert isinstance(result, Metadata)
if not isinstance(result, Metadata):
raise AssertionError(f"Expected Metadata response, got {type(result)}")
return result
@staticmethod
@ -245,7 +247,10 @@ class _AsyncCheckpointProcess:
):
logger.info("Terminating the checkpoint background process.")
return
assert isinstance(obj, _AsyncCheckpointRequest)
if not isinstance(obj, _AsyncCheckpointRequest):
raise AssertionError(
f"Expected _AsyncCheckpointRequest, got {type(obj)}"
)
logger.info(
f"Received async checkpoint request with id={obj.checkpoint_request_id.checkpoint_id}" # noqa: G004
)
@ -296,7 +301,10 @@ class _ProcessBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor):
) -> Metadata:
global _CHECKPOINT_PROCESS
if _CHECKPOINT_PROCESS is None:
assert pg_init_info is not None
if pg_init_info is None:
raise AssertionError(
"pg_init_info must not be None when _CHECKPOINT_PROCESS is None"
)
ckpt_kwargs = {}
if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None:
ckpt_kwargs["checkpoint_id"] = ckpt_id
@ -310,7 +318,10 @@ class _ProcessBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor):
create_checkpoint_daemon_process()
assert _CHECKPOINT_PROCESS is not None
if _CHECKPOINT_PROCESS is None:
raise AssertionError(
"_CHECKPOINT_PROCESS must not be None after initialization"
)
staged_state_dict = (
staging_future_or_state_dict.result()
if isinstance(staging_future_or_state_dict, Future)

View File

@ -89,7 +89,8 @@ class _Checkpointer:
process_group=self.process_group,
planner=self.save_planner,
)
assert isinstance(response, Future)
if not isinstance(response, Future):
raise AssertionError("response should be a Future instance")
return response
def load(self, state_dict: dict[str, Any]) -> None:

View File

@ -54,7 +54,8 @@ def dedup_save_plans(
for plan_idx in plan_indices - {select_plan_idx}:
plan_to_item_indices[plan_idx].discard(write_item_idx)
# Sanity check
assert len(all_plans) == len(plan_to_item_indices)
if len(all_plans) != len(plan_to_item_indices):
raise AssertionError("len(all_plans) != len(plan_to_item_indices)")
# Create new plans with the updated write items post deduplication
return [
dataclasses.replace(

View File

@ -150,9 +150,8 @@ class DistBarrier(Barrier):
Raises:
AssertionError: If the distributed process group is not initialized.
"""
assert dist.is_initialized(), (
"DistBarrier requires an initialized process group."
)
if not dist.is_initialized():
raise AssertionError("DistBarrier requires an initialized process group.")
def execute_barrier(self) -> None:
"""

View File

@ -135,7 +135,8 @@ class CheckpointProcess:
)
# wait for the timeout or a response from subprocess
assert self._parent_end is not None, "Parent end of pipe should be initialized"
if self._parent_end is None:
raise AssertionError("Parent end of pipe should be initialized")
if not self._parent_end.poll(timeout=config.subprocess_init_timeout_secs):
msg = f"Timed out after {config.subprocess_init_timeout_secs}s waiting for checkpoint subprocess to initialize"
logger.error(msg)
@ -161,7 +162,8 @@ class CheckpointProcess:
os.getpid(),
)
assert sub_rank == 0, "We need only one checkpointer per parent training"
if sub_rank != 0:
raise AssertionError("We need only one checkpointer per parent training")
request = WorkerRequest(request_type=RequestType.PING, payload={})
try:
@ -226,9 +228,8 @@ class CheckpointProcess:
def _send(self, request_type: RequestType, payload: dict[str, Any]) -> None:
try:
assert self._parent_end is not None, (
"Parent end of pipe should be initialized"
)
if self._parent_end is None:
raise AssertionError("Parent end of pipe should be initialized")
self._parent_end.send(
WorkerRequest(
request_type=request_type,
@ -244,9 +245,8 @@ class CheckpointProcess:
def _recv(self) -> Optional[dict[str, Any]]:
try:
assert self._parent_end is not None, (
"Parent end of pipe should be initialized"
)
if self._parent_end is None:
raise AssertionError("Parent end of pipe should be initialized")
response = self._parent_end.recv()
if response.success is False:
error_msg = (

View File

@ -134,11 +134,12 @@ class CheckpointReader:
tensor_offset = source.untyped_storage()._checkpoint_offset
assert tensor_offset is not None, (
"checkpoint_offset for tensor in torch serialized file is not set. This could"
"happen if the checkpoint was saved with a older version of Pytorch."
"Please make sure that the checkpoint was saved with Pytorch 2.7 or later."
)
if tensor_offset is None:
raise AssertionError(
"checkpoint_offset for tensor in torch serialized file is not set. This could "
"happen if the checkpoint was saved with a older version of Pytorch. "
"Please make sure that the checkpoint was saved with Pytorch 2.7 or later."
)
tensor_len = source.nelement() * source.element_size()
file.seek(

View File

@ -158,9 +158,10 @@ class DefaultStager(CheckpointStager):
self._staging_stream = torch.Stream()
if self._config.use_non_blocking_copy:
assert torch.accelerator.is_available(), (
"Non-blocking copy requires that the current accelerator is available."
)
if not torch.accelerator.is_available():
raise AssertionError(
"Non-blocking copy requires that the current accelerator is available."
)
def stage(
self,
@ -168,9 +169,10 @@ class DefaultStager(CheckpointStager):
**kwargs: Any,
) -> Union[STATE_DICT, Future[STATE_DICT]]:
if self._config.use_async_staging:
assert self._staging_executor is not None, (
"Staging executor should be initialized for async staging"
)
if self._staging_executor is None:
raise AssertionError(
"Staging executor should be initialized for async staging"
)
return self._staging_executor.submit(
self._stage,
state_dict,
@ -185,9 +187,10 @@ class DefaultStager(CheckpointStager):
)
if self._config.use_non_blocking_copy:
assert self._staging_stream or not self._config.use_async_staging, (
"Non-blocking copy in a background thread for async staging needs staging_stream to be initialized."
)
if not (self._staging_stream or not self._config.use_async_staging):
raise AssertionError(
"Non-blocking copy in a background thread for async staging needs staging_stream to be initialized."
)
# waits for the enqued copy operations to finish.
self._staging_stream.synchronize() if self._staging_stream else torch.accelerator.synchronize()

View File

@ -37,7 +37,8 @@ class FileSystem(FileSystemBase):
def create_stream(
self, path: Union[str, os.PathLike], mode: str
) -> Generator[io.IOBase, None, None]:
assert self.fs is not None
if self.fs is None:
raise AssertionError("fs should not be None")
path = os.fspath(path)
# fsspec does not support concurrent transactions, and not all

View File

@ -193,12 +193,12 @@ def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
caveat that the cast tensor may be larger than the original tensor due to
the differences in striding.
"""
assert type(tensor) is torch.Tensor, (
f"can only cast standard tensors not {type(tensor)}"
)
if type(tensor) is not torch.Tensor:
raise AssertionError(f"can only cast standard tensors not {type(tensor)}")
storage = tensor.untyped_storage()
ret = torch.tensor(storage, dtype=dtype, device=tensor.device)
assert ret.untyped_storage() is storage, "storage should be the same"
if ret.untyped_storage() is not storage:
raise AssertionError("storage should be the same")
return ret
@ -317,9 +317,8 @@ class PGTransport:
if isinstance(inplace, DTensor):
inplace = inplace._local_tensor
t = _cast_tensor(inplace, torch.uint8)
assert t.nbytes == v.nbytes, (
"inplace tensor storage must be the same size"
)
if t.nbytes != v.nbytes:
raise AssertionError("inplace tensor storage must be the same size")
else:
t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device)

View File

@ -123,12 +123,13 @@ class StateDictStager:
# Check if we've already cached this storage
if storage in self._cached_storage_mapping:
cached_storage = self._cached_storage_mapping[storage]
assert cached_storage.size() == storage.size(), (
"For async checkpointing, We cache storages in DRAM and reuse them."
"Cached storage size does not match original storage size."
"This should never happen as we track the original storage weakref "
"and clean up the cache storage. Please report this to PyTorch Distributed Checkpointing."
)
if cached_storage.size() != storage.size():
raise AssertionError(
"For async checkpointing, We cache storages in DRAM and reuse them. "
"Cached storage size does not match original storage size. "
"This should never happen as we track the original storage weakref "
"and clean up the cache storage. Please report this to PyTorch Distributed Checkpointing."
)
# Reuse cached storage but update with new data
cached_storage.copy_(storage, non_blocking=non_blocking)
return cached_storage

View File

@ -313,7 +313,8 @@ class DefaultLoadPlanner(LoadPlanner):
self.is_coordinator = is_coordinator
def create_local_plan(self) -> LoadPlan:
assert self.metadata is not None
if self.metadata is None:
raise AssertionError("self.metadata is not None")
if self.flatten_state_dict:
# To support checkpoints that are saved before v2.4, we have to
# differentiate if the missing keys are due to old checkpoints.
@ -432,8 +433,10 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
metadata: Optional[Metadata] = None,
is_coordinator: bool = False,
) -> None:
assert not state_dict
assert metadata is not None
if state_dict:
raise AssertionError("not state_dict")
if metadata is None:
raise AssertionError("metadata is not None")
# rebuild the state dict from the metadata
for k, v in metadata.state_dict_metadata.items():
@ -549,13 +552,15 @@ def create_default_global_save_plan(
new_items = []
for item in plan.items:
if item.type != WriteItemType.SHARD:
assert item.index.fqn not in md
if item.index.fqn in md:
raise AssertionError("item.index.fqn not in md")
if item.type == WriteItemType.BYTE_IO:
md[item.index.fqn] = BytesStorageMetadata()
new_items.append(item)
else:
assert item.tensor_data is not None
if item.tensor_data is None:
raise AssertionError("item.tensor_data is not None")
tensor_md = cast(
TensorStorageMetadata,
md.setdefault(
@ -575,10 +580,11 @@ def create_default_global_save_plan(
new_item = dataclasses.replace(item, index=new_index)
new_items.append(new_item)
assert item.tensor_data.chunk is not None, f"""
if item.tensor_data.chunk is None:
raise AssertionError(f"""
Cannot create MD for tensor without bounds.
FQN: {item.index.fqn}
"""
""")
tensor_md.chunks.append(item.tensor_data.chunk)
new_plans.append(dataclasses.replace(plan, items=new_items))
return (new_plans, Metadata(md))

View File

@ -109,7 +109,8 @@ def run(rank, world_size):
if epoch % SAVE_PERIOD == 0:
if f is not None:
assert isinstance(f, Future)
if not isinstance(f, Future):
raise AssertionError("f should be a Future instance")
f.result()
f = dcp.state_dict_saver.async_save(
state_dict, checkpoint_id=CHECKPOINT_DIR
@ -126,7 +127,8 @@ def run(rank, world_size):
_print("Reloading model from last checkpoint!")
if f is not None:
assert isinstance(f, Future)
if not isinstance(f, Future):
raise AssertionError("f should be a Future instance") from None
f.result()
dcp.load(state_dict)

View File

@ -201,7 +201,8 @@ class _OverlappingCpuLoader(_TensorLoader):
self.in_flight_data += tensor.numel() * tensor.element_size()
def _finish(self) -> Iterable[tuple[torch.Tensor, object]]:
assert self._done
if not self._done:
raise AssertionError("_finish called before all items were processed")
if len(self.current_items) > 0:
self.stream.synchronize()
return self.current_items
@ -281,7 +282,8 @@ class _StorageWriterTransforms:
def _item_size(item: WriteItem) -> int:
size = 1
assert item.tensor_data is not None
if item.tensor_data is None:
raise AssertionError("WriteItem tensor_data must not be None")
# can't use math.prod as PT needs to support older python
for s in item.tensor_data.size:
size *= s
@ -329,11 +331,16 @@ def _write_item(
)
if write_item.type == WriteItemType.BYTE_IO:
assert isinstance(data, io.BytesIO)
if not isinstance(data, io.BytesIO):
raise AssertionError("Data must be io.BytesIO for BYTE_IO write items")
transform_to.write(data.getbuffer())
else:
assert isinstance(data, torch.Tensor)
assert data.device == torch.device("cpu")
if not isinstance(data, torch.Tensor):
raise AssertionError(
"Data must be torch.Tensor for non-BYTE_IO write items"
)
if data.device != torch.device("cpu"):
raise AssertionError("Tensor must be on CPU device")
if serialization_format == SerializationFormat.TORCH_SAVE:
torch.save(data, transform_to)
@ -428,7 +435,8 @@ def _write_files_from_queue(
tensor_dict = {}
metadata_dict = {}
for tensor, write_item in loader.values():
assert tensor.is_cpu
if not tensor.is_cpu:
raise AssertionError("Tensor must be on CPU")
write_results.append(
_write_item(
transforms,
@ -903,9 +911,10 @@ class FileSystemReader(StorageReader):
)
target_tensor = planner.resolve_tensor(req).detach()
assert target_tensor.size() == tensor.size(), (
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
)
if target_tensor.size() != tensor.size():
raise AssertionError(
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
)
target_tensor.copy_(tensor)
planner.commit_tensor(req, target_tensor)
@ -936,7 +945,8 @@ class FileSystemReader(StorageReader):
self.storage_data = metadata.storage_data
self.rank = kwargs.get("rank")
self.use_collectives = kwargs.get("use_collectives", True)
assert self.storage_data is not None
if self.storage_data is None:
raise AssertionError("storage_data must not be None in metadata")
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
return plan

View File

@ -84,7 +84,8 @@ class BroadcastingTorchSaveReader(StorageReader):
# the entire checkpoint on each rank, hopefully preventing OOM issues
# TODO: read on each host, instead of only the coordinator
if self.is_coordinator:
assert self.checkpoint_id is not None
if self.checkpoint_id is None:
raise AssertionError("checkpoint_id must be set before reading data")
torch_state_dict = torch.load(
self.checkpoint_id, map_location="cpu", weights_only=False
)
@ -112,10 +113,11 @@ class BroadcastingTorchSaveReader(StorageReader):
tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths)
target_tensor = planner.resolve_tensor(req).detach()
assert target_tensor.size() == tensor.size(), (
f"req {req.storage_index} mismatch sizes, "
f"{target_tensor.size()} vs {tensor.size()}"
)
if not target_tensor.size() == tensor.size():
raise AssertionError(
f"req {req.storage_index} mismatch sizes, "
f"{target_tensor.size()} vs {tensor.size()}"
)
target_tensor.copy_(tensor)
planner.commit_tensor(req, target_tensor)
@ -128,9 +130,16 @@ class BroadcastingTorchSaveReader(StorageReader):
"""Implementation of the StorageReader method"""
self.is_coordinator = is_coordinator
if self.is_coordinator:
assert dist.get_rank() == self.coordinator_rank
if not dist.get_rank() == self.coordinator_rank:
raise AssertionError(
f"Coordinator rank mismatch: expected {self.coordinator_rank}, "
f"got {dist.get_rank()}"
)
assert self.checkpoint_id is not None
if self.checkpoint_id is None:
raise AssertionError(
"checkpoint_id must be set before setting up storage reader"
)
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
"""Implementation of the StorageReader method"""

View File

@ -226,9 +226,10 @@ class HuggingFaceStorageReader(FileSystemReader):
tensor = f.get_slice(req.storage_index.fqn)[slices]
target_tensor = planner.resolve_tensor(req).detach()
assert target_tensor.size() == tensor.size(), (
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
)
if target_tensor.size() != tensor.size():
raise AssertionError(
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
)
target_tensor.copy_(tensor)
planner.commit_tensor(req, target_tensor)
@ -299,9 +300,10 @@ class HuggingFaceStorageReader(FileSystemReader):
except queue.Empty:
pass
assert processed_count == len(per_file), (
f"Not all files were processed: {processed_count} out of {len(per_file)}"
)
if processed_count != len(per_file):
raise AssertionError(
f"Not all files were processed: {processed_count} out of {len(per_file)}"
)
fut: Future = Future()
fut.set_result(None)

View File

@ -137,12 +137,10 @@ def _get_state_dict_2d_layout(
for key, value in state_dict.items():
specs[key] = (None, value.size())
if _is_nested_tensor(value):
assert len(value.local_shards()) == 1, (
"Cannot handle ST with multiple shards"
)
assert isinstance(value, ShardedTensor), (
"Can only handle nested ShardedTensor"
)
if not len(value.local_shards()) == 1:
raise AssertionError("Cannot handle ST with multiple shards")
if not isinstance(value, ShardedTensor):
raise AssertionError("Can only handle nested ShardedTensor")
shard = value.local_shards()[0]
specs[key] = (
shard.metadata.shard_offsets,
@ -184,7 +182,8 @@ class _ReaderWithOffset(DefaultLoadPlanner):
offset = self.fqn_to_offset[fqn]
assert len(obj.local_shards()) == 1
if not len(obj.local_shards()) == 1:
raise AssertionError("Expected exactly one local shard")
original_shard = obj.local_shards()[0]
local_chunks = [
ChunkStorageMetadata(
@ -201,7 +200,8 @@ class _ReaderWithOffset(DefaultLoadPlanner):
# TODO: The ReadItems will have a displaced MetadataIndex, fix it.
# TODO: we should change _create_sharded_read_items to have more ergonomic API
for ri in reqs:
assert ri.dest_index.offset is not None
if ri.dest_index.offset is None:
raise AssertionError("dest_index.offset must not be None")
original_offset = _element_wise_sub(ri.dest_index.offset, offset)
original_index = dataclasses.replace(
ri.dest_index, offset=torch.Size(original_offset)

View File

@ -107,9 +107,10 @@ class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader):
target_tensor = planner.resolve_tensor(req).detach()
assert target_tensor.size() == tensor.size(), (
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
)
if target_tensor.size() != tensor.size():
raise AssertionError(
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
)
target_tensor.copy_(tensor)
planner.commit_tensor(req, target_tensor)

View File

@ -193,9 +193,10 @@ class DefaultStager(AsyncStager):
self._staging_stream = torch.Stream()
if self._config.use_non_blocking_copy:
assert torch.accelerator.is_available(), (
"Non-blocking copy requires that the current accelerator is available."
)
if not torch.accelerator.is_available():
raise AssertionError(
"Non-blocking copy requires that the current accelerator is available."
)
self._staging_future: Optional[Future[STATE_DICT_TYPE]] = None
@ -215,7 +216,10 @@ class DefaultStager(AsyncStager):
state_dict (STATE_DICT_TYPE): The state_dict to be staged.
"""
if self._config.use_async_staging:
assert self._staging_executor is not None
if self._staging_executor is None:
raise AssertionError(
"staging_executor should not be None for async staging"
)
self._staging_future = self._staging_executor.submit(
self._stage,
state_dict,
@ -227,9 +231,10 @@ class DefaultStager(AsyncStager):
def _stage(self, state_dict: STATE_DICT_TYPE, **kwargs: Any) -> STATE_DICT_TYPE:
if self._config.use_non_blocking_copy:
assert self._staging_stream or not self._config.use_async_staging, (
"Non-blocking copy in a background thread for async staging needs staging_stream to be initialized."
)
if not (self._staging_stream or not self._config.use_async_staging):
raise AssertionError(
"Non-blocking copy in a background thread for async staging needs staging_stream to be initialized."
)
with (
self._staging_stream
if self._staging_stream is not None

View File

@ -186,7 +186,8 @@ def _get_fqns(
curr_obj = model
for i, curr_obj_name in enumerate(obj_names):
if isinstance(curr_obj, DDP):
assert curr_obj_name == "module"
if curr_obj_name != "module":
raise AssertionError(f"Expected 'module', got '{curr_obj_name}'")
curr_obj = curr_obj.module
if not skip_ddp_prefix:
fqn_obj_names.append(curr_obj_name)
@ -203,7 +204,8 @@ def _get_fqns(
fqn_obj_names.append(curr_obj_name)
curr_obj = getattr(curr_obj, curr_obj_name)
elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule):
assert curr_obj_name == "_orig_mod"
if curr_obj_name != "_orig_mod":
raise AssertionError(f"Expected '_orig_mod', got '{curr_obj_name}'")
curr_obj = curr_obj._orig_mod
if not skip_compiler_prefix:
fqn_obj_names.append(curr_obj_name)
@ -329,7 +331,8 @@ def _verify_options(
if module not in submodules:
continue
fqns = _get_fqns(model, name)
assert len(fqns) == 1, "Submodule FQN should only have 1 instance"
if len(fqns) != 1:
raise AssertionError("Submodule FQN should only have 1 instance")
submodule_prefixes.update(f"{fqn}." for fqn in fqns)
if options.broadcast_from_rank0 and not options.full_state_dict:
@ -408,7 +411,8 @@ def _verify_state_dict(
) -> None:
for module in info.fsdp_modules:
fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
assert fsdp_state is not None, "Expected a fsdp_state with a fsdp module."
if fsdp_state is None:
raise AssertionError("Expected a fsdp_state with a fsdp module.")
# Verify if the model_state_dict and optim_state_dict are valid. This API
# should give the users an explicit error message to debug or report.
@ -483,7 +487,10 @@ def _get_model_state_dict(
for key in list(state_dict.keys()):
fqns = _get_fqns(model, key)
assert len(fqns) == 1, (key, fqns)
if len(fqns) != 1:
raise AssertionError(
f"Expected 1 FQN for key '{key}', got {len(fqns)}: {fqns}"
)
fqn = next(iter(fqns))
if fqn != key:
# As we only support FSDP, DDP, and TP, the only cases are
@ -746,7 +753,8 @@ def _unflatten_optim_state_dict(
continue
params = pg_state[-1][_PARAMS]
assert isinstance(params, list) # typing
if not isinstance(params, list):
raise AssertionError(f"Expected list, got {type(params)}")
params.append(fqn)
if not param.requires_grad:
continue
@ -808,7 +816,10 @@ def _get_optim_state_dict(
fqn_pid_mapping = {}
for key, param in model.named_parameters():
fqns = _get_fqns(model, key)
assert len(fqns) == 1
if len(fqns) != 1:
raise AssertionError(
f"Expected 1 FQN for key '{key}', got {len(fqns)}"
)
fqn = next(iter(fqns))
if param not in param_pid_mapping:
continue
@ -886,7 +897,8 @@ def _split_optim_state_dict(
continue
params = pg_state[-1][_PARAMS]
assert isinstance(params, list)
if not isinstance(params, list):
raise AssertionError(f"Expected list, got {type(params)}")
params.append(fqn)
if param.requires_grad:
state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn]
@ -965,7 +977,10 @@ def _load_optim_state_dict(
if fqns == fqns_with_compiler:
continue
assert len(fqns) == 1
if len(fqns) != 1:
raise AssertionError(
f"Expected 1 FQN for '{original_fqn}', got {len(fqns)}"
)
fqn = fqns.pop()
fqn_with_compiler = fqns_with_compiler.pop()
for g in optim_state_dict[_PG]:
@ -999,7 +1014,8 @@ def _load_optim_state_dict(
return t
_ = tree_map_only(torch.Tensor, _device, local_state_dict)
assert device is not None
if device is None:
raise AssertionError("Expected device to be set")
flatten_osd, osd_mapping = _flatten_state_dict(optim_state_dict)
flatten_local_osd, local_osd_mapping = _flatten_state_dict(local_state_dict)
if info.broadcast_from_rank0:
@ -1012,7 +1028,10 @@ def _load_optim_state_dict(
# having additional parameters ultimately.
for optim_key in flatten_osd.keys():
if optim_key not in flatten_local_osd:
assert optim_key in osd_mapping
if optim_key not in osd_mapping:
raise AssertionError(
f"Expected key '{optim_key}' in osd_mapping"
)
flatten_local_osd[optim_key] = flatten_osd[optim_key]
local_osd_mapping[optim_key] = osd_mapping[optim_key]
optim_state_dict = _unflatten_state_dict(
@ -1225,7 +1244,10 @@ def _unflatten_model_state_dict(
continue
fqns = _get_fqns(model, name)
assert len(fqns) == 1, "FQNs for a submodule should only have 1 element"
if len(fqns) != 1:
raise AssertionError(
"FQNs for a submodule should only have 1 element"
)
prefix = f"{next(iter(fqns))}."
new_state_dict.update(
{prefix + subfqn: value for subfqn, value in sub_state_dict.items()}

View File

@ -246,8 +246,10 @@ def _load_state_dict(
except Exception:
logger.info("Rank local metadata is not found.")
assert planner is not None
assert metadata is not None
if planner is None:
raise AssertionError("planner is None")
if metadata is None:
raise AssertionError("metadata is None")
planner.set_up_planner(state_dict, metadata, distW.is_coordinator)
if (
@ -269,7 +271,8 @@ def _load_state_dict(
@_dcp_method_logger(**ckpt_kwargs)
def global_step(all_local_plans):
assert planner is not None
if planner is None:
raise AssertionError("planner is None")
all_local_plans = planner.create_global_plan(all_local_plans)
all_local_plans = storage_reader.prepare_global_plan(all_local_plans)
return all_local_plans
@ -284,8 +287,10 @@ def _load_state_dict(
@_dcp_method_logger(**ckpt_kwargs)
def read_data():
assert planner is not None
assert central_plan is not None
if planner is None:
raise AssertionError("planner is None")
if central_plan is None:
raise AssertionError("central_plan is None")
final_local_plan = planner.finish_plan(central_plan)
all_reads = storage_reader.read_data(final_local_plan, planner)

View File

@ -292,11 +292,10 @@ def async_save(
if dist.is_available() and dist.is_initialized():
pg = process_group or _get_default_group()
assert (
torch.device("cpu") in pg._device_types # type: ignore[attr-defined]
), (
"A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
)
if torch.device("cpu") not in pg._device_types:
raise AssertionError(
"A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
)
if async_stager is None:
if storage_writer is not None and isinstance(storage_writer, AsyncStager):
@ -396,7 +395,8 @@ def _save_state_dict(
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
if planner is None:
planner = DefaultSavePlanner()
assert planner is not None
if planner is None:
raise AssertionError("planner is None")
global_metadata = None
@ -407,7 +407,8 @@ def _save_state_dict(
@_dcp_method_logger(**ckpt_kwargs)
def local_step():
assert planner is not None
if planner is None:
raise AssertionError("planner is None")
storage_meta = storage_writer.storage_meta()
if "storage_meta" not in inspect.signature(planner.set_up_planner).parameters:
warnings.warn(
@ -443,7 +444,8 @@ def _save_state_dict(
def global_step(all_local_plans):
nonlocal global_metadata
assert planner is not None
if planner is None:
raise AssertionError("planner is None")
all_local_plans, global_metadata = planner.create_global_plan(all_local_plans)
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
return all_local_plans
@ -458,8 +460,10 @@ def _save_state_dict(
@_dcp_method_logger(**ckpt_kwargs)
def write_data():
assert planner is not None
assert central_plan is not None
if planner is None:
raise AssertionError("planner is None")
if central_plan is None:
raise AssertionError("central_plan is None")
final_local_plan = planner.finish_plan(central_plan)
all_writes = storage_writer.write_data(final_local_plan, planner)
@ -468,7 +472,8 @@ def _save_state_dict(
@_dcp_method_logger(**ckpt_kwargs)
def finish_checkpoint(all_results):
assert global_metadata is not None
if global_metadata is None:
raise AssertionError("global_metadata is None")
storage_writer.finish(metadata=global_metadata, results=all_results)
return global_metadata

View File

@ -168,7 +168,8 @@ class _DistWrapper:
local_reply = gather_result[0]
else:
assert object_list is not None
if object_list is None:
raise AssertionError("object_list is None")
local_reply = object_list[0]
return local_reply
@ -196,7 +197,8 @@ class _DistWrapper:
all_data = self.gather_object(local_data)
all_results: Optional[list[Union[R, CheckpointException]]] = None
if self.is_coordinator:
assert all_data is not None
if all_data is None:
raise AssertionError("all_data is None")
node_failures = _get_failure_dict(all_data)
if len(node_failures) == 0:
@ -243,7 +245,8 @@ class _DistWrapper:
all_data = self.gather_object(local_data)
result: Optional[Union[R, CheckpointException]] = None
if self.is_coordinator:
assert all_data is not None
if all_data is None:
raise AssertionError("all_data is None")
node_failures = _get_failure_dict(all_data)
if len(node_failures) == 0:
try:
@ -465,10 +468,12 @@ def _api_bc_check(func):
p.name for p in sig.parameters.values() if p.kind == p.KEYWORD_ONLY
]
if "storage_writer" in kwonlyargs:
assert "storage_writer" not in kwargs, (args, kwargs)
if "storage_writer" in kwargs:
raise AssertionError(f"storage_writer in kwargs: {(args, kwargs)}")
kwargs["storage_writer"] = args[1]
elif "storage_reader" in kwonlyargs:
assert "storage_reader" not in kwargs, (args, kwargs)
if "storage_reader" in kwargs:
raise AssertionError(f"storage_reader in kwargs: {(args, kwargs)}")
kwargs["storage_reader"] = args[1]
else:
raise RuntimeError(f"Unexpected kwonlyargs = {kwonlyargs}")