[Feature] Enable Intel XPU support (#839)

* enable xpu support

* fix bug

* review commits

* fix style

* add xou decorator

* refactor review commit

* fix test

* review commit

* fix test

* Update benchmark.yml (#856)

* Standardise example scripts (#842)

* Standardise example scripts

* fix plotting script

* Rename run_xxx to xxx

* Fix doc

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>

* Fix version check in import_utils.py (#853)

* dont use get_peft_model if model is already peft (#857)

* merge conflict

* add xou decorator

* resolve

* resolves

* upstream

* refactor and precommit

* fix new tests

* add device mapping for xpu

---------

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Costa Huang <costa.huang@outlook.com>
Co-authored-by: Adam Pauls <adpauls@gmail.com>
Co-authored-by: abhishek thakur <1183441+abhishekkrthakur@users.noreply.github.com>
This commit is contained in:
Abhilash Majumder
2023-10-31 14:45:35 +05:30
committed by GitHub
parent d192244f54
commit ec9e76623e
14 changed files with 119 additions and 29 deletions

View File

@ -22,7 +22,7 @@ from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig
from trl import RewardConfig, RewardTrainer
from trl import RewardConfig, RewardTrainer, is_xpu_available
tqdm.pandas()
@ -83,7 +83,11 @@ if args.load_in_8bit and args.load_in_4bit:
elif args.load_in_8bit or args.load_in_4bit:
quantization_config = BitsAndBytesConfig(load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit)
# Copy the model to each device
device_map = {"": Accelerator().local_process_index}
device_map = (
{"": f"xpu:{Accelerator().local_process_index}"}
if is_xpu_available()
else {"": Accelerator().local_process_index}
)
else:
device_map = None
quantization_config = None