@ -294,6 +294,10 @@ def get_model():
|
||||
)
|
||||
else:
|
||||
init_kwargs["device_map"] = {"": get_current_device(args.device)}
|
||||
else:
|
||||
# zero3 does not support load model with device map
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
init_kwargs["device_map"] = {"": get_current_device(os.getenv("LOCAL_RANK"))}
|
||||
|
||||
if args.load_in_4bit:
|
||||
patch_bnb()
|
||||
|
Reference in New Issue
Block a user