ENH Support XPU for seq clf examples (#2732)

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
This commit is contained in:
kaixuanliu
2025-08-08 18:07:20 +08:00
committed by GitHub
parent a4b41e7924
commit 9b420cc9c7
11 changed files with 5819 additions and 6099 deletions

View File

@ -66,7 +66,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "e3b13308",
"metadata": {},
"outputs": [
@ -86,7 +86,7 @@
"model_name_or_path = \"roberta-base\"\n",
"task = \"mrpc\"\n",
"peft_type = PeftType.FOURIERFT\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n",
"num_epochs = 5 # for better results, increase this number\n",
"n_frequency = 1000 # for better results, increase this number\n",
"scaling = 150.0\n",

File diff suppressed because it is too large Load Diff

View File

@ -56,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "e3b13308",
"metadata": {},
"outputs": [],
@ -64,7 +64,7 @@
"batch_size = 16\n",
"model_name_or_path = \"google/gemma-2-2b\"\n",
"task = \"mrpc\"\n",
"device = \"cuda\"\n",
"device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n",
"num_epochs = 5\n",
"lr = 2e-5\n",
"\n",

View File

@ -56,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "e3b13308",
"metadata": {},
"outputs": [],
@ -64,7 +64,7 @@
"batch_size = 16\n",
"model_name_or_path = \"google/gemma-2-2b\"\n",
"task = \"mrpc\"\n",
"device = \"cuda\"\n",
"device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n",
"num_epochs = 5\n",
"lr = 2e-5\n",
"\n",

View File

@ -48,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "e3b13308",
"metadata": {},
"outputs": [],
@ -57,7 +57,7 @@
"model_name_or_path = \"roberta-large\"\n",
"task = \"mrpc\"\n",
"peft_type = PeftType.LORA\n",
"device = \"cuda\"\n",
"device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n",
"num_epochs = 20"
]
},

View File

@ -47,7 +47,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "2bd7cbb2",
"metadata": {},
"outputs": [],
@ -56,7 +56,7 @@
"model_name_or_path = \"roberta-large\"\n",
"task = \"mrpc\"\n",
"peft_type = PeftType.P_TUNING\n",
"device = \"cuda\"\n",
"device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n",
"num_epochs = 20"
]
},

View File

@ -48,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "e32c4a9e",
"metadata": {},
"outputs": [],
@ -57,7 +57,7 @@
"model_name_or_path = \"roberta-large\"\n",
"task = \"mrpc\"\n",
"peft_type = PeftType.PROMPT_TUNING\n",
"device = \"cuda\"\n",
"device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n",
"num_epochs = 20"
]
},

View File

@ -58,7 +58,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "e3b13308",
"metadata": {},
"outputs": [
@ -78,7 +78,7 @@
"model_name_or_path = \"roberta-large\"\n",
"task = \"mrpc\"\n",
"peft_type = PeftType.VBLORA\n",
"device = \"cuda\"\n",
"device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n",
"num_epochs = 20\n",
"rank = 4\n",
"max_length = 128\n",

View File

@ -76,7 +76,7 @@
"model_name_or_path = \"roberta-base\"\n",
"task = \"mrpc\"\n",
"peft_type = PeftType.VERA\n",
"device = \"cuda\"\n",
"device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n",
"num_epochs = 5 # for best results, increase this number\n",
"rank = 8 # for best results, increase this number\n",
"max_length = 128\n",

View File

@ -47,7 +47,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "2bd7cbb2",
"metadata": {},
"outputs": [],
@ -56,7 +56,7 @@
"model_name_or_path = \"roberta-large\"\n",
"task = \"mrpc\"\n",
"peft_type = PeftType.PREFIX_TUNING\n",
"device = \"cuda\"\n",
"device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n",
"num_epochs = 20"
]
},

View File

@ -2,4 +2,5 @@ transformers
accelerate
evaluate
tqdm
datasets
datasets
torchao