FIX Poly issue with returned base model (#2702)

Also, add XPU support for Poly example.

---------

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix
2025-08-06 03:16:49 -07:00
committed by GitHub
parent e3d8fc98f1
commit db5c00fad2
2 changed files with 31 additions and 146 deletions

View File

@ -18,13 +18,15 @@
"name": "stdout",
"output_type": "stream",
"text": [
"env: CUDA_VISIBLE_DEVICES=0\n",
"env: CUDA_VISIBLE_DEVICES=0 # force using CUDA GPU device 0\n",
"env: ZE_AFFINITY_MASK=0 # force using Intel XPU device 0\n",
"env: TOKENIZERS_PARALLELISM=false\n"
]
}
],
"source": [
"%env CUDA_VISIBLE_DEVICES=0\n",
"%env CUDA_VISIBLE_DEVICES=0 # force using CUDA GPU device 0\n",
"%env ZE_AFFINITY_MASK=0 # force using Intel XPU device 0\n",
"%env TOKENIZERS_PARALLELISM=false"
]
},
@ -41,7 +43,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"id": "1a5c7a99-5208-4d22-ac15-bacebe1b52f9",
"metadata": {
"execution": {
@ -52,36 +54,7 @@
"id": "1a5c7a99-5208-4d22-ac15-bacebe1b52f9",
"libroFormatter": "formatter-string"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"===================================BUG REPORT===================================\n",
"Welcome to bitsandbytes. For bug reports, please run\n",
"\n",
"python -m bitsandbytes\n",
"\n",
" and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
"================================================================================\n",
"bin /opt/conda/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cuda121.so\n",
"CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so\n",
"CUDA SETUP: Highest compute capability among GPUs detected: 8.0\n",
"CUDA SETUP: Detected CUDA version 121\n",
"CUDA SETUP: Loading binary /opt/conda/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cuda121.so...\n",
"[2023-12-22 11:34:24,536] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
]
}
],
"outputs": [],
"source": [
"import torch\n",
"from transformers import (\n",
@ -208,7 +181,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 2/2 [00:39<00:00, 19.75s/it]\n"
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 22.43it/s]\n"
]
}
],
@ -239,7 +212,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"trainable params: 9,441,792 || all params: 2,859,198,976 || trainable%: 0.33022507629773296\n"
"trainable params: 9,441,792 || all params: 2,859,198,976 || trainable%: 0.3302\n"
]
}
],
@ -828,8 +801,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Map: 100%|██████████| 4000/4000 [00:02<00:00, 1365.07 examples/s]\n",
"Map: 100%|██████████| 400/400 [00:00<00:00, 548.46 examples/s]\n"
"Map: 0%| | 0/4000 [00:00<?, ? examples/s]"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Map: 100%|██████████| 4000/4000 [00:02<00:00, 1880.98 examples/s]\n",
"Map: 100%|██████████| 400/400 [00:00<00:00, 2124.88 examples/s]\n"
]
}
],
@ -851,7 +831,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"id": "b48135d6-0d83-4e8a-b1f0-c292663c84ec",
"metadata": {
"colab": {
@ -867,87 +847,7 @@
"libroFormatter": "formatter-string",
"outputId": "362dbaae-4a43-423b-d0d1-39839d721177"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.8/site-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='3784' max='4000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [3784/4000 2:11:42 < 07:31, 0.48 it/s, Epoch 7.57/8]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Epoch</th>\n",
" <th>Training Loss</th>\n",
" <th>Validation Loss</th>\n",
" <th>Accuracy</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>1.458600</td>\n",
" <td>0.968393</td>\n",
" <td>0.457500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.619800</td>\n",
" <td>0.874669</td>\n",
" <td>0.510000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.548800</td>\n",
" <td>0.837347</td>\n",
" <td>0.537500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.466800</td>\n",
" <td>0.784065</td>\n",
" <td>0.552500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>0.400800</td>\n",
" <td>0.768286</td>\n",
" <td>0.565000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>0.377200</td>\n",
" <td>0.764708</td>\n",
" <td>0.562500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>0.356300</td>\n",
" <td>0.765993</td>\n",
" <td>0.562500</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"# training and evaluation\n",
"def compute_metrics(eval_preds):\n",
@ -1033,9 +933,9 @@
"output_type": "stream",
"text": [
"total 37M\n",
"-rw-r--r-- 1 root root 374 12月 22 14:59 adapter_config.json\n",
"-rw-r--r-- 1 root root 37M 12月 22 14:59 adapter_model.safetensors\n",
"-rw-r--r-- 1 root root 5.0K 12月 22 14:58 README.md\n"
"-rw-r--r-- 1 root root 5.1K Aug 4 20:25 README.md\n",
"-rw-r--r-- 1 root root 381 Aug 4 20:25 adapter_config.json\n",
"-rw-r--r-- 1 root root 37M Aug 4 20:25 adapter_model.safetensors\n"
]
}
],
@ -1069,7 +969,8 @@
},
"outputs": [],
"source": [
"device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\""
"device_type = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n",
"device = f\"{device_type}:0\" if device_type != \"cpu\" else \"cpu\""
]
},
{
@ -1090,7 +991,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 2/2 [00:36<00:00, 18.50s/it]\n"
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 22.17it/s]\n"
]
}
],
@ -1137,11 +1038,11 @@
" 122, 63, 102, 17, 11128, 7139, 47, 3814, 16, 15393,\n",
" 5, 27, 7, 8, 7142, 666, 3, 295, 10990, 57,\n",
" 8, 7142, 756, 58, 71, 5, 2163, 272, 5, 465,\n",
" 11801, 10, 1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 11801, 10, 1]], device='xpu:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0'), 'task_ids': tensor([2], device='cuda:0')}\n",
"tensor([ 0, 71, 1], device='cuda:0')\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='xpu:0'), 'task_ids': tensor([2], device='xpu:0')}\n",
"tensor([ 0, 71, 1], device='xpu:0')\n",
"A\n"
]
}
@ -1160,18 +1061,6 @@
" print(outputs[0])\n",
" print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82ffbaf4-6ff1-41bd-8665-c51c402a45bc",
"metadata": {
"execution": {},
"id": "82ffbaf4-6ff1-41bd-8665-c51c402a45bc",
"libroFormatter": "formatter-string"
},
"outputs": [],
"source": []
}
],
"metadata": {
@ -1181,7 +1070,7 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
@ -1195,7 +1084,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.11.13"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {

View File

@ -945,11 +945,7 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
"""
Returns the base model.
"""
return (
self.base_model
if (self.active_peft_config.is_prompt_learning or self.peft_type == PeftType.POLY)
else self.base_model.model
)
return self.base_model if self.active_peft_config.is_prompt_learning else self.base_model.model
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
"""