mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 23:43:47 +08:00
693 lines
31 KiB
Plaintext
693 lines
31 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "9ff5004e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"===================================BUG REPORT===================================\n",
|
|
"Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
|
|
"For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
|
|
"================================================================================\n",
|
|
"CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so\n",
|
|
"CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
|
|
"CUDA SETUP: Detected CUDA version 117\n",
|
|
"CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import argparse\n",
|
|
"import os\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"from torch.optim import AdamW\n",
|
|
"from torch.utils.data import DataLoader\n",
|
|
"from peft import (\n",
|
|
" get_peft_config,\n",
|
|
" get_peft_model,\n",
|
|
" get_peft_model_state_dict,\n",
|
|
" set_peft_model_state_dict,\n",
|
|
" PeftType,\n",
|
|
" PrefixTuningConfig,\n",
|
|
" PromptEncoderConfig,\n",
|
|
" PromptTuningConfig,\n",
|
|
")\n",
|
|
"\n",
|
|
"import evaluate\n",
|
|
"from datasets import load_dataset\n",
|
|
"from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n",
|
|
"from tqdm import tqdm"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e32c4a9e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch_size = 32\n",
|
|
"model_name_or_path = \"roberta-large\"\n",
|
|
"task = \"mrpc\"\n",
|
|
"peft_type = PeftType.PROMPT_TUNING\n",
|
|
"device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n",
|
|
"num_epochs = 20"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "622fe9c8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"peft_config = PromptTuningConfig(task_type=\"SEQ_CLS\", num_virtual_tokens=10)\n",
|
|
"lr = 1e-3"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "74e9efe0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Found cached dataset glue (/home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "76198cec552441818ff107910275e5be",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/3 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9fa7887f9eaa03ae.arrow\n",
|
|
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-dc593149bbeafe80.arrow\n",
|
|
"Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-140ebe5b70e09817.arrow\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n",
|
|
" padding_side = \"left\"\n",
|
|
"else:\n",
|
|
" padding_side = \"right\"\n",
|
|
"\n",
|
|
"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n",
|
|
"if getattr(tokenizer, \"pad_token_id\") is None:\n",
|
|
" tokenizer.pad_token_id = tokenizer.eos_token_id\n",
|
|
"\n",
|
|
"datasets = load_dataset(\"glue\", task)\n",
|
|
"metric = evaluate.load(\"glue\", task)\n",
|
|
"\n",
|
|
"\n",
|
|
"def tokenize_function(examples):\n",
|
|
" # max_length=None => use the model max length (it's actually the default)\n",
|
|
" outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n",
|
|
" return outputs\n",
|
|
"\n",
|
|
"\n",
|
|
"tokenized_datasets = datasets.map(\n",
|
|
" tokenize_function,\n",
|
|
" batched=True,\n",
|
|
" remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n",
|
|
")\n",
|
|
"\n",
|
|
"# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n",
|
|
"# transformers library\n",
|
|
"tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
|
|
"\n",
|
|
"\n",
|
|
"def collate_fn(examples):\n",
|
|
" return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n",
|
|
"\n",
|
|
"\n",
|
|
"# Instantiate dataloaders.\n",
|
|
"train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n",
|
|
"eval_dataloader = DataLoader(\n",
|
|
" tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a3c15af0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)\n",
|
|
"model = get_peft_model(model, peft_config)\n",
|
|
"model.print_trainable_parameters()\n",
|
|
"model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "6d3c5edb",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"optimizer = AdamW(params=model.parameters(), lr=lr)\n",
|
|
"\n",
|
|
"# Instantiate scheduler\n",
|
|
"lr_scheduler = get_linear_schedule_with_warmup(\n",
|
|
" optimizer=optimizer,\n",
|
|
" num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs),\n",
|
|
" num_training_steps=(len(train_dataloader) * num_epochs),\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "4d279225",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" 0%| | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [02:09<00:00, 1.13s/it]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:08<00:00, 1.62it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 0: {'accuracy': 0.678921568627451, 'f1': 0.7956318252730109}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:50<00:00, 1.04it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.22it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 1: {'accuracy': 0.696078431372549, 'f1': 0.8171091445427728}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:36<00:00, 1.19it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:06<00:00, 2.00it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 2: {'accuracy': 0.6985294117647058, 'f1': 0.8161434977578476}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:37<00:00, 1.18it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:06<00:00, 2.09it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 3: {'accuracy': 0.7058823529411765, 'f1': 0.7979797979797979}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [02:03<00:00, 1.07s/it]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:07<00:00, 1.71it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 4: {'accuracy': 0.696078431372549, 'f1': 0.8132530120481929}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:53<00:00, 1.01it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.19it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 5: {'accuracy': 0.7107843137254902, 'f1': 0.8121019108280254}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:35<00:00, 1.20it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.20it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 6: {'accuracy': 0.6911764705882353, 'f1': 0.7692307692307693}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:36<00:00, 1.20it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.18it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 7: {'accuracy': 0.7156862745098039, 'f1': 0.8209876543209876}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:35<00:00, 1.20it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.22it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 8: {'accuracy': 0.7205882352941176, 'f1': 0.8240740740740742}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:36<00:00, 1.19it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.21it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 9: {'accuracy': 0.7205882352941176, 'f1': 0.8229813664596273}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:36<00:00, 1.20it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.35it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 10: {'accuracy': 0.7156862745098039, 'f1': 0.8164556962025317}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:35<00:00, 1.20it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.22it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 11: {'accuracy': 0.7058823529411765, 'f1': 0.8113207547169811}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:32<00:00, 1.24it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.48it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 12: {'accuracy': 0.7009803921568627, 'f1': 0.7946127946127945}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:32<00:00, 1.24it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.38it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 13: {'accuracy': 0.7230392156862745, 'f1': 0.8186195826645265}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:29<00:00, 1.29it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.31it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 14: {'accuracy': 0.7058823529411765, 'f1': 0.8130841121495327}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:30<00:00, 1.27it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.39it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 15: {'accuracy': 0.7181372549019608, 'f1': 0.8194662480376768}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:28<00:00, 1.29it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.35it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 16: {'accuracy': 0.7254901960784313, 'f1': 0.8181818181818181}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:30<00:00, 1.27it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.30it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 17: {'accuracy': 0.7205882352941176, 'f1': 0.820754716981132}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:30<00:00, 1.27it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.36it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 18: {'accuracy': 0.7254901960784313, 'f1': 0.821656050955414}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:28<00:00, 1.29it/s]\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.43it/s]"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"epoch 19: {'accuracy': 0.7303921568627451, 'f1': 0.8242811501597445}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model.to(device)\n",
|
|
"for epoch in range(num_epochs):\n",
|
|
" model.train()\n",
|
|
" for step, batch in enumerate(tqdm(train_dataloader)):\n",
|
|
" batch.to(device)\n",
|
|
" outputs = model(**batch)\n",
|
|
" loss = outputs.loss\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
" lr_scheduler.step()\n",
|
|
" optimizer.zero_grad()\n",
|
|
"\n",
|
|
" model.eval()\n",
|
|
" for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
|
" batch.to(device)\n",
|
|
" with torch.no_grad():\n",
|
|
" outputs = model(**batch)\n",
|
|
" predictions = outputs.logits.argmax(dim=-1)\n",
|
|
" predictions, references = predictions, batch[\"labels\"]\n",
|
|
" metric.add_batch(\n",
|
|
" predictions=predictions,\n",
|
|
" references=references,\n",
|
|
" )\n",
|
|
"\n",
|
|
" eval_metric = metric.compute()\n",
|
|
" print(f\"epoch {epoch}:\", eval_metric)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "e1ff3f44",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Share adapters on the 🤗 Hub"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "0bf79cb5",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"CommitInfo(commit_url='https://huggingface.co/smangrul/roberta-large-peft-prompt-tuning/commit/893a909d8499aa8778d58c781d43c3a8d9360de8', commit_message='Upload model', commit_description='', oid='893a909d8499aa8778d58c781d43c3a8d9360de8', pr_url=None, pr_revision=None, pr_num=None)"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model.push_to_hub(\"smangrul/roberta-large-peft-prompt-tuning\", use_auth_token=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "73870ad7",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Load adapters from the Hub\n",
|
|
"\n",
|
|
"You can also directly load adapters from the Hub using the commands below:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "0654a552",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "24581bb98582444ca6114b9fa267847f",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Downloading: 0%| | 0.00/368 [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias']\n",
|
|
"- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
|
"- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
|
"Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight']\n",
|
|
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "f1584da4d1c54cc3873a515182674980",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"Downloading: 0%| | 0.00/4.25M [00:00<?, ?B/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" 0%| | 0/13 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
|
|
"100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00, 2.58it/s]"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"{'accuracy': 0.7303921568627451, 'f1': 0.8242811501597445}\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch\n",
|
|
"from peft import PeftModel, PeftConfig\n",
|
|
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
|
"\n",
|
|
"peft_model_id = \"smangrul/roberta-large-peft-prompt-tuning\"\n",
|
|
"config = PeftConfig.from_pretrained(peft_model_id)\n",
|
|
"inference_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path)\n",
|
|
"tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
|
|
"\n",
|
|
"# Load the Lora model\n",
|
|
"inference_model = PeftModel.from_pretrained(inference_model, peft_model_id)\n",
|
|
"\n",
|
|
"inference_model.to(device)\n",
|
|
"inference_model.eval()\n",
|
|
"for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
|
" batch.to(device)\n",
|
|
" with torch.no_grad():\n",
|
|
" outputs = inference_model(**batch)\n",
|
|
" predictions = outputs.logits.argmax(dim=-1)\n",
|
|
" predictions, references = predictions, batch[\"labels\"]\n",
|
|
" metric.add_batch(\n",
|
|
" predictions=predictions,\n",
|
|
" references=references,\n",
|
|
" )\n",
|
|
"\n",
|
|
"eval_metric = metric.compute()\n",
|
|
"print(eval_metric)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.4"
|
|
},
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|