DOC Make docs more device agnostic (e.g. XPU) (#2728)

Also adjusted some more examples.

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix
2025-08-08 03:06:22 -07:00
committed by GitHub
parent 7f7463548a
commit e98a59ec2d
8 changed files with 38 additions and 18 deletions

View File

@ -42,7 +42,7 @@ Prepare a model for training with a PEFT method such as LoRA by wrapping the bas
from transformers import AutoModelForCausalLM
from peft import LoraConfig, TaskType, get_peft_model
device = "cuda"
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
model_id = "Qwen/Qwen2.5-3B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device)
peft_config = LoraConfig(
@ -65,7 +65,7 @@ To load a PEFT model for inference:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
device = "cuda"
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
model_id = "Qwen/Qwen2.5-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device)

View File

@ -109,7 +109,7 @@ peft_config = LoraConfig(
```
The parameter `rho` (≥ 1.0) determines how much redistribution is allowed. When `rho=1.0` and `r=16`, LoRA adapters are limited to exactly 16 ranks, preventing any redistribution from occurring. A recommended value for EVA with redistribution is 2.0, meaning the maximum rank allowed for a layer is 2r.
It is recommended to perform EVA initialization on a GPU as it is much faster. To optimize the amount of available memory for EVA, you can use the `low_cpu_mem_usage` flag in [`get_peft_model`]:
It is recommended to perform EVA initialization on an accelerator(e.g. CUDA GPU, Intel XPU) as it is much faster. To optimize the amount of available memory for EVA, you can use the `low_cpu_mem_usage` flag in [`get_peft_model`]:
```python
peft_model = get_peft_model(model, peft_config, low_cpu_mem_usage=True)
```
@ -203,7 +203,7 @@ model = PeftModel.from_pretrained(base_model, peft_model_id, ephemeral_gpu_offlo
DoRA is optimized (computes faster and takes less memory) for models in the evaluation mode, or when dropout is set to 0. We reuse the
base result at those times to get the speedup.
Running [dora finetuning](https://github.com/huggingface/peft/blob/main/examples/dora_finetuning/dora_finetuning.py)
with `CUDA_VISIBLE_DEVICES=0 time python examples/dora_finetuning/dora_finetuning.py --quantize --lora_dropout 0 --batch_size 16 --eval_step 2 --use_dora`
with `CUDA_VISIBLE_DEVICES=0 ZE_AFFINITY_MASK=0 time python examples/dora_finetuning/dora_finetuning.py --quantize --lora_dropout 0 --batch_size 16 --eval_step 2 --use_dora`
on a 4090 with gradient accumulation set to 2 and max step to 20 resulted with the following observations:
| | Without Optimization | With Optimization |
@ -471,11 +471,13 @@ There are several supported methods for `combination_type`. Refer to the [docume
Now, perform inference:
```python
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
generate_ids = model.generate(**inputs, max_length=30)

View File

@ -99,12 +99,13 @@ Now you can use the merged model as an instruction-tuned model to write ad copy
<hfoption id="instruct">
```py
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
messages = [
{"role": "user", "content": "Write an essay about Generative AI."},
]
text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
inputs = tokenizer(text, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model.generate(**inputs, max_new_tokens=256, do_sample=True, top_p=0.95, temperature=0.2, repetition_penalty=1.2, eos_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(outputs[0]))
```
@ -113,13 +114,14 @@ print(tokenizer.decode(outputs[0]))
<hfoption id="ad copy">
```py
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
messages = [
{"role": "system", "content": "Create a text ad given the following product and description."},
{"role": "user", "content": "Product: Sony PS5 PlayStation Console\nDescription: The PS5 console unleashes new gaming possibilities that you never anticipated."},
]
text = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
inputs = tokenizer(text, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model.generate(**inputs, max_new_tokens=128, do_sample=True, top_p=0.95, temperature=0.2, repetition_penalty=1.2, eos_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(outputs[0]))
```
@ -128,13 +130,15 @@ print(tokenizer.decode(outputs[0]))
<hfoption id="SQL">
```py
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
text = """Table: 2-11365528-2
Columns: ['Team', 'Head Coach', 'President', 'Home Ground', 'Location']
Natural Query: Who is the Head Coach of the team whose President is Mario Volarevic?
SQL Query:"""
inputs = tokenizer(text, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model.generate(**inputs, max_new_tokens=64, repetition_penalty=1.1, eos_token_id=tokenizer("</s>").input_ids[-1])
print(tokenizer.decode(outputs[0]))
```

View File

@ -197,7 +197,9 @@ The models that are quantized using Half-Quadratic Quantization of Large Machine
```python
from hqq.engine.hf import HQQModelForCausalLM
quantized_model = HQQModelForCausalLM.from_quantized(save_dir_or_hfhub, device='cuda')
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
quantized_model = HQQModelForCausalLM.from_quantized(save_dir_or_hfhub, device=device)
peft_config = LoraConfig(...)
quantized_model = get_peft_model(quantized_model, peft_config)
```

View File

@ -92,7 +92,7 @@ processed_ds = ds.map(
)
```
Create a training and evaluation [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), and set `pin_memory=True` to speed up data transfer to the GPU during training if your dataset samples are on a CPU.
Create a training and evaluation [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), and set `pin_memory=True` to speed up data transfer to the accelerator during training if your dataset samples are on a CPU.
```py
from torch.utils.data import DataLoader
@ -159,12 +159,12 @@ lr_scheduler = get_linear_schedule_with_warmup(
)
```
Move the model to the GPU and create a training loop that reports the loss and perplexity for each epoch.
Move the model to the accelerator and create a training loop that reports the loss and perplexity for each epoch.
```py
from tqdm import tqdm
device = "cuda"
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
model = model.to(device)
for epoch in range(num_epochs):
@ -219,7 +219,9 @@ To load the model for inference, use the [`~AutoPeftModelForSeq2SeqLM.from_pretr
```py
from peft import AutoPeftModelForSeq2SeqLM
model = AutoPeftModelForSeq2SeqLM.from_pretrained("<your-hf-account-name>/mt0-large-ia3").to("cuda")
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
model = AutoPeftModelForSeq2SeqLM.from_pretrained("<your-hf-account-name>/mt0-large-ia3").to(device)
tokenizer = AutoTokenizer.from_pretrained("bigscience/mt0-large")
i = 15

View File

@ -426,7 +426,10 @@ class LightControlNetPipeline(StableDiffusionControlNetPipeline):
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]

View File

@ -9,7 +9,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # force to use CUDA GPU device 0
os.environ["ZE_AFFINITY_MASK"] = "0" # force to use Intel XPU device 0
# -*- coding: utf-8 -*-
"""Finetune-opt-bnb-peft.ipynb

View File

@ -62,7 +62,10 @@ def load_or_quantize_model(
print(f"Model {base_model} is not GPTQ-quantized. Will quantize it.")
# Clean up the test model to free memory
del test_model
torch.cuda.empty_cache() if torch.cuda.is_available() else None
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()
except Exception as e:
print(f"Could not load model {base_model} directly: {e}")
@ -253,8 +256,11 @@ def train_model(
label_names=["labels"],
)
# Clear CUDA cache to free memory
torch.cuda.empty_cache()
# Clear accelerator cache to free memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.xpu.is_available():
torch.xpu.empty_cache()
# Initialize trainer
trainer = Trainer(