mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
- FastPersist - ZeRO-Inference+SGLang --------- Signed-off-by: Olatunji Ruwase <olruwase@microsoft.com> Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: jerryyangli <jerryyangli@gmail.com> Co-authored-by: Yang Li <yangli2@microsoft.com> Co-authored-by: Guanhua Wang <alexwgh333@gmail.com> Co-authored-by: Connor Holmes <connorholmes@microsoft.com> Co-authored-by: Bing Xie <67908712+xiexbing@users.noreply.github.com> Co-authored-by: cassieesvelt <73311224+cassieesvelt@users.noreply.github.com> Co-authored-by: Jeff Rasley <jerasley@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: swli <47371259+lucasleesw@users.noreply.github.com> Co-authored-by: Cheng Li <pistasable@gmail.com> Co-authored-by: Molly Smith <112220543+molly-smith@users.noreply.github.com> Co-authored-by: Ubuntu <jomayeri@microsoft.com> Co-authored-by: Olatunji Ruwase <tunji.ruwase@snowflake.com> Co-authored-by: Zhipeng Wang <zhipeng.rainbowserie@gmail.com>
109 lines
5.0 KiB
Python
109 lines
5.0 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
# DeepSpeed Team
|
|
|
|
import os
|
|
import torch
|
|
import torch_nebula
|
|
|
|
from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
|
|
CheckpointEngine, CheckpointCommitInfo
|
|
from deepspeed.utils import logger, log_dist
|
|
from deepspeed.nebula.constants import *
|
|
|
|
|
|
def _get_tag_from_path(path):
|
|
return os.path.basename(os.path.dirname(path))
|
|
|
|
|
|
class NebulaCheckpointEngine(CheckpointEngine):
|
|
|
|
def __init__(self, config_params=None):
|
|
super().__init__(config_params)
|
|
self.name = "NebulaCheckpointEngine"
|
|
self.checkpoint = None
|
|
self.tag_flag = None
|
|
self.enable_nebula_load = config_params.enable_nebula_load
|
|
self.nebula_load_path = config_params.load_path
|
|
if self.nebula_load_path is None:
|
|
self.nebula_load_path = config_params.persistent_storage_path
|
|
|
|
nebula_config_params = {
|
|
NEBULA_PERSISTENT_STORAGE_PATH: config_params.persistent_storage_path,
|
|
NEBULA_PERSISTENT_TIME_INTERVAL: config_params.persistent_time_interval,
|
|
NEBULA_NUM_OF_VERSION_IN_RETENTION: config_params.num_of_version_in_retention,
|
|
}
|
|
torch_nebula.init(**nebula_config_params)
|
|
|
|
def create(self, info: CheckpointCommitInfo):
|
|
log_dist(f"[Nebula] Start Checkpoint for tag:{info.tag}", ranks=[0])
|
|
# -2 means: customer needs to explicitly tell nebula
|
|
# current checkpoint is complete by commit methond.
|
|
self.checkpoint = torch_nebula.Checkpoint(info.tag, -2)
|
|
|
|
def save(self, state_dict, path: str):
|
|
log_dist(f"[Nebula] Create dummy files for loading.")
|
|
torch.save("", path)
|
|
|
|
tag = _get_tag_from_path(path)
|
|
partititon_name = os.path.basename(path)
|
|
logger.info(f"[Nebula] Saving {partititon_name} under tag {tag}...")
|
|
self.checkpoint.save(partititon_name, state_dict)
|
|
logger.info(f"[Nebula] Saved {partititon_name} under tag {tag}.")
|
|
|
|
def load(self, path: str, map_location=None):
|
|
tag = _get_tag_from_path(path)
|
|
first_load_flag = self.tag_flag is None or self.tag_flag == tag
|
|
if not self.enable_nebula_load and first_load_flag:
|
|
self.tag_flag = tag
|
|
logger.info(f"[Nebula] Disable nebula load. Loading checkpoint from {path} ...")
|
|
partition = torch.load(path, map_location=map_location, weights_only=False)
|
|
logger.info(f"[Nebula] Disable nebula load. Loaded checkpoint from {path} .")
|
|
return partition
|
|
|
|
partition_name = os.path.basename(path)
|
|
logger.info(f"[Nebula] Loading {path} under tag {tag} from nebula path {self.nebula_load_path}...")
|
|
|
|
checkpoint = None
|
|
if tag in (None, 'latest', 'latest_universal'):
|
|
# In some cases, there is the inconsistent tag between deepspeed metadata (latest file)
|
|
# and nebula metadata, will lead to the failure on loading with deepspeed tag. Then we
|
|
# will try to load the valid latest checkpoint from nebula(tier3 > tier1). So, in summary
|
|
# when met failure loading for given tag, the loading priority would be like:
|
|
# nebula tier3 latest > nebula tier1 latest.
|
|
checkpoint = torch_nebula.get_latest_checkpoint(persist_path=self.nebula_load_path)
|
|
else:
|
|
checkpoint = torch_nebula.get_checkpoint(tag=tag, persist_path=self.nebula_load_path)
|
|
|
|
if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''):
|
|
logger.info(
|
|
f"Unable to find valid checkpoint tag:{tag} from Nebula, try to get latest checkpoint again from nebula {self.nebula_load_path} path!"
|
|
)
|
|
# nebula tier3 latest
|
|
checkpoint = torch_nebula.get_latest_checkpoint(persist_path=self.nebula_load_path)
|
|
if checkpoint is None or (checkpoint is not None and checkpoint.tag == ''):
|
|
logger.info(
|
|
f"Unable to find latest checkpoint from Nebula tier3, try to get latest checkpoint again from nebula tier1 path!"
|
|
)
|
|
# nebula tier1 latest
|
|
checkpoint = torch_nebula.get_latest_checkpoint()
|
|
logger.warning(f"Unable to find valid checkpoint from Nebula under tag:{tag}.")
|
|
return None
|
|
|
|
tag = checkpoint.tag
|
|
self.tag_flag = -1
|
|
partition = checkpoint.load(partition_name, map_location=map_location)
|
|
logger.info(f"[Nebula] Loaded {path} under tag {tag} from {self.nebula_load_path}.")
|
|
return partition
|
|
|
|
def commit(self, info: CheckpointCommitInfo):
|
|
tag = info.tag
|
|
# nebula commit will be call when all files under give tag are ready to be persisted in the async way.
|
|
logger.info(f"[Nebula] all files for {tag} are saved in tier1. It is ready to start persisting")
|
|
commit_rls = self.checkpoint.commit()
|
|
if not commit_rls:
|
|
logger.error(f"[Nebula] failed to commit the checkpoint, please check the log.")
|
|
return False
|
|
return commit_rls
|