!243 Modify the model loading logic and the fused operator enabling logic.

Merge pull request !243 from 金勇旭/master0
This commit is contained in:
金勇旭
2025-06-18 10:22:08 +00:00
committed by i-robot
parent 49302340c8
commit 6e39f9ca2c
3 changed files with 17 additions and 13 deletions

View File

@ -37,7 +37,7 @@ from transformers import (
from transformers.dynamic_module_utils import get_relative_imports
from transformers.integrations import is_deepspeed_zero3_enabled
from openmind.utils import logging, is_torch_npu_available
from openmind.integrations.transformers.npu_fused_ops.sdk import SUPPORTED_FUSED_MODELS, map_fused_kernel_to_model
from openmind.integrations.transformers.npu_fused_ops.sdk import map_fused_kernel_to_model
from openmind.flow.arguments import get_args
from openmind.flow.model.model_registry import SUPPORTED_MODELS
from openmind.flow.model.adapter import apply_adapter
@ -315,10 +315,6 @@ def get_model():
)
else:
init_kwargs["device_map"] = {"": get_current_device(args.device)}
else:
# zero3 does not support load model with device map
if not is_deepspeed_zero3_enabled():
init_kwargs["device_map"] = {"": get_current_device(os.getenv("LOCAL_RANK", 0))}
if args.load_in_4bit:
patch_bnb()
@ -337,10 +333,9 @@ def get_model():
nf4_config = None
patch_config(config)
if config.architectures and config.architectures[0] in SUPPORTED_FUSED_MODELS and args.do_train:
logger.warning_rank0(f"Unsupported model architecture for npu fused options: {config.architectures[0]}")
if args.do_train:
map_fused_kernel_to_model(
config.architectures[0],
config.architectures,
use_npu_fusion_attention=args.use_npu_fusion_attention,
use_fused_rms_norm=args.use_fused_rms_norm,
use_fused_rope=args.use_fused_rope,
@ -362,6 +357,9 @@ def get_model():
**init_kwargs,
)
if args.init_lora_weights:
model = model.to(get_current_device(os.getenv("LOCAL_RANK", 0)))
apply_sequence_parallel(args, config)
model = apply_adapter(model, args.do_train)

View File

@ -198,7 +198,13 @@ SUPPORTED_FUSED_MODELS = {
}
def map_fused_kernel_to_model(architecture, **kwargs):
if architecture not in SUPPORTED_FUSED_MODELS:
def map_fused_kernel_to_model(architectures, **kwargs):
if not architectures:
logger.warning_rank0("Unknown model architectures for npu fused options")
return
SUPPORTED_FUSED_MODELS.get(architecture)(inner=True, **kwargs)
if architectures[0] not in SUPPORTED_FUSED_MODELS:
logger.warning_rank0(f"Unsupported model architecture for npu fused options: {architectures[0]}")
return
SUPPORTED_FUSED_MODELS.get(architectures[0])(inner=True, **kwargs)

View File

@ -165,7 +165,7 @@ class TestMapFusedKernel:
config = Config()
sdk.map_fused_kernel_to_model(
architecture="Qwen2ForCausalLM",
architectures=["Qwen2ForCausalLM"],
use_npu_fusion_attention=True,
use_fused_rms_norm=True,
use_fused_rope=True,
@ -191,7 +191,7 @@ class TestMapFusedKernel:
mock_raw_get_dynamic_module.model_name = "InternLM2ForCausalLM"
mock_raw_get_dynamic_module.return_value = Model(config)
sdk.map_fused_kernel_to_model(
architecture="InternLM2ForCausalLM",
architectures=["InternLM2ForCausalLM"],
use_npu_fusion_attention=True,
use_fused_rms_norm=False,
use_fused_rope=False,