mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[merger] fix: move megatron import into megatron related branch (#958)
users using fsdp backend may no have megatron installed, directly running this script will lead to an import error.
This commit is contained in:
@ -28,10 +28,8 @@ try:
|
||||
except ImportError:
|
||||
from torch.distributed._tensor import DTensor
|
||||
|
||||
from verl.utils.megatron_utils import get_model_checkpoint_path, get_hf_model_checkpoint_path
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--backend', type=str, required=True, help="The backend of the model")
|
||||
parser.add_argument('--backend', type=str, required=True, help="The backend of the model", choices=["fsdp", "megatron"])
|
||||
parser.add_argument('--tie-word-embedding', action='store_true', help="Whether to tie word embedding weights")
|
||||
parser.add_argument('--is-value-model', action='store_true', help="Whether the model loaded as value model")
|
||||
parser.add_argument('--hf_model_path', type=str, required=True, help="The path for the huggingface model")
|
||||
@ -227,6 +225,8 @@ def check_megatron_checkpoint_path(model_path):
|
||||
|
||||
|
||||
def convert_megatron_checkpoints_to_hfmodels():
|
||||
from verl.utils.megatron_utils import get_model_checkpoint_path, get_hf_model_checkpoint_path
|
||||
|
||||
local_path = args.local_dir
|
||||
|
||||
model_ckpt_path = get_model_checkpoint_path(local_path)
|
||||
|
Reference in New Issue
Block a user