fix: cpu ram efficient loading for nd or hsdp parallelisms (#3740)

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
This commit is contained in:
Mehant Kammakomati
2025-08-21 17:10:06 +05:30
committed by GitHub
parent 7c25f696b8
commit 979d81e4a9

View File

@ -506,7 +506,7 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()):
device_mesh = sharded_param.device_mesh
full_param = full_param.detach().to(device_mesh.device_type)
dist.broadcast(full_param, src=0, group=device_mesh.get_group())
dist.broadcast(full_param, src=0, group=dist.group.WORLD)
sharded_tensor = distribute_tensor(full_param, device_mesh, sharded_param.placements)
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
@ -520,7 +520,7 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
for param_name, sharded_param in meta_sharded_sd.items():
device_mesh = sharded_param.device_mesh
full_tensor = torch.empty(sharded_param.size(), device=device_mesh.device_type, dtype=sharded_param.dtype)
dist.broadcast(full_tensor, src=0, group=device_mesh.get_group())
dist.broadcast(full_tensor, src=0, group=dist.group.WORLD)
sharded_tensor = distribute_tensor(full_tensor, device_mesh, sharded_param.placements)
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,