mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
fix: cpu ram efficient loading for nd or hsdp parallelisms (#3740)
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
This commit is contained in:
committed by
GitHub
parent
7c25f696b8
commit
979d81e4a9
@ -506,7 +506,7 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
|
||||
for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()):
|
||||
device_mesh = sharded_param.device_mesh
|
||||
full_param = full_param.detach().to(device_mesh.device_type)
|
||||
dist.broadcast(full_param, src=0, group=device_mesh.get_group())
|
||||
dist.broadcast(full_param, src=0, group=dist.group.WORLD)
|
||||
sharded_tensor = distribute_tensor(full_param, device_mesh, sharded_param.placements)
|
||||
to_contiguous, casting_dtype = _infer_parameter_dtype(
|
||||
model,
|
||||
@ -520,7 +520,7 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
|
||||
for param_name, sharded_param in meta_sharded_sd.items():
|
||||
device_mesh = sharded_param.device_mesh
|
||||
full_tensor = torch.empty(sharded_param.size(), device=device_mesh.device_type, dtype=sharded_param.dtype)
|
||||
dist.broadcast(full_tensor, src=0, group=device_mesh.get_group())
|
||||
dist.broadcast(full_tensor, src=0, group=dist.group.WORLD)
|
||||
sharded_tensor = distribute_tensor(full_tensor, device_mesh, sharded_param.placements)
|
||||
to_contiguous, casting_dtype = _infer_parameter_dtype(
|
||||
model,
|
||||
|
Reference in New Issue
Block a user