[DCP] Adds support for non-primatives in async_save by deep copying during cpu offloading (#123941)

Adds support for non-primatives in async_save by deep copying during cpu offloading.

If users are not type checking, the expectation in async is likely that the object is copied

Differential Revision: [D56065237](https://our.internmc.facebook.com/intern/diff/D56065237/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123941
Approved by: https://github.com/fegin
This commit is contained in:
Lucas Pasqualin
2024-04-16 10:22:34 -07:00
committed by PyTorch MergeBot
parent 14b2273b0c
commit 46a25cc0db
2 changed files with 5 additions and 2 deletions

View File

@ -1,3 +1,4 @@
import copy
import io
import math
from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING
@ -166,7 +167,7 @@ def _iterate_state_dict(
if isinstance(iter_object, tuple):
ret = tuple(ret)
elif not type_check:
ret = iter_object
ret = copy.deepcopy(iter_object)
else:
raise ValueError(f"Unexpected value type {type(iter_object)}")

View File

@ -216,7 +216,9 @@ def async_save(
torch.device("cpu") in pg._device_types # type: ignore[attr-defined]
), "A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
cpu_state_dict = _offload_state_dict_to_cpu(_stateful_to_state_dict(state_dict))
cpu_state_dict = _offload_state_dict_to_cpu(
_stateful_to_state_dict(state_dict), type_check=False
)
executor = ThreadPoolExecutor(max_workers=1)
f: Future = executor.submit(