mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[DCP][PyTorch Staging APIs][2/x] Handle 0-elem case + ShardedTensor copy for staging (#156092)
Summary: ### Diff Context 1. Sometimes, a tensor might have non-zero size and 0 numel. In this case, pinning memory will fail so we take a best guess at how to replicate the tensor below to maintain symmetry in the returned state dict. 2. ShardedTensor copying was not handled originally in PyTorch state_dict copy APIs, handled in this diff. Test Plan: CI Differential Revision: D75553096 Pull Request resolved: https://github.com/pytorch/pytorch/pull/156092 Approved by: https://github.com/pradeepfn
This commit is contained in:
committed by
PyTorch MergeBot
parent
a5b4463d60
commit
9bfefda296
@ -187,6 +187,12 @@ def _iterate_state_dict(
|
||||
companion_obj._local_tensor.copy_(
|
||||
ret._local_tensor, non_blocking=non_blocking
|
||||
)
|
||||
elif isinstance(companion_obj, ShardedTensor):
|
||||
assert isinstance(ret, ShardedTensor)
|
||||
for idx, shard in enumerate(companion_obj.local_shards()):
|
||||
shard.tensor.copy_(
|
||||
ret.local_shards()[idx].tensor, non_blocking=non_blocking
|
||||
)
|
||||
else:
|
||||
companion_obj.copy_(ret, non_blocking=non_blocking)
|
||||
ret = companion_obj
|
||||
@ -402,6 +408,15 @@ def _create_cpu_state_dict(
|
||||
if len(obj.size()) == 0:
|
||||
return torch.tensor(0, dtype=obj.dtype)
|
||||
|
||||
# sometimes, a tensor might have non-zero size and 0 numel. In this case, pinning memory will fail
|
||||
# so we take a best guess at how to replicate the tensor below to maintain symmetry in the returned
|
||||
# state dict.
|
||||
if obj.numel() == 0 or obj.data_ptr() == 0:
|
||||
t = torch.zeros_like(obj, device="cpu")
|
||||
if share_memory:
|
||||
t = t.share_memory_()
|
||||
return t
|
||||
|
||||
if share_memory:
|
||||
t = torch.empty(*tuple(obj.size()), dtype=obj.dtype)
|
||||
t = t.share_memory_()
|
||||
@ -446,9 +461,28 @@ def _create_cpu_state_dict(
|
||||
ret._local_tensor = tensor_func(ret._local_tensor, pg, device, None)
|
||||
return ret
|
||||
|
||||
def sharded_tensor_func(
|
||||
obj: ShardedTensor,
|
||||
pg: Optional[dist.ProcessGroup],
|
||||
device: Optional[torch.device],
|
||||
_: Any,
|
||||
) -> ShardedTensor:
|
||||
if not obj.local_shards():
|
||||
return obj
|
||||
|
||||
if obj.device != torch.device("cpu"):
|
||||
ret = obj.to(device="cpu")
|
||||
else:
|
||||
ret = copy.deepcopy(obj)
|
||||
|
||||
for shards in ret.local_shards():
|
||||
shards.tensor = tensor_func(shards.tensor, pg, device, None)
|
||||
|
||||
return ret
|
||||
|
||||
ret = _iterate_state_dict(
|
||||
state_dict,
|
||||
_identity_func,
|
||||
sharded_tensor_func,
|
||||
dtensor_func,
|
||||
tensor_func,
|
||||
pg=None,
|
||||
|
Reference in New Issue
Block a user