mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
[Feature] Enable Intel XPU support (#839)
* enable xpu support * fix bug * review commits * fix style * add xou decorator * refactor review commit * fix test * review commit * fix test * Update benchmark.yml (#856) * Standardise example scripts (#842) * Standardise example scripts * fix plotting script * Rename run_xxx to xxx * Fix doc --------- Co-authored-by: Costa Huang <costa.huang@outlook.com> * Fix version check in import_utils.py (#853) * dont use get_peft_model if model is already peft (#857) * merge conflict * add xou decorator * resolve * resolves * upstream * refactor and precommit * fix new tests * add device mapping for xpu --------- Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> Co-authored-by: Costa Huang <costa.huang@outlook.com> Co-authored-by: Adam Pauls <adpauls@gmail.com> Co-authored-by: abhishek thakur <1183441+abhishekkrthakur@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
d192244f54
commit
ec9e76623e
@ -22,7 +22,7 @@ from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
|
||||
|
||||
from trl import RewardConfig, RewardTrainer
|
||||
from trl import RewardConfig, RewardTrainer, is_xpu_available
|
||||
|
||||
|
||||
tqdm.pandas()
|
||||
@ -83,7 +83,11 @@ if args.load_in_8bit and args.load_in_4bit:
|
||||
elif args.load_in_8bit or args.load_in_4bit:
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit)
|
||||
# Copy the model to each device
|
||||
device_map = {"": Accelerator().local_process_index}
|
||||
device_map = (
|
||||
{"": f"xpu:{Accelerator().local_process_index}"}
|
||||
if is_xpu_available()
|
||||
else {"": Accelerator().local_process_index}
|
||||
)
|
||||
else:
|
||||
device_map = None
|
||||
quantization_config = None
|
||||
|
Reference in New Issue
Block a user