!220 修正模型加载到cpu的问题

Merge pull request !220 from 金勇旭/ds_branch
This commit is contained in:
金勇旭
2025-06-04 01:42:06 +00:00
committed by i-robot
parent a7fb713eea
commit 6d975d04b3

View File

@ -294,6 +294,10 @@ def get_model():
)
else:
init_kwargs["device_map"] = {"": get_current_device(args.device)}
else:
# zero3 does not support load model with device map
if not is_deepspeed_zero3_enabled():
init_kwargs["device_map"] = {"": get_current_device(os.getenv("LOCAL_RANK"))}
if args.load_in_4bit:
patch_bnb()