[DCP][PyTorch Staging APIs][2/x] Handle 0-elem case + ShardedTensor copy for staging (#156092)

Summary:
### Diff Context

1. Sometimes, a tensor might have non-zero size and 0 numel. In this case, pinning memory will fail
so we take a best guess at how to replicate the tensor below to maintain symmetry in the returned
state dict.

2. ShardedTensor copying was not handled originally in PyTorch state_dict copy APIs, handled in this diff.

Test Plan: CI

Differential Revision: D75553096

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156092
Approved by: https://github.com/pradeepfn
This commit is contained in:
Meet Vadakkanchery
2025-06-18 22:41:20 +00:00
committed by PyTorch MergeBot
parent a5b4463d60
commit 9bfefda296
2 changed files with 96 additions and 3 deletions

View File

@ -187,6 +187,12 @@ def _iterate_state_dict(
companion_obj._local_tensor.copy_(
ret._local_tensor, non_blocking=non_blocking
)
elif isinstance(companion_obj, ShardedTensor):
assert isinstance(ret, ShardedTensor)
for idx, shard in enumerate(companion_obj.local_shards()):
shard.tensor.copy_(
ret.local_shards()[idx].tensor, non_blocking=non_blocking
)
else:
companion_obj.copy_(ret, non_blocking=non_blocking)
ret = companion_obj
@ -402,6 +408,15 @@ def _create_cpu_state_dict(
if len(obj.size()) == 0:
return torch.tensor(0, dtype=obj.dtype)
# sometimes, a tensor might have non-zero size and 0 numel. In this case, pinning memory will fail
# so we take a best guess at how to replicate the tensor below to maintain symmetry in the returned
# state dict.
if obj.numel() == 0 or obj.data_ptr() == 0:
t = torch.zeros_like(obj, device="cpu")
if share_memory:
t = t.share_memory_()
return t
if share_memory:
t = torch.empty(*tuple(obj.size()), dtype=obj.dtype)
t = t.share_memory_()
@ -446,9 +461,28 @@ def _create_cpu_state_dict(
ret._local_tensor = tensor_func(ret._local_tensor, pg, device, None)
return ret
def sharded_tensor_func(
obj: ShardedTensor,
pg: Optional[dist.ProcessGroup],
device: Optional[torch.device],
_: Any,
) -> ShardedTensor:
if not obj.local_shards():
return obj
if obj.device != torch.device("cpu"):
ret = obj.to(device="cpu")
else:
ret = copy.deepcopy(obj)
for shards in ret.local_shards():
shards.tensor = tensor_func(shards.tensor, pg, device, None)
return ret
ret = _iterate_state_dict(
state_dict,
_identity_func,
sharded_tensor_func,
dtensor_func,
tensor_func,
pg=None,