mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix FSDP offload pin_memory bug (#157147)
Fixes #157146 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157147 Approved by: https://github.com/weifengpy
This commit is contained in:
committed by
PyTorch MergeBot
parent
67f8270516
commit
0629dfb860
@ -376,9 +376,7 @@ class FSDPParam:
|
||||
if self.offload_to_cpu and not padded_sharded_param.is_meta:
|
||||
padded_sharded_param = padded_sharded_param.cpu()
|
||||
if self.pin_memory:
|
||||
padded_sharded_param = padded_sharded_param.pin_memory(
|
||||
device=self.device
|
||||
)
|
||||
padded_sharded_param = padded_sharded_param.pin_memory()
|
||||
self._sharded_param_data = padded_sharded_param.view(-1)
|
||||
length = sharded_param.size(shard_dim) if sharded_param.numel() > 0 else 0
|
||||
sharded_param = padded_sharded_param.narrow(
|
||||
@ -848,7 +846,7 @@ class FSDPParam:
|
||||
local_tensor = padded_local_tensor
|
||||
updated_local_tensor = True
|
||||
if self.pin_memory and not local_tensor.is_pinned():
|
||||
local_tensor = local_tensor.cpu().pin_memory(device=self.device)
|
||||
local_tensor = local_tensor.cpu().pin_memory()
|
||||
updated_local_tensor = True
|
||||
self._sharded_param_data = local_tensor.view(-1)
|
||||
assert isinstance(self.sharded_param, DTensor) # mypy
|
||||
|
Reference in New Issue
Block a user