mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
ENH XPU support for dna_language_model example (#2689)
--------- Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
This commit is contained in:
@ -114,7 +114,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": null,
|
||||
"id": "ca43b893-2d66-4e93-a08f-b17a92040709",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -185,7 +185,8 @@
|
||||
],
|
||||
"source": [
|
||||
"lm.eval()\n",
|
||||
"lm.to(\"cuda\");"
|
||||
"device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n",
|
||||
"lm.to(device);"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -210,7 +211,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": null,
|
||||
"id": "f5c0b3df-911a-4645-9140-99ee489515e8",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -327,7 +328,8 @@
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"\n",
|
||||
"raw_data = load_dataset(\"InstaDeepAI/nucleotide_transformer_downstream_tasks\", \"H3\")"
|
||||
"raw_data_full = load_dataset(\"InstaDeepAI/nucleotide_transformer_downstream_tasks\")\n",
|
||||
"raw_data = raw_data_full.filter(lambda example: example['task'] == 'H3')"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -592,7 +594,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": null,
|
||||
"id": "700540f4-0ab8-4f8a-a75c-416a6908af47",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -720,7 +722,7 @@
|
||||
"# Number of classes for your classification task\n",
|
||||
"num_labels = 2\n",
|
||||
"classification_model = DNA_LM(lm, num_labels)\n",
|
||||
"classification_model.to('cuda');"
|
||||
"classification_model.to(device);"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -991,7 +993,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"execution_count": null,
|
||||
"id": "021641ae-f604-4d69-8724-743b7d7c613c",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -1094,7 +1096,7 @@
|
||||
"# Number of classes for your classification task\n",
|
||||
"num_labels = 2\n",
|
||||
"classification_model = DNA_LM(lm, num_labels)\n",
|
||||
"classification_model.to('cuda');"
|
||||
"classification_model.to(device);"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
Reference in New Issue
Block a user