It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165037
Approved by: https://github.com/mlazos
Summary:
This implements staging in way that doesnt mess up checkpointing semantics. We want to be close to torch.save/load semantics and when async checkpointing is used it messes up shared storages, doesnt handle custom objects or tensors well. EG: users passes a state_dict with a cuda tensor in datatype. this is deepcloned causing the staging tensor to be created on GPU. This can cause ooms is hard to debug.
This diffs hooks into deepcopy of storages to move them to cpu using the cached storages created for async checkpoint staging. This allows reusing storages created for staging to avoid recreating them on each checkpoint while also being flexible enough to handle any changes - clean up old storages or create new ones as needed.
Lifetime of staging storages is tied to the original storage object. when the original storage object is gc-ed, we delete the corresponding staging storage from cache possibly causing it to gc-ed is there are no other references. I am using data_ptr of the storage to keep track of this. Please share thoughts on this.
The alternative is to use fqn's instead of storage_id and verify the underlying storage object has same shape/size,etc to make the caching logic work. Current implementation is much simpler and cleaner.
The API:
```
# construct a stager once per job in checkpointing.
stager = StateDictStager(pin_memory=pin_memory, share_memory=share_memory)
# do this on every checkpoint:
with staging_context(stager):
cpu_state_dict = copy.deepcopy(state_dict)
```
Also, adds support for pinned-memory.
One problem this implementation does not address is that we lose the original device.
The only alternatives here are - pickle synchronously like torch.save but with special handling for storages. It is valuable to keep state_dict throughout the checkpointing process. so users can manipulate and debug as needed. so we need to unpickle in the background process. I think this is flexible, not performant and not very different to current solution but needs more code. One idea if we really want to address is this to stick the original device in a some variable on storage and then use it recover on load side. I think we do not need this for now and can be explicit about losing device type for async checkpointing.
Update:
Note: Due to reservations on hooking into deepcopy to customize it, the PR is now updated to use deepcopy like logic to clone the state_dict. There are some caveats to this solution:
1. Duplicated deepcopy code to hook into for tensors. There is a risk of this code getting outdated with python version changes. This is needed to handle several different types like NamedTuples, frozen dataclasses, nested dataclasses. deepcopy logic is relying on reduce_ex to get a function with which these can be constructed.
2. Since we are bypassing deepcopy and adding custom logic to clone a tensor, we are missing some of the functionality that exists in deepcopy for torch.Tensor like _clear_non_serializable_cached_data(), or other logic. Would like thoughts on which logic or if everything should be copied?
3. If any object implemented deepcopy , we will not be able to handle any tensors in the attrs with this logic because they likely just call copy.deepcopy on the attrs instead of this deepcopy logic. We are taking care of subclasses of torch.Tensor to workaround this.
The new API:
```
# construct a stager once per job in checkpointing.
stager = StateDictStager(pin_memory=pin_memory, share_memory=share_memory)
# do this on every checkpoint:
cpu_state_dict = copy.stage(state_dict)
```
Test Plan:
unit tests
Differential Revision: D75993324
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155192
Approved by: https://github.com/mikaylagawarecki, https://github.com/pradeepfn
Summary:
### Diff Context
1. Sometimes, a tensor might have non-zero size and 0 numel. In this case, pinning memory will fail
so we take a best guess at how to replicate the tensor below to maintain symmetry in the returned
state dict.
2. ShardedTensor copying was not handled originally in PyTorch state_dict copy APIs, handled in this diff.
Test Plan: CI
Differential Revision: D75553096
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156092
Approved by: https://github.com/pradeepfn
Fixes#138842
`device` is always the device of the `local_state_dict`, which may or may not be CPU, which is not supported by NCCL backend.
Instead, create broadcasted tensors on one of `pg._device_types` and then move the tensors back if `local_state_dict`'s `device` was not supported by the `ProcessGroup`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148865
Approved by: https://github.com/mori360
For destributed state dict api [migration](https://github.com/pytorch/torchtune/pull/2138), make the changes here:
1. `load_from_full_model_state_dict` at TorchTune calls `set_model_state_dict` with the options on whether to have cpu_offload. Add cpu_offload at _load_model_state_dict to process to cpu if config is True
2. Change the device check as lora_finetune might hace 2 device types, accept that to be valid.
3. Some changes to optimize the memory performance:
3.1 use `.detach().clone()` instead of view directly
3.2 if local_state is not meta, copy `full_tensor[slices]` to `ret.to_local()`
4. add relative unit tests
Memory performance calling from TorchTune with llama2/7B_full:
1. cpu_offload = True
<img width="555" alt="Screenshot 2024-12-18 at 1 36 47 PM" src="https://github.com/user-attachments/assets/429261f5-1107-4592-b295-de3944a2614b" />
2. cpu_offload = False
<img width="555" alt="Screenshot 2024-12-18 at 1 36 52 PM" src="https://github.com/user-attachments/assets/40bf281a-236a-4218-826b-b1192a10c806" />
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142845
Approved by: https://github.com/fegin
Fix https://github.com/pytorch/pytorch/issues/134095
This fix distributed state dict full_state_dict option hang during set_state_dict. We switch `_distribute_tensors` in _state_dict_utils.py to use `DTensor.from_local` instead of `distribute_tensor` to support FSDP2+TP 2D strided sharding use case, as `distribute_tensor` cannot handle strided sharding yet. `distribute_tensor` incurs a scatter behind the scenes, while `DTensor.from_local` takes the local slice from the full tensor on each rank to create the DTensor (no collective). This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135725
Approved by: https://github.com/fegin
Fix https://github.com/pytorch/pytorch/issues/134095
This fix distributed state dict full_state_dict option hang during set_state_dict. We switch `_distribute_tensors` in _state_dict_utils.py to use `DTensor.from_local` instead of `distribute_tensor` to support FSDP2+TP 2D strided sharding use case, as `distribute_tensor` cannot handle strided sharding yet. `distribute_tensor` incurs a scatter behind the scenes, while `DTensor.from_local` takes the local slice from the full tensor on each rank to create the DTensor (no collective). This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135725
Approved by: https://github.com/fegin
reland of https://github.com/pytorch/pytorch/pull/133113
I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :(
----
Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes:
* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next PRs)
* To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module
The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203
Approved by: https://github.com/tianyu-l
Optimize memory cost at [PR#129635](https://github.com/pytorch/pytorch/pull/129635)
There are 2 main part of the optimization here:
1. optimize the tensor distributing part, postpone the full_tensor generation, which avoids the memory overlap, saves around 50% peak memory at 2 param test case.
2. apply `assign=True` for the `load_state_dict`, saves memory cost at the state dict loading by assigning the input param, around 50% peak memory at loading part.
Future work:
Memory optimization to the opt will be conducted in the next PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134025
Approved by: https://github.com/fegin
Co-authored-by: Rachel Guo <guorachel@meta.com>
Moving DTensor to be in the public namespace, to formally add the
documentation page that includes all the public APIs. This includes:
* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next
PRs)
* To preserve the BC for users still using the `torch.distributed._tensor`,
I added a shim script to redirect old path calls to the new module
The BC preserving is evidented by the fact that all DTensor tests are still
working without changing the public imports. So it's safe to land the
changes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113
Approved by: https://github.com/XilunWu
ghstack dependencies: #133305, #133306
Causing some terrible error messages e.g. :
```
# printing directly: cudaError.???
# casting to int first: 712
Traceback (most recent call last):
File "/data/users/lpasqualin/fbsource/fbcode/scripts/lpasqualin/playground.py", line 15, in <module>
main()
File "/data/users/lpasqualin/fbsource/fbcode/scripts/lpasqualin/playground.py", line 11, in main
_create_cpu_state_dict(sd, share_memory=True, pin_memory=True)
File "/home/lpasqualin/pytorch/torch/distributed/_state_dict_utils.py", line 436, in _create_cpu_state_dict
ret = _iterate_state_dict(
^^^^^^^^^^^^^^^^^^^^
File "/home/lpasqualin/pytorch/torch/distributed/_state_dict_utils.py", line 143, in _iterate_state_dict
ret = {
^
File "/home/lpasqualin/pytorch/torch/distributed/_state_dict_utils.py", line 144, in <dictcomp>
key: _iterate_state_dict(
^^^^^^^^^^^^^^^^^^^^
File "/home/lpasqualin/pytorch/torch/distributed/_state_dict_utils.py", line 125, in _iterate_state_dict
ret = tensor_func(iter_object, pg, device, companion_obj)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lpasqualin/pytorch/torch/distributed/_state_dict_utils.py", line 428, in tensor_func
succ == 0
AssertionError: Pinning shared memory failed with error-code: cudaError.???
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132089
Approved by: https://github.com/Skylion007
When pin_memory and share_memory both are set to True in _create_cpu_state_dict, the memory is pinned using cudaHostRegister but is never unpinned. So, once tensor is created and freed, when a new tensor is created the caching allocator is allocating the same memory. This fails with below error.
```
obj = <[RuntimeError('CUDA error: part or all of the requested memory range is already mapped\nCUDA kernel errors might be a...pile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n') raised in repr()] Tensor object at 0x7f0028a4d6c0> pg = None, device = None, _ = None
```
This PR fixes this by unregistering this memory on tensor free by attaching a hook.
This is easily reproducible with xlformers checkpointing unit tests and the fix is verified with the same.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131270
Approved by: https://github.com/LucasLLC
Fixes#126950
`ptd_state_dict` with `broadcast_from_rank0=False` might miss 2 condition checks in the `set_optimizer_state_dict`
Here we add another condition `full_state_dict=True` with corresponding tensor distribution without broadcasting if broadcast_from_rank0=False
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128004
Approved by: https://github.com/fegin
I found that returning the copy is actually useful in situations where you might do something like:
```
ret = _copy_state_dict(obj, cache)
ret.update(some_other_values)
```
and would like `cache` not to change structure from `ret.update(some_other_values)`. Open to some notes here, not returning a copy might force the user to do some additional copies for this case.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123567
Approved by: https://github.com/wz337