Compare commits

...

2 Commits

2 changed files with 25 additions and 20 deletions

View File

@ -2903,7 +2903,7 @@ class Accelerator:
for i, model in enumerate(self._models):
if self.distributed_type == DistributedType.FSDP:
logger.info("Saving FSDP model")
save_fsdp_model(self.state.fsdp_plugin, self, model, output_dir, i)
save_fsdp_model(self.state.fsdp_plugin, model, output_dir, i)
logger.info(f"FSDP Model saved to output dir {output_dir}")
elif self.distributed_type == DistributedType.DEEPSPEED:
logger.info("Saving DeepSpeed Model and Optimizer")
@ -2922,7 +2922,7 @@ class Accelerator:
if self.distributed_type == DistributedType.FSDP:
for i, opt in enumerate(self._optimizers):
logger.info("Saving FSDP Optimizer")
save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i)
save_fsdp_optimizer(self.state.fsdp_plugin, opt, self._models[i], output_dir, i)
logger.info(f"FSDP Optimizer saved to output dir {output_dir}")
elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
optimizers = self._optimizers
@ -3047,7 +3047,7 @@ class Accelerator:
for i, model in enumerate(self._models):
if self.distributed_type == DistributedType.FSDP:
logger.info("Loading FSDP model")
load_fsdp_model(self.state.fsdp_plugin, self, model, input_dir, i)
load_fsdp_model(self.state.fsdp_plugin, model, input_dir, i)
logger.info(f"FSDP Model loaded from input dir {input_dir}")
elif self.distributed_type == DistributedType.DEEPSPEED:
logger.info("Loading DeepSpeed Model and Optimizer")
@ -3066,7 +3066,7 @@ class Accelerator:
if self.distributed_type == DistributedType.FSDP:
for i, opt in enumerate(self._optimizers):
logger.info("Loading FSDP Optimizer")
load_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], input_dir, i)
load_fsdp_optimizer(self.state.fsdp_plugin, opt, self._models[i], input_dir, i)
logger.info(f"FSDP Optimizer loaded from input dir {input_dir}")
elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
optimizers = self._optimizers

View File

@ -16,6 +16,7 @@ import os
import torch
from ..logging import get_logger
from ..state import PartialState
from .constants import FSDP_MODEL_NAME, FSDP_PYTORCH_VERSION, OPTIMIZER_NAME
from .imports import is_torch_distributed_available
from .modeling import is_peft_model
@ -51,13 +52,14 @@ def _set_model_state_dict(model, state_dict, adapter_only=False):
return model.load_state_dict(state_dict)
def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False):
def save_fsdp_model(fsdp_plugin, model, output_dir, model_index=0, adapter_only=False):
state = PartialState()
os.makedirs(output_dir, exist_ok=True)
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
# FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT
# so, only enable it when num_processes>1
is_multi_process = accelerator.num_processes > 1
is_multi_process = state.num_processes > 1
fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process
fsdp_plugin.state_dict_config.rank0_only = is_multi_process
@ -68,15 +70,15 @@ def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0,
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
weights_name = f"{FSDP_MODEL_NAME}.bin" if model_index == 0 else f"{FSDP_MODEL_NAME}_{model_index}.bin"
output_model_file = os.path.join(output_dir, weights_name)
if accelerator.process_index == 0:
if state.process_index == 0:
logger.info(f"Saving model to {output_model_file}")
torch.save(state_dict, output_model_file)
logger.info(f"Model saved to {output_model_file}")
elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT:
weights_name = (
f"{FSDP_MODEL_NAME}_rank{accelerator.process_index}.bin"
f"{FSDP_MODEL_NAME}_rank{state.process_index}.bin"
if model_index == 0
else f"{FSDP_MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin"
else f"{FSDP_MODEL_NAME}_{model_index}_rank{state.process_index}.bin"
)
output_model_file = os.path.join(output_dir, weights_name)
logger.info(f"Saving model to {output_model_file}")
@ -96,19 +98,20 @@ def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0,
logger.info(f"Model saved to {ckpt_dir}")
def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False):
accelerator.wait_for_everyone()
def load_fsdp_model(fsdp_plugin, model, input_dir, model_index=0, adapter_only=False):
state = PartialState()
state.wait_for_everyone()
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
# FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT
# so, only enable it when num_processes>1
is_multi_process = accelerator.num_processes > 1
is_multi_process = state.num_processes > 1
fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process
fsdp_plugin.state_dict_config.rank0_only = is_multi_process
with FSDP.state_dict_type(
model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
):
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
if type(model) != FSDP and accelerator.process_index != 0:
if type(model) != FSDP and state.process_index != 0:
if not fsdp_plugin.sync_module_states:
raise ValueError(
"Set the `sync_module_states` flag to `True` so that model states are synced across processes when "
@ -122,9 +125,9 @@ def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, a
logger.info(f"Model loaded from {input_model_file}")
elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT:
weights_name = (
f"{FSDP_MODEL_NAME}_rank{accelerator.process_index}.bin"
f"{FSDP_MODEL_NAME}_rank{state.process_index}.bin"
if model_index == 0
else f"{FSDP_MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin"
else f"{FSDP_MODEL_NAME}_{model_index}_rank{state.process_index}.bin"
)
input_model_file = os.path.join(input_dir, weights_name)
logger.info(f"Loading model from {input_model_file}")
@ -149,14 +152,15 @@ def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, a
return load_result
def save_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0):
def save_fsdp_optimizer(fsdp_plugin, optimizer, model, output_dir, optimizer_index=0):
state = PartialState()
os.makedirs(output_dir, exist_ok=True)
with FSDP.state_dict_type(
model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
):
optim_state = FSDP.optim_state_dict(model, optimizer)
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
if accelerator.process_index == 0:
if state.process_index == 0:
optim_state_name = (
f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin"
)
@ -176,14 +180,15 @@ def save_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, output_dir,
logger.info(f"Optimizer state saved in {ckpt_dir}")
def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, optimizer_index=0, adapter_only=False):
accelerator.wait_for_everyone()
def load_fsdp_optimizer(fsdp_plugin, optimizer, model, input_dir, optimizer_index=0, adapter_only=False):
state = PartialState()
state.wait_for_everyone()
with FSDP.state_dict_type(
model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
):
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
optim_state = None
if accelerator.process_index == 0 or not fsdp_plugin.optim_state_dict_config.rank0_only:
if state.process_index == 0 or not fsdp_plugin.optim_state_dict_config.rank0_only:
optimizer_name = (
f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin"
)