Fix FSDP offload pin_memory bug (#157147)

Fixes #157146

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157147
Approved by: https://github.com/weifengpy
This commit is contained in:
Edenzzzz
2025-06-28 21:09:08 +00:00
committed by PyTorch MergeBot
parent 67f8270516
commit 0629dfb860

View File

@ -376,9 +376,7 @@ class FSDPParam:
if self.offload_to_cpu and not padded_sharded_param.is_meta:
padded_sharded_param = padded_sharded_param.cpu()
if self.pin_memory:
padded_sharded_param = padded_sharded_param.pin_memory(
device=self.device
)
padded_sharded_param = padded_sharded_param.pin_memory()
self._sharded_param_data = padded_sharded_param.view(-1)
length = sharded_param.size(shard_dim) if sharded_param.numel() > 0 else 0
sharded_param = padded_sharded_param.narrow(
@ -848,7 +846,7 @@ class FSDPParam:
local_tensor = padded_local_tensor
updated_local_tensor = True
if self.pin_memory and not local_tensor.is_pinned():
local_tensor = local_tensor.cpu().pin_memory(device=self.device)
local_tensor = local_tensor.cpu().pin_memory()
updated_local_tensor = True
self._sharded_param_data = local_tensor.view(-1)
assert isinstance(self.sharded_param, DTensor) # mypy