mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This implements staging in way that doesnt mess up checkpointing semantics. We want to be close to torch.save/load semantics and when async checkpointing is used it messes up shared storages, doesnt handle custom objects or tensors well. EG: users passes a state_dict with a cuda tensor in datatype. this is deepcloned causing the staging tensor to be created on GPU. This can cause ooms is hard to debug. This diffs hooks into deepcopy of storages to move them to cpu using the cached storages created for async checkpoint staging. This allows reusing storages created for staging to avoid recreating them on each checkpoint while also being flexible enough to handle any changes - clean up old storages or create new ones as needed. Lifetime of staging storages is tied to the original storage object. when the original storage object is gc-ed, we delete the corresponding staging storage from cache possibly causing it to gc-ed is there are no other references. I am using data_ptr of the storage to keep track of this. Please share thoughts on this. The alternative is to use fqn's instead of storage_id and verify the underlying storage object has same shape/size,etc to make the caching logic work. Current implementation is much simpler and cleaner. The API: ``` # construct a stager once per job in checkpointing. stager = StateDictStager(pin_memory=pin_memory, share_memory=share_memory) # do this on every checkpoint: with staging_context(stager): cpu_state_dict = copy.deepcopy(state_dict) ``` Also, adds support for pinned-memory. One problem this implementation does not address is that we lose the original device. The only alternatives here are - pickle synchronously like torch.save but with special handling for storages. It is valuable to keep state_dict throughout the checkpointing process. so users can manipulate and debug as needed. so we need to unpickle in the background process. I think this is flexible, not performant and not very different to current solution but needs more code. One idea if we really want to address is this to stick the original device in a some variable on storage and then use it recover on load side. I think we do not need this for now and can be explicit about losing device type for async checkpointing. Update: Note: Due to reservations on hooking into deepcopy to customize it, the PR is now updated to use deepcopy like logic to clone the state_dict. There are some caveats to this solution: 1. Duplicated deepcopy code to hook into for tensors. There is a risk of this code getting outdated with python version changes. This is needed to handle several different types like NamedTuples, frozen dataclasses, nested dataclasses. deepcopy logic is relying on reduce_ex to get a function with which these can be constructed. 2. Since we are bypassing deepcopy and adding custom logic to clone a tensor, we are missing some of the functionality that exists in deepcopy for torch.Tensor like _clear_non_serializable_cached_data(), or other logic. Would like thoughts on which logic or if everything should be copied? 3. If any object implemented deepcopy , we will not be able to handle any tensors in the attrs with this logic because they likely just call copy.deepcopy on the attrs instead of this deepcopy logic. We are taking care of subclasses of torch.Tensor to workaround this. The new API: ``` # construct a stager once per job in checkpointing. stager = StateDictStager(pin_memory=pin_memory, share_memory=share_memory) # do this on every checkpoint: cpu_state_dict = copy.stage(state_dict) ``` Test Plan: unit tests Differential Revision: D75993324 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155192 Approved by: https://github.com/mikaylagawarecki, https://github.com/pradeepfn
25 lines
747 B
Python
25 lines
747 B
Python
import torch
|
|
|
|
|
|
def pin_memory(data_ptr: int, size: int) -> None:
|
|
cudart = torch.cuda.cudart()
|
|
succ = int(
|
|
cudart.cudaHostRegister(
|
|
data_ptr,
|
|
size,
|
|
1, # lines up with 'cudaHostRegisterPortable'
|
|
)
|
|
)
|
|
|
|
if succ != 0:
|
|
raise RuntimeError(
|
|
f"Registering memory failed with cudaError: {succ}."
|
|
" It's possible that this is an asynchronous error raised from a previous cuda operation."
|
|
" Consider launching with CUDA_LAUNCH_BLOCKING=1 to debug."
|
|
)
|
|
|
|
|
|
def unpin_memory(data_ptr: int) -> None:
|
|
succ = int(torch.cuda.cudart().cudaHostUnregister(data_ptr))
|
|
assert succ == 0, f"Unpinning shared memory failed with error-code: {succ}"
|