mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[distributed] Replace assert statements in distributed checkpoint with explicit checks (#165256)
Fixes partially #164878 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165256 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
75e2a9fae3
commit
2bcd892c86
@ -109,7 +109,8 @@ class _AsyncCheckpointProcess:
|
||||
# Wait for the checkpoint background process to initialize.
|
||||
# Using default GLOO init timeout.
|
||||
response = self._wait_for_response(timeout=1800)
|
||||
assert response == _CheckpointSaveProcessControlOpts.INIT_COMPLETE
|
||||
if not response == _CheckpointSaveProcessControlOpts.INIT_COMPLETE:
|
||||
raise AssertionError(f"Expected INIT_COMPLETE response, got {response}")
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self._save_process.is_alive():
|
||||
@ -175,7 +176,8 @@ class _AsyncCheckpointProcess:
|
||||
)
|
||||
self._send(async_cp_request)
|
||||
result = self._wait_for_response()
|
||||
assert isinstance(result, Metadata)
|
||||
if not isinstance(result, Metadata):
|
||||
raise AssertionError(f"Expected Metadata response, got {type(result)}")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
@ -245,7 +247,10 @@ class _AsyncCheckpointProcess:
|
||||
):
|
||||
logger.info("Terminating the checkpoint background process.")
|
||||
return
|
||||
assert isinstance(obj, _AsyncCheckpointRequest)
|
||||
if not isinstance(obj, _AsyncCheckpointRequest):
|
||||
raise AssertionError(
|
||||
f"Expected _AsyncCheckpointRequest, got {type(obj)}"
|
||||
)
|
||||
logger.info(
|
||||
f"Received async checkpoint request with id={obj.checkpoint_request_id.checkpoint_id}" # noqa: G004
|
||||
)
|
||||
@ -296,7 +301,10 @@ class _ProcessBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor):
|
||||
) -> Metadata:
|
||||
global _CHECKPOINT_PROCESS
|
||||
if _CHECKPOINT_PROCESS is None:
|
||||
assert pg_init_info is not None
|
||||
if pg_init_info is None:
|
||||
raise AssertionError(
|
||||
"pg_init_info must not be None when _CHECKPOINT_PROCESS is None"
|
||||
)
|
||||
ckpt_kwargs = {}
|
||||
if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None:
|
||||
ckpt_kwargs["checkpoint_id"] = ckpt_id
|
||||
@ -310,7 +318,10 @@ class _ProcessBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor):
|
||||
|
||||
create_checkpoint_daemon_process()
|
||||
|
||||
assert _CHECKPOINT_PROCESS is not None
|
||||
if _CHECKPOINT_PROCESS is None:
|
||||
raise AssertionError(
|
||||
"_CHECKPOINT_PROCESS must not be None after initialization"
|
||||
)
|
||||
staged_state_dict = (
|
||||
staging_future_or_state_dict.result()
|
||||
if isinstance(staging_future_or_state_dict, Future)
|
||||
|
@ -89,7 +89,8 @@ class _Checkpointer:
|
||||
process_group=self.process_group,
|
||||
planner=self.save_planner,
|
||||
)
|
||||
assert isinstance(response, Future)
|
||||
if not isinstance(response, Future):
|
||||
raise AssertionError("response should be a Future instance")
|
||||
return response
|
||||
|
||||
def load(self, state_dict: dict[str, Any]) -> None:
|
||||
|
@ -54,7 +54,8 @@ def dedup_save_plans(
|
||||
for plan_idx in plan_indices - {select_plan_idx}:
|
||||
plan_to_item_indices[plan_idx].discard(write_item_idx)
|
||||
# Sanity check
|
||||
assert len(all_plans) == len(plan_to_item_indices)
|
||||
if len(all_plans) != len(plan_to_item_indices):
|
||||
raise AssertionError("len(all_plans) != len(plan_to_item_indices)")
|
||||
# Create new plans with the updated write items post deduplication
|
||||
return [
|
||||
dataclasses.replace(
|
||||
|
@ -150,9 +150,8 @@ class DistBarrier(Barrier):
|
||||
Raises:
|
||||
AssertionError: If the distributed process group is not initialized.
|
||||
"""
|
||||
assert dist.is_initialized(), (
|
||||
"DistBarrier requires an initialized process group."
|
||||
)
|
||||
if not dist.is_initialized():
|
||||
raise AssertionError("DistBarrier requires an initialized process group.")
|
||||
|
||||
def execute_barrier(self) -> None:
|
||||
"""
|
||||
|
@ -135,7 +135,8 @@ class CheckpointProcess:
|
||||
)
|
||||
|
||||
# wait for the timeout or a response from subprocess
|
||||
assert self._parent_end is not None, "Parent end of pipe should be initialized"
|
||||
if self._parent_end is None:
|
||||
raise AssertionError("Parent end of pipe should be initialized")
|
||||
if not self._parent_end.poll(timeout=config.subprocess_init_timeout_secs):
|
||||
msg = f"Timed out after {config.subprocess_init_timeout_secs}s waiting for checkpoint subprocess to initialize"
|
||||
logger.error(msg)
|
||||
@ -161,7 +162,8 @@ class CheckpointProcess:
|
||||
os.getpid(),
|
||||
)
|
||||
|
||||
assert sub_rank == 0, "We need only one checkpointer per parent training"
|
||||
if sub_rank != 0:
|
||||
raise AssertionError("We need only one checkpointer per parent training")
|
||||
request = WorkerRequest(request_type=RequestType.PING, payload={})
|
||||
|
||||
try:
|
||||
@ -226,9 +228,8 @@ class CheckpointProcess:
|
||||
|
||||
def _send(self, request_type: RequestType, payload: dict[str, Any]) -> None:
|
||||
try:
|
||||
assert self._parent_end is not None, (
|
||||
"Parent end of pipe should be initialized"
|
||||
)
|
||||
if self._parent_end is None:
|
||||
raise AssertionError("Parent end of pipe should be initialized")
|
||||
self._parent_end.send(
|
||||
WorkerRequest(
|
||||
request_type=request_type,
|
||||
@ -244,9 +245,8 @@ class CheckpointProcess:
|
||||
|
||||
def _recv(self) -> Optional[dict[str, Any]]:
|
||||
try:
|
||||
assert self._parent_end is not None, (
|
||||
"Parent end of pipe should be initialized"
|
||||
)
|
||||
if self._parent_end is None:
|
||||
raise AssertionError("Parent end of pipe should be initialized")
|
||||
response = self._parent_end.recv()
|
||||
if response.success is False:
|
||||
error_msg = (
|
||||
|
@ -134,11 +134,12 @@ class CheckpointReader:
|
||||
|
||||
tensor_offset = source.untyped_storage()._checkpoint_offset
|
||||
|
||||
assert tensor_offset is not None, (
|
||||
"checkpoint_offset for tensor in torch serialized file is not set. This could"
|
||||
"happen if the checkpoint was saved with a older version of Pytorch."
|
||||
"Please make sure that the checkpoint was saved with Pytorch 2.7 or later."
|
||||
)
|
||||
if tensor_offset is None:
|
||||
raise AssertionError(
|
||||
"checkpoint_offset for tensor in torch serialized file is not set. This could "
|
||||
"happen if the checkpoint was saved with a older version of Pytorch. "
|
||||
"Please make sure that the checkpoint was saved with Pytorch 2.7 or later."
|
||||
)
|
||||
|
||||
tensor_len = source.nelement() * source.element_size()
|
||||
file.seek(
|
||||
|
@ -158,9 +158,10 @@ class DefaultStager(CheckpointStager):
|
||||
self._staging_stream = torch.Stream()
|
||||
|
||||
if self._config.use_non_blocking_copy:
|
||||
assert torch.accelerator.is_available(), (
|
||||
"Non-blocking copy requires that the current accelerator is available."
|
||||
)
|
||||
if not torch.accelerator.is_available():
|
||||
raise AssertionError(
|
||||
"Non-blocking copy requires that the current accelerator is available."
|
||||
)
|
||||
|
||||
def stage(
|
||||
self,
|
||||
@ -168,9 +169,10 @@ class DefaultStager(CheckpointStager):
|
||||
**kwargs: Any,
|
||||
) -> Union[STATE_DICT, Future[STATE_DICT]]:
|
||||
if self._config.use_async_staging:
|
||||
assert self._staging_executor is not None, (
|
||||
"Staging executor should be initialized for async staging"
|
||||
)
|
||||
if self._staging_executor is None:
|
||||
raise AssertionError(
|
||||
"Staging executor should be initialized for async staging"
|
||||
)
|
||||
return self._staging_executor.submit(
|
||||
self._stage,
|
||||
state_dict,
|
||||
@ -185,9 +187,10 @@ class DefaultStager(CheckpointStager):
|
||||
)
|
||||
|
||||
if self._config.use_non_blocking_copy:
|
||||
assert self._staging_stream or not self._config.use_async_staging, (
|
||||
"Non-blocking copy in a background thread for async staging needs staging_stream to be initialized."
|
||||
)
|
||||
if not (self._staging_stream or not self._config.use_async_staging):
|
||||
raise AssertionError(
|
||||
"Non-blocking copy in a background thread for async staging needs staging_stream to be initialized."
|
||||
)
|
||||
|
||||
# waits for the enqued copy operations to finish.
|
||||
self._staging_stream.synchronize() if self._staging_stream else torch.accelerator.synchronize()
|
||||
|
@ -37,7 +37,8 @@ class FileSystem(FileSystemBase):
|
||||
def create_stream(
|
||||
self, path: Union[str, os.PathLike], mode: str
|
||||
) -> Generator[io.IOBase, None, None]:
|
||||
assert self.fs is not None
|
||||
if self.fs is None:
|
||||
raise AssertionError("fs should not be None")
|
||||
path = os.fspath(path)
|
||||
|
||||
# fsspec does not support concurrent transactions, and not all
|
||||
|
@ -193,12 +193,12 @@ def _cast_tensor(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
caveat that the cast tensor may be larger than the original tensor due to
|
||||
the differences in striding.
|
||||
"""
|
||||
assert type(tensor) is torch.Tensor, (
|
||||
f"can only cast standard tensors not {type(tensor)}"
|
||||
)
|
||||
if type(tensor) is not torch.Tensor:
|
||||
raise AssertionError(f"can only cast standard tensors not {type(tensor)}")
|
||||
storage = tensor.untyped_storage()
|
||||
ret = torch.tensor(storage, dtype=dtype, device=tensor.device)
|
||||
assert ret.untyped_storage() is storage, "storage should be the same"
|
||||
if ret.untyped_storage() is not storage:
|
||||
raise AssertionError("storage should be the same")
|
||||
return ret
|
||||
|
||||
|
||||
@ -317,9 +317,8 @@ class PGTransport:
|
||||
if isinstance(inplace, DTensor):
|
||||
inplace = inplace._local_tensor
|
||||
t = _cast_tensor(inplace, torch.uint8)
|
||||
assert t.nbytes == v.nbytes, (
|
||||
"inplace tensor storage must be the same size"
|
||||
)
|
||||
if t.nbytes != v.nbytes:
|
||||
raise AssertionError("inplace tensor storage must be the same size")
|
||||
else:
|
||||
t = torch.empty(v.nbytes, dtype=torch.uint8, device=self._device)
|
||||
|
||||
|
@ -123,12 +123,13 @@ class StateDictStager:
|
||||
# Check if we've already cached this storage
|
||||
if storage in self._cached_storage_mapping:
|
||||
cached_storage = self._cached_storage_mapping[storage]
|
||||
assert cached_storage.size() == storage.size(), (
|
||||
"For async checkpointing, We cache storages in DRAM and reuse them."
|
||||
"Cached storage size does not match original storage size."
|
||||
"This should never happen as we track the original storage weakref "
|
||||
"and clean up the cache storage. Please report this to PyTorch Distributed Checkpointing."
|
||||
)
|
||||
if cached_storage.size() != storage.size():
|
||||
raise AssertionError(
|
||||
"For async checkpointing, We cache storages in DRAM and reuse them. "
|
||||
"Cached storage size does not match original storage size. "
|
||||
"This should never happen as we track the original storage weakref "
|
||||
"and clean up the cache storage. Please report this to PyTorch Distributed Checkpointing."
|
||||
)
|
||||
# Reuse cached storage but update with new data
|
||||
cached_storage.copy_(storage, non_blocking=non_blocking)
|
||||
return cached_storage
|
||||
|
@ -313,7 +313,8 @@ class DefaultLoadPlanner(LoadPlanner):
|
||||
self.is_coordinator = is_coordinator
|
||||
|
||||
def create_local_plan(self) -> LoadPlan:
|
||||
assert self.metadata is not None
|
||||
if self.metadata is None:
|
||||
raise AssertionError("self.metadata is not None")
|
||||
if self.flatten_state_dict:
|
||||
# To support checkpoints that are saved before v2.4, we have to
|
||||
# differentiate if the missing keys are due to old checkpoints.
|
||||
@ -432,8 +433,10 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
|
||||
metadata: Optional[Metadata] = None,
|
||||
is_coordinator: bool = False,
|
||||
) -> None:
|
||||
assert not state_dict
|
||||
assert metadata is not None
|
||||
if state_dict:
|
||||
raise AssertionError("not state_dict")
|
||||
if metadata is None:
|
||||
raise AssertionError("metadata is not None")
|
||||
|
||||
# rebuild the state dict from the metadata
|
||||
for k, v in metadata.state_dict_metadata.items():
|
||||
@ -549,13 +552,15 @@ def create_default_global_save_plan(
|
||||
new_items = []
|
||||
for item in plan.items:
|
||||
if item.type != WriteItemType.SHARD:
|
||||
assert item.index.fqn not in md
|
||||
if item.index.fqn in md:
|
||||
raise AssertionError("item.index.fqn not in md")
|
||||
|
||||
if item.type == WriteItemType.BYTE_IO:
|
||||
md[item.index.fqn] = BytesStorageMetadata()
|
||||
new_items.append(item)
|
||||
else:
|
||||
assert item.tensor_data is not None
|
||||
if item.tensor_data is None:
|
||||
raise AssertionError("item.tensor_data is not None")
|
||||
tensor_md = cast(
|
||||
TensorStorageMetadata,
|
||||
md.setdefault(
|
||||
@ -575,10 +580,11 @@ def create_default_global_save_plan(
|
||||
new_item = dataclasses.replace(item, index=new_index)
|
||||
new_items.append(new_item)
|
||||
|
||||
assert item.tensor_data.chunk is not None, f"""
|
||||
if item.tensor_data.chunk is None:
|
||||
raise AssertionError(f"""
|
||||
Cannot create MD for tensor without bounds.
|
||||
FQN: {item.index.fqn}
|
||||
"""
|
||||
""")
|
||||
tensor_md.chunks.append(item.tensor_data.chunk)
|
||||
new_plans.append(dataclasses.replace(plan, items=new_items))
|
||||
return (new_plans, Metadata(md))
|
||||
|
@ -109,7 +109,8 @@ def run(rank, world_size):
|
||||
|
||||
if epoch % SAVE_PERIOD == 0:
|
||||
if f is not None:
|
||||
assert isinstance(f, Future)
|
||||
if not isinstance(f, Future):
|
||||
raise AssertionError("f should be a Future instance")
|
||||
f.result()
|
||||
f = dcp.state_dict_saver.async_save(
|
||||
state_dict, checkpoint_id=CHECKPOINT_DIR
|
||||
@ -126,7 +127,8 @@ def run(rank, world_size):
|
||||
|
||||
_print("Reloading model from last checkpoint!")
|
||||
if f is not None:
|
||||
assert isinstance(f, Future)
|
||||
if not isinstance(f, Future):
|
||||
raise AssertionError("f should be a Future instance") from None
|
||||
f.result()
|
||||
dcp.load(state_dict)
|
||||
|
||||
|
@ -201,7 +201,8 @@ class _OverlappingCpuLoader(_TensorLoader):
|
||||
self.in_flight_data += tensor.numel() * tensor.element_size()
|
||||
|
||||
def _finish(self) -> Iterable[tuple[torch.Tensor, object]]:
|
||||
assert self._done
|
||||
if not self._done:
|
||||
raise AssertionError("_finish called before all items were processed")
|
||||
if len(self.current_items) > 0:
|
||||
self.stream.synchronize()
|
||||
return self.current_items
|
||||
@ -281,7 +282,8 @@ class _StorageWriterTransforms:
|
||||
|
||||
def _item_size(item: WriteItem) -> int:
|
||||
size = 1
|
||||
assert item.tensor_data is not None
|
||||
if item.tensor_data is None:
|
||||
raise AssertionError("WriteItem tensor_data must not be None")
|
||||
# can't use math.prod as PT needs to support older python
|
||||
for s in item.tensor_data.size:
|
||||
size *= s
|
||||
@ -329,11 +331,16 @@ def _write_item(
|
||||
)
|
||||
|
||||
if write_item.type == WriteItemType.BYTE_IO:
|
||||
assert isinstance(data, io.BytesIO)
|
||||
if not isinstance(data, io.BytesIO):
|
||||
raise AssertionError("Data must be io.BytesIO for BYTE_IO write items")
|
||||
transform_to.write(data.getbuffer())
|
||||
else:
|
||||
assert isinstance(data, torch.Tensor)
|
||||
assert data.device == torch.device("cpu")
|
||||
if not isinstance(data, torch.Tensor):
|
||||
raise AssertionError(
|
||||
"Data must be torch.Tensor for non-BYTE_IO write items"
|
||||
)
|
||||
if data.device != torch.device("cpu"):
|
||||
raise AssertionError("Tensor must be on CPU device")
|
||||
if serialization_format == SerializationFormat.TORCH_SAVE:
|
||||
torch.save(data, transform_to)
|
||||
|
||||
@ -428,7 +435,8 @@ def _write_files_from_queue(
|
||||
tensor_dict = {}
|
||||
metadata_dict = {}
|
||||
for tensor, write_item in loader.values():
|
||||
assert tensor.is_cpu
|
||||
if not tensor.is_cpu:
|
||||
raise AssertionError("Tensor must be on CPU")
|
||||
write_results.append(
|
||||
_write_item(
|
||||
transforms,
|
||||
@ -903,9 +911,10 @@ class FileSystemReader(StorageReader):
|
||||
)
|
||||
target_tensor = planner.resolve_tensor(req).detach()
|
||||
|
||||
assert target_tensor.size() == tensor.size(), (
|
||||
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
|
||||
)
|
||||
if target_tensor.size() != tensor.size():
|
||||
raise AssertionError(
|
||||
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
|
||||
)
|
||||
target_tensor.copy_(tensor)
|
||||
planner.commit_tensor(req, target_tensor)
|
||||
|
||||
@ -936,7 +945,8 @@ class FileSystemReader(StorageReader):
|
||||
self.storage_data = metadata.storage_data
|
||||
self.rank = kwargs.get("rank")
|
||||
self.use_collectives = kwargs.get("use_collectives", True)
|
||||
assert self.storage_data is not None
|
||||
if self.storage_data is None:
|
||||
raise AssertionError("storage_data must not be None in metadata")
|
||||
|
||||
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
|
||||
return plan
|
||||
|
@ -84,7 +84,8 @@ class BroadcastingTorchSaveReader(StorageReader):
|
||||
# the entire checkpoint on each rank, hopefully preventing OOM issues
|
||||
# TODO: read on each host, instead of only the coordinator
|
||||
if self.is_coordinator:
|
||||
assert self.checkpoint_id is not None
|
||||
if self.checkpoint_id is None:
|
||||
raise AssertionError("checkpoint_id must be set before reading data")
|
||||
torch_state_dict = torch.load(
|
||||
self.checkpoint_id, map_location="cpu", weights_only=False
|
||||
)
|
||||
@ -112,10 +113,11 @@ class BroadcastingTorchSaveReader(StorageReader):
|
||||
|
||||
tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths)
|
||||
target_tensor = planner.resolve_tensor(req).detach()
|
||||
assert target_tensor.size() == tensor.size(), (
|
||||
f"req {req.storage_index} mismatch sizes, "
|
||||
f"{target_tensor.size()} vs {tensor.size()}"
|
||||
)
|
||||
if not target_tensor.size() == tensor.size():
|
||||
raise AssertionError(
|
||||
f"req {req.storage_index} mismatch sizes, "
|
||||
f"{target_tensor.size()} vs {tensor.size()}"
|
||||
)
|
||||
target_tensor.copy_(tensor)
|
||||
planner.commit_tensor(req, target_tensor)
|
||||
|
||||
@ -128,9 +130,16 @@ class BroadcastingTorchSaveReader(StorageReader):
|
||||
"""Implementation of the StorageReader method"""
|
||||
self.is_coordinator = is_coordinator
|
||||
if self.is_coordinator:
|
||||
assert dist.get_rank() == self.coordinator_rank
|
||||
if not dist.get_rank() == self.coordinator_rank:
|
||||
raise AssertionError(
|
||||
f"Coordinator rank mismatch: expected {self.coordinator_rank}, "
|
||||
f"got {dist.get_rank()}"
|
||||
)
|
||||
|
||||
assert self.checkpoint_id is not None
|
||||
if self.checkpoint_id is None:
|
||||
raise AssertionError(
|
||||
"checkpoint_id must be set before setting up storage reader"
|
||||
)
|
||||
|
||||
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
|
||||
"""Implementation of the StorageReader method"""
|
||||
|
@ -226,9 +226,10 @@ class HuggingFaceStorageReader(FileSystemReader):
|
||||
tensor = f.get_slice(req.storage_index.fqn)[slices]
|
||||
target_tensor = planner.resolve_tensor(req).detach()
|
||||
|
||||
assert target_tensor.size() == tensor.size(), (
|
||||
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
|
||||
)
|
||||
if target_tensor.size() != tensor.size():
|
||||
raise AssertionError(
|
||||
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
|
||||
)
|
||||
|
||||
target_tensor.copy_(tensor)
|
||||
planner.commit_tensor(req, target_tensor)
|
||||
@ -299,9 +300,10 @@ class HuggingFaceStorageReader(FileSystemReader):
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
assert processed_count == len(per_file), (
|
||||
f"Not all files were processed: {processed_count} out of {len(per_file)}"
|
||||
)
|
||||
if processed_count != len(per_file):
|
||||
raise AssertionError(
|
||||
f"Not all files were processed: {processed_count} out of {len(per_file)}"
|
||||
)
|
||||
|
||||
fut: Future = Future()
|
||||
fut.set_result(None)
|
||||
|
@ -137,12 +137,10 @@ def _get_state_dict_2d_layout(
|
||||
for key, value in state_dict.items():
|
||||
specs[key] = (None, value.size())
|
||||
if _is_nested_tensor(value):
|
||||
assert len(value.local_shards()) == 1, (
|
||||
"Cannot handle ST with multiple shards"
|
||||
)
|
||||
assert isinstance(value, ShardedTensor), (
|
||||
"Can only handle nested ShardedTensor"
|
||||
)
|
||||
if not len(value.local_shards()) == 1:
|
||||
raise AssertionError("Cannot handle ST with multiple shards")
|
||||
if not isinstance(value, ShardedTensor):
|
||||
raise AssertionError("Can only handle nested ShardedTensor")
|
||||
shard = value.local_shards()[0]
|
||||
specs[key] = (
|
||||
shard.metadata.shard_offsets,
|
||||
@ -184,7 +182,8 @@ class _ReaderWithOffset(DefaultLoadPlanner):
|
||||
|
||||
offset = self.fqn_to_offset[fqn]
|
||||
|
||||
assert len(obj.local_shards()) == 1
|
||||
if not len(obj.local_shards()) == 1:
|
||||
raise AssertionError("Expected exactly one local shard")
|
||||
original_shard = obj.local_shards()[0]
|
||||
local_chunks = [
|
||||
ChunkStorageMetadata(
|
||||
@ -201,7 +200,8 @@ class _ReaderWithOffset(DefaultLoadPlanner):
|
||||
# TODO: The ReadItems will have a displaced MetadataIndex, fix it.
|
||||
# TODO: we should change _create_sharded_read_items to have more ergonomic API
|
||||
for ri in reqs:
|
||||
assert ri.dest_index.offset is not None
|
||||
if ri.dest_index.offset is None:
|
||||
raise AssertionError("dest_index.offset must not be None")
|
||||
original_offset = _element_wise_sub(ri.dest_index.offset, offset)
|
||||
original_index = dataclasses.replace(
|
||||
ri.dest_index, offset=torch.Size(original_offset)
|
||||
|
@ -107,9 +107,10 @@ class QuantizedHuggingFaceStorageReader(HuggingFaceStorageReader):
|
||||
|
||||
target_tensor = planner.resolve_tensor(req).detach()
|
||||
|
||||
assert target_tensor.size() == tensor.size(), (
|
||||
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
|
||||
)
|
||||
if target_tensor.size() != tensor.size():
|
||||
raise AssertionError(
|
||||
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
|
||||
)
|
||||
|
||||
target_tensor.copy_(tensor)
|
||||
planner.commit_tensor(req, target_tensor)
|
||||
|
@ -193,9 +193,10 @@ class DefaultStager(AsyncStager):
|
||||
self._staging_stream = torch.Stream()
|
||||
|
||||
if self._config.use_non_blocking_copy:
|
||||
assert torch.accelerator.is_available(), (
|
||||
"Non-blocking copy requires that the current accelerator is available."
|
||||
)
|
||||
if not torch.accelerator.is_available():
|
||||
raise AssertionError(
|
||||
"Non-blocking copy requires that the current accelerator is available."
|
||||
)
|
||||
|
||||
self._staging_future: Optional[Future[STATE_DICT_TYPE]] = None
|
||||
|
||||
@ -215,7 +216,10 @@ class DefaultStager(AsyncStager):
|
||||
state_dict (STATE_DICT_TYPE): The state_dict to be staged.
|
||||
"""
|
||||
if self._config.use_async_staging:
|
||||
assert self._staging_executor is not None
|
||||
if self._staging_executor is None:
|
||||
raise AssertionError(
|
||||
"staging_executor should not be None for async staging"
|
||||
)
|
||||
self._staging_future = self._staging_executor.submit(
|
||||
self._stage,
|
||||
state_dict,
|
||||
@ -227,9 +231,10 @@ class DefaultStager(AsyncStager):
|
||||
|
||||
def _stage(self, state_dict: STATE_DICT_TYPE, **kwargs: Any) -> STATE_DICT_TYPE:
|
||||
if self._config.use_non_blocking_copy:
|
||||
assert self._staging_stream or not self._config.use_async_staging, (
|
||||
"Non-blocking copy in a background thread for async staging needs staging_stream to be initialized."
|
||||
)
|
||||
if not (self._staging_stream or not self._config.use_async_staging):
|
||||
raise AssertionError(
|
||||
"Non-blocking copy in a background thread for async staging needs staging_stream to be initialized."
|
||||
)
|
||||
with (
|
||||
self._staging_stream
|
||||
if self._staging_stream is not None
|
||||
|
@ -186,7 +186,8 @@ def _get_fqns(
|
||||
curr_obj = model
|
||||
for i, curr_obj_name in enumerate(obj_names):
|
||||
if isinstance(curr_obj, DDP):
|
||||
assert curr_obj_name == "module"
|
||||
if curr_obj_name != "module":
|
||||
raise AssertionError(f"Expected 'module', got '{curr_obj_name}'")
|
||||
curr_obj = curr_obj.module
|
||||
if not skip_ddp_prefix:
|
||||
fqn_obj_names.append(curr_obj_name)
|
||||
@ -203,7 +204,8 @@ def _get_fqns(
|
||||
fqn_obj_names.append(curr_obj_name)
|
||||
curr_obj = getattr(curr_obj, curr_obj_name)
|
||||
elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule):
|
||||
assert curr_obj_name == "_orig_mod"
|
||||
if curr_obj_name != "_orig_mod":
|
||||
raise AssertionError(f"Expected '_orig_mod', got '{curr_obj_name}'")
|
||||
curr_obj = curr_obj._orig_mod
|
||||
if not skip_compiler_prefix:
|
||||
fqn_obj_names.append(curr_obj_name)
|
||||
@ -329,7 +331,8 @@ def _verify_options(
|
||||
if module not in submodules:
|
||||
continue
|
||||
fqns = _get_fqns(model, name)
|
||||
assert len(fqns) == 1, "Submodule FQN should only have 1 instance"
|
||||
if len(fqns) != 1:
|
||||
raise AssertionError("Submodule FQN should only have 1 instance")
|
||||
submodule_prefixes.update(f"{fqn}." for fqn in fqns)
|
||||
|
||||
if options.broadcast_from_rank0 and not options.full_state_dict:
|
||||
@ -408,7 +411,8 @@ def _verify_state_dict(
|
||||
) -> None:
|
||||
for module in info.fsdp_modules:
|
||||
fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module)
|
||||
assert fsdp_state is not None, "Expected a fsdp_state with a fsdp module."
|
||||
if fsdp_state is None:
|
||||
raise AssertionError("Expected a fsdp_state with a fsdp module.")
|
||||
|
||||
# Verify if the model_state_dict and optim_state_dict are valid. This API
|
||||
# should give the users an explicit error message to debug or report.
|
||||
@ -483,7 +487,10 @@ def _get_model_state_dict(
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
fqns = _get_fqns(model, key)
|
||||
assert len(fqns) == 1, (key, fqns)
|
||||
if len(fqns) != 1:
|
||||
raise AssertionError(
|
||||
f"Expected 1 FQN for key '{key}', got {len(fqns)}: {fqns}"
|
||||
)
|
||||
fqn = next(iter(fqns))
|
||||
if fqn != key:
|
||||
# As we only support FSDP, DDP, and TP, the only cases are
|
||||
@ -746,7 +753,8 @@ def _unflatten_optim_state_dict(
|
||||
continue
|
||||
|
||||
params = pg_state[-1][_PARAMS]
|
||||
assert isinstance(params, list) # typing
|
||||
if not isinstance(params, list):
|
||||
raise AssertionError(f"Expected list, got {type(params)}")
|
||||
params.append(fqn)
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
@ -808,7 +816,10 @@ def _get_optim_state_dict(
|
||||
fqn_pid_mapping = {}
|
||||
for key, param in model.named_parameters():
|
||||
fqns = _get_fqns(model, key)
|
||||
assert len(fqns) == 1
|
||||
if len(fqns) != 1:
|
||||
raise AssertionError(
|
||||
f"Expected 1 FQN for key '{key}', got {len(fqns)}"
|
||||
)
|
||||
fqn = next(iter(fqns))
|
||||
if param not in param_pid_mapping:
|
||||
continue
|
||||
@ -886,7 +897,8 @@ def _split_optim_state_dict(
|
||||
continue
|
||||
|
||||
params = pg_state[-1][_PARAMS]
|
||||
assert isinstance(params, list)
|
||||
if not isinstance(params, list):
|
||||
raise AssertionError(f"Expected list, got {type(params)}")
|
||||
params.append(fqn)
|
||||
if param.requires_grad:
|
||||
state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn]
|
||||
@ -965,7 +977,10 @@ def _load_optim_state_dict(
|
||||
if fqns == fqns_with_compiler:
|
||||
continue
|
||||
|
||||
assert len(fqns) == 1
|
||||
if len(fqns) != 1:
|
||||
raise AssertionError(
|
||||
f"Expected 1 FQN for '{original_fqn}', got {len(fqns)}"
|
||||
)
|
||||
fqn = fqns.pop()
|
||||
fqn_with_compiler = fqns_with_compiler.pop()
|
||||
for g in optim_state_dict[_PG]:
|
||||
@ -999,7 +1014,8 @@ def _load_optim_state_dict(
|
||||
return t
|
||||
|
||||
_ = tree_map_only(torch.Tensor, _device, local_state_dict)
|
||||
assert device is not None
|
||||
if device is None:
|
||||
raise AssertionError("Expected device to be set")
|
||||
flatten_osd, osd_mapping = _flatten_state_dict(optim_state_dict)
|
||||
flatten_local_osd, local_osd_mapping = _flatten_state_dict(local_state_dict)
|
||||
if info.broadcast_from_rank0:
|
||||
@ -1012,7 +1028,10 @@ def _load_optim_state_dict(
|
||||
# having additional parameters ultimately.
|
||||
for optim_key in flatten_osd.keys():
|
||||
if optim_key not in flatten_local_osd:
|
||||
assert optim_key in osd_mapping
|
||||
if optim_key not in osd_mapping:
|
||||
raise AssertionError(
|
||||
f"Expected key '{optim_key}' in osd_mapping"
|
||||
)
|
||||
flatten_local_osd[optim_key] = flatten_osd[optim_key]
|
||||
local_osd_mapping[optim_key] = osd_mapping[optim_key]
|
||||
optim_state_dict = _unflatten_state_dict(
|
||||
@ -1225,7 +1244,10 @@ def _unflatten_model_state_dict(
|
||||
continue
|
||||
|
||||
fqns = _get_fqns(model, name)
|
||||
assert len(fqns) == 1, "FQNs for a submodule should only have 1 element"
|
||||
if len(fqns) != 1:
|
||||
raise AssertionError(
|
||||
"FQNs for a submodule should only have 1 element"
|
||||
)
|
||||
prefix = f"{next(iter(fqns))}."
|
||||
new_state_dict.update(
|
||||
{prefix + subfqn: value for subfqn, value in sub_state_dict.items()}
|
||||
|
@ -246,8 +246,10 @@ def _load_state_dict(
|
||||
except Exception:
|
||||
logger.info("Rank local metadata is not found.")
|
||||
|
||||
assert planner is not None
|
||||
assert metadata is not None
|
||||
if planner is None:
|
||||
raise AssertionError("planner is None")
|
||||
if metadata is None:
|
||||
raise AssertionError("metadata is None")
|
||||
planner.set_up_planner(state_dict, metadata, distW.is_coordinator)
|
||||
|
||||
if (
|
||||
@ -269,7 +271,8 @@ def _load_state_dict(
|
||||
|
||||
@_dcp_method_logger(**ckpt_kwargs)
|
||||
def global_step(all_local_plans):
|
||||
assert planner is not None
|
||||
if planner is None:
|
||||
raise AssertionError("planner is None")
|
||||
all_local_plans = planner.create_global_plan(all_local_plans)
|
||||
all_local_plans = storage_reader.prepare_global_plan(all_local_plans)
|
||||
return all_local_plans
|
||||
@ -284,8 +287,10 @@ def _load_state_dict(
|
||||
|
||||
@_dcp_method_logger(**ckpt_kwargs)
|
||||
def read_data():
|
||||
assert planner is not None
|
||||
assert central_plan is not None
|
||||
if planner is None:
|
||||
raise AssertionError("planner is None")
|
||||
if central_plan is None:
|
||||
raise AssertionError("central_plan is None")
|
||||
final_local_plan = planner.finish_plan(central_plan)
|
||||
all_reads = storage_reader.read_data(final_local_plan, planner)
|
||||
|
||||
|
@ -292,11 +292,10 @@ def async_save(
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
pg = process_group or _get_default_group()
|
||||
assert (
|
||||
torch.device("cpu") in pg._device_types # type: ignore[attr-defined]
|
||||
), (
|
||||
"A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
|
||||
)
|
||||
if torch.device("cpu") not in pg._device_types:
|
||||
raise AssertionError(
|
||||
"A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
|
||||
)
|
||||
|
||||
if async_stager is None:
|
||||
if storage_writer is not None and isinstance(storage_writer, AsyncStager):
|
||||
@ -396,7 +395,8 @@ def _save_state_dict(
|
||||
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
|
||||
if planner is None:
|
||||
planner = DefaultSavePlanner()
|
||||
assert planner is not None
|
||||
if planner is None:
|
||||
raise AssertionError("planner is None")
|
||||
|
||||
global_metadata = None
|
||||
|
||||
@ -407,7 +407,8 @@ def _save_state_dict(
|
||||
|
||||
@_dcp_method_logger(**ckpt_kwargs)
|
||||
def local_step():
|
||||
assert planner is not None
|
||||
if planner is None:
|
||||
raise AssertionError("planner is None")
|
||||
storage_meta = storage_writer.storage_meta()
|
||||
if "storage_meta" not in inspect.signature(planner.set_up_planner).parameters:
|
||||
warnings.warn(
|
||||
@ -443,7 +444,8 @@ def _save_state_dict(
|
||||
def global_step(all_local_plans):
|
||||
nonlocal global_metadata
|
||||
|
||||
assert planner is not None
|
||||
if planner is None:
|
||||
raise AssertionError("planner is None")
|
||||
all_local_plans, global_metadata = planner.create_global_plan(all_local_plans)
|
||||
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
|
||||
return all_local_plans
|
||||
@ -458,8 +460,10 @@ def _save_state_dict(
|
||||
|
||||
@_dcp_method_logger(**ckpt_kwargs)
|
||||
def write_data():
|
||||
assert planner is not None
|
||||
assert central_plan is not None
|
||||
if planner is None:
|
||||
raise AssertionError("planner is None")
|
||||
if central_plan is None:
|
||||
raise AssertionError("central_plan is None")
|
||||
final_local_plan = planner.finish_plan(central_plan)
|
||||
all_writes = storage_writer.write_data(final_local_plan, planner)
|
||||
|
||||
@ -468,7 +472,8 @@ def _save_state_dict(
|
||||
|
||||
@_dcp_method_logger(**ckpt_kwargs)
|
||||
def finish_checkpoint(all_results):
|
||||
assert global_metadata is not None
|
||||
if global_metadata is None:
|
||||
raise AssertionError("global_metadata is None")
|
||||
storage_writer.finish(metadata=global_metadata, results=all_results)
|
||||
return global_metadata
|
||||
|
||||
|
@ -168,7 +168,8 @@ class _DistWrapper:
|
||||
|
||||
local_reply = gather_result[0]
|
||||
else:
|
||||
assert object_list is not None
|
||||
if object_list is None:
|
||||
raise AssertionError("object_list is None")
|
||||
local_reply = object_list[0]
|
||||
return local_reply
|
||||
|
||||
@ -196,7 +197,8 @@ class _DistWrapper:
|
||||
all_data = self.gather_object(local_data)
|
||||
all_results: Optional[list[Union[R, CheckpointException]]] = None
|
||||
if self.is_coordinator:
|
||||
assert all_data is not None
|
||||
if all_data is None:
|
||||
raise AssertionError("all_data is None")
|
||||
node_failures = _get_failure_dict(all_data)
|
||||
|
||||
if len(node_failures) == 0:
|
||||
@ -243,7 +245,8 @@ class _DistWrapper:
|
||||
all_data = self.gather_object(local_data)
|
||||
result: Optional[Union[R, CheckpointException]] = None
|
||||
if self.is_coordinator:
|
||||
assert all_data is not None
|
||||
if all_data is None:
|
||||
raise AssertionError("all_data is None")
|
||||
node_failures = _get_failure_dict(all_data)
|
||||
if len(node_failures) == 0:
|
||||
try:
|
||||
@ -465,10 +468,12 @@ def _api_bc_check(func):
|
||||
p.name for p in sig.parameters.values() if p.kind == p.KEYWORD_ONLY
|
||||
]
|
||||
if "storage_writer" in kwonlyargs:
|
||||
assert "storage_writer" not in kwargs, (args, kwargs)
|
||||
if "storage_writer" in kwargs:
|
||||
raise AssertionError(f"storage_writer in kwargs: {(args, kwargs)}")
|
||||
kwargs["storage_writer"] = args[1]
|
||||
elif "storage_reader" in kwonlyargs:
|
||||
assert "storage_reader" not in kwargs, (args, kwargs)
|
||||
if "storage_reader" in kwargs:
|
||||
raise AssertionError(f"storage_reader in kwargs: {(args, kwargs)}")
|
||||
kwargs["storage_reader"] = args[1]
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected kwonlyargs = {kwonlyargs}")
|
||||
|
Reference in New Issue
Block a user