ENH Support XPU in DoRA FT example (#2700)

This commit is contained in:
kaixuanliu
2025-09-25 23:57:41 +08:00
committed by GitHub
parent 4f868bd7c9
commit c15daaa5aa
3 changed files with 19 additions and 11 deletions

View File

@ -6,7 +6,7 @@
"id": "CV_gQs58bsvM"
},
"source": [
"# Fine-tuning [Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) on [timdettmers/openassistant-guanaco](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) Dataset using QDora (quantized Lora w/ use_dora=True) on T4 Free Colab GPU."
"# Fine-tuning [Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) on [timdettmers/openassistant-guanaco](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) Dataset using QDora (quantized Lora w/ use_dora=True)."
]
},
{
@ -1010,6 +1010,7 @@
"top_p = 0.9\n",
"temperature = 0.7\n",
"user_question = \"What is the purpose of quantization in LLMs?\"\n",
"device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n",
"\n",
"\n",
"prompt = (\n",
@ -1021,7 +1022,7 @@
"\n",
"\n",
"def generate(model, user_question, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature):\n",
" inputs = tokenizer(prompt.format(user_question=user_question), return_tensors=\"pt\").to(\"cuda\")\n",
" inputs = tokenizer(prompt.format(user_question=user_question), return_tensors=\"pt\").to(device)\n",
"\n",
" outputs = model.generate(\n",
" **inputs,\n",

View File

@ -13,7 +13,7 @@ from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer
from datasets import load_dataset
model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", device_map="cuda")
model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
lora_config = LoraConfig(
@ -70,7 +70,6 @@ python dora_finetuning.py \
--quantize \
--eval_step 10 \
--save_step 100 \
--device "cuda:0" \
--lora_r 16 \
--lora_alpha 32 \
--lora_dropout 0.05 \

View File

@ -39,7 +39,10 @@ def train_model(
hf_token = os.getenv("HF_TOKEN")
# Setup device
device = torch.device(device)
if device == "auto":
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
else:
device = torch.device(device)
print(f"Using device: {device}")
# load tokenizer
@ -47,14 +50,16 @@ def train_model(
# QDoRA (quantized dora): IF YOU WANNA QUANTIZE THE MODEL
if quantize:
if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) or torch.xpu.is_available():
bnb_4bit_compute_dtype = torch.bfloat16
else:
bnb_4bit_compute_dtype = torch.float16
model = AutoModelForCausalLM.from_pretrained(
base_model,
token=hf_token,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=(
torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
),
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
),
@ -117,8 +122,11 @@ def train_model(
hub_token=hf_token,
)
# Clear CUDA cache to free memory
torch.cuda.empty_cache()
# Clear device cache to free memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()
# Initialize the Trainer
trainer = Trainer(
@ -162,7 +170,7 @@ if __name__ == "__main__":
parser.add_argument("--quantize", action="store_true", help="Use quantization")
parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval")
parser.add_argument("--save_step", type=int, default=100, help="Save step interval")
parser.add_argument("--device", type=str, default="cuda:0", help="Device to use for training")
parser.add_argument("--device", type=str, default="auto", help="Device to use for training")
parser.add_argument("--lora_r", type=int, default=8, help="LoRA rank")
parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha")
parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout rate")