mirror of
https://github.com/huggingface/accelerate.git
synced 2025-11-14 22:24:32 +08:00
Compare commits
2 Commits
v1.9.0
...
use-partia
| Author | SHA1 | Date | |
|---|---|---|---|
| 72e214f561 | |||
| ab14a5e6a1 |
@ -2903,7 +2903,7 @@ class Accelerator:
|
||||
for i, model in enumerate(self._models):
|
||||
if self.distributed_type == DistributedType.FSDP:
|
||||
logger.info("Saving FSDP model")
|
||||
save_fsdp_model(self.state.fsdp_plugin, self, model, output_dir, i)
|
||||
save_fsdp_model(self.state.fsdp_plugin, model, output_dir, i)
|
||||
logger.info(f"FSDP Model saved to output dir {output_dir}")
|
||||
elif self.distributed_type == DistributedType.DEEPSPEED:
|
||||
logger.info("Saving DeepSpeed Model and Optimizer")
|
||||
@ -2922,7 +2922,7 @@ class Accelerator:
|
||||
if self.distributed_type == DistributedType.FSDP:
|
||||
for i, opt in enumerate(self._optimizers):
|
||||
logger.info("Saving FSDP Optimizer")
|
||||
save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i)
|
||||
save_fsdp_optimizer(self.state.fsdp_plugin, opt, self._models[i], output_dir, i)
|
||||
logger.info(f"FSDP Optimizer saved to output dir {output_dir}")
|
||||
elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
|
||||
optimizers = self._optimizers
|
||||
@ -3047,7 +3047,7 @@ class Accelerator:
|
||||
for i, model in enumerate(self._models):
|
||||
if self.distributed_type == DistributedType.FSDP:
|
||||
logger.info("Loading FSDP model")
|
||||
load_fsdp_model(self.state.fsdp_plugin, self, model, input_dir, i)
|
||||
load_fsdp_model(self.state.fsdp_plugin, model, input_dir, i)
|
||||
logger.info(f"FSDP Model loaded from input dir {input_dir}")
|
||||
elif self.distributed_type == DistributedType.DEEPSPEED:
|
||||
logger.info("Loading DeepSpeed Model and Optimizer")
|
||||
@ -3066,7 +3066,7 @@ class Accelerator:
|
||||
if self.distributed_type == DistributedType.FSDP:
|
||||
for i, opt in enumerate(self._optimizers):
|
||||
logger.info("Loading FSDP Optimizer")
|
||||
load_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], input_dir, i)
|
||||
load_fsdp_optimizer(self.state.fsdp_plugin, opt, self._models[i], input_dir, i)
|
||||
logger.info(f"FSDP Optimizer loaded from input dir {input_dir}")
|
||||
elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
|
||||
optimizers = self._optimizers
|
||||
|
||||
@ -16,6 +16,7 @@ import os
|
||||
import torch
|
||||
|
||||
from ..logging import get_logger
|
||||
from ..state import PartialState
|
||||
from .constants import FSDP_MODEL_NAME, FSDP_PYTORCH_VERSION, OPTIMIZER_NAME
|
||||
from .imports import is_torch_distributed_available
|
||||
from .modeling import is_peft_model
|
||||
@ -51,13 +52,14 @@ def _set_model_state_dict(model, state_dict, adapter_only=False):
|
||||
return model.load_state_dict(state_dict)
|
||||
|
||||
|
||||
def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False):
|
||||
def save_fsdp_model(fsdp_plugin, model, output_dir, model_index=0, adapter_only=False):
|
||||
state = PartialState()
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
|
||||
# FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT
|
||||
# so, only enable it when num_processes>1
|
||||
is_multi_process = accelerator.num_processes > 1
|
||||
is_multi_process = state.num_processes > 1
|
||||
fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process
|
||||
fsdp_plugin.state_dict_config.rank0_only = is_multi_process
|
||||
|
||||
@ -68,15 +70,15 @@ def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0,
|
||||
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
|
||||
weights_name = f"{FSDP_MODEL_NAME}.bin" if model_index == 0 else f"{FSDP_MODEL_NAME}_{model_index}.bin"
|
||||
output_model_file = os.path.join(output_dir, weights_name)
|
||||
if accelerator.process_index == 0:
|
||||
if state.process_index == 0:
|
||||
logger.info(f"Saving model to {output_model_file}")
|
||||
torch.save(state_dict, output_model_file)
|
||||
logger.info(f"Model saved to {output_model_file}")
|
||||
elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT:
|
||||
weights_name = (
|
||||
f"{FSDP_MODEL_NAME}_rank{accelerator.process_index}.bin"
|
||||
f"{FSDP_MODEL_NAME}_rank{state.process_index}.bin"
|
||||
if model_index == 0
|
||||
else f"{FSDP_MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin"
|
||||
else f"{FSDP_MODEL_NAME}_{model_index}_rank{state.process_index}.bin"
|
||||
)
|
||||
output_model_file = os.path.join(output_dir, weights_name)
|
||||
logger.info(f"Saving model to {output_model_file}")
|
||||
@ -96,19 +98,20 @@ def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0,
|
||||
logger.info(f"Model saved to {ckpt_dir}")
|
||||
|
||||
|
||||
def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False):
|
||||
accelerator.wait_for_everyone()
|
||||
def load_fsdp_model(fsdp_plugin, model, input_dir, model_index=0, adapter_only=False):
|
||||
state = PartialState()
|
||||
state.wait_for_everyone()
|
||||
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
|
||||
# FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT
|
||||
# so, only enable it when num_processes>1
|
||||
is_multi_process = accelerator.num_processes > 1
|
||||
is_multi_process = state.num_processes > 1
|
||||
fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process
|
||||
fsdp_plugin.state_dict_config.rank0_only = is_multi_process
|
||||
with FSDP.state_dict_type(
|
||||
model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
|
||||
):
|
||||
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
|
||||
if type(model) != FSDP and accelerator.process_index != 0:
|
||||
if type(model) != FSDP and state.process_index != 0:
|
||||
if not fsdp_plugin.sync_module_states:
|
||||
raise ValueError(
|
||||
"Set the `sync_module_states` flag to `True` so that model states are synced across processes when "
|
||||
@ -122,9 +125,9 @@ def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, a
|
||||
logger.info(f"Model loaded from {input_model_file}")
|
||||
elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT:
|
||||
weights_name = (
|
||||
f"{FSDP_MODEL_NAME}_rank{accelerator.process_index}.bin"
|
||||
f"{FSDP_MODEL_NAME}_rank{state.process_index}.bin"
|
||||
if model_index == 0
|
||||
else f"{FSDP_MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin"
|
||||
else f"{FSDP_MODEL_NAME}_{model_index}_rank{state.process_index}.bin"
|
||||
)
|
||||
input_model_file = os.path.join(input_dir, weights_name)
|
||||
logger.info(f"Loading model from {input_model_file}")
|
||||
@ -149,14 +152,15 @@ def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, a
|
||||
return load_result
|
||||
|
||||
|
||||
def save_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0):
|
||||
def save_fsdp_optimizer(fsdp_plugin, optimizer, model, output_dir, optimizer_index=0):
|
||||
state = PartialState()
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
with FSDP.state_dict_type(
|
||||
model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
|
||||
):
|
||||
optim_state = FSDP.optim_state_dict(model, optimizer)
|
||||
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
|
||||
if accelerator.process_index == 0:
|
||||
if state.process_index == 0:
|
||||
optim_state_name = (
|
||||
f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin"
|
||||
)
|
||||
@ -176,14 +180,15 @@ def save_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, output_dir,
|
||||
logger.info(f"Optimizer state saved in {ckpt_dir}")
|
||||
|
||||
|
||||
def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, optimizer_index=0, adapter_only=False):
|
||||
accelerator.wait_for_everyone()
|
||||
def load_fsdp_optimizer(fsdp_plugin, optimizer, model, input_dir, optimizer_index=0, adapter_only=False):
|
||||
state = PartialState()
|
||||
state.wait_for_everyone()
|
||||
with FSDP.state_dict_type(
|
||||
model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config
|
||||
):
|
||||
if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT:
|
||||
optim_state = None
|
||||
if accelerator.process_index == 0 or not fsdp_plugin.optim_state_dict_config.rank0_only:
|
||||
if state.process_index == 0 or not fsdp_plugin.optim_state_dict_config.rank0_only:
|
||||
optimizer_name = (
|
||||
f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user