TST Enable BNB tests on XPU (#2396)

This commit is contained in:
Fanli Lin
2025-03-06 23:18:47 +08:00
committed by GitHub
parent 461f6426ef
commit 24150d0e41
8 changed files with 97 additions and 105 deletions

View File

@ -23,6 +23,8 @@ import torch
import transformers
from torch import nn
from peft.import_utils import is_xpu_available
@contextmanager
def gather_params_ctx(param, modifier_rank: int = 0, fwd_module: torch.nn.Module = None):
@ -90,8 +92,11 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
# BNB requires CUDA weights
device = weight.device
is_cpu = device.type == torch.device("cpu").type
if is_cpu and torch.cuda.is_available():
weight = weight.to(torch.device("cuda"))
if is_cpu:
if torch.cuda.is_available():
weight = weight.to(torch.device("cuda"))
elif is_xpu_available():
weight = weight.to(torch.device("xpu"))
cls_name = weight.__class__.__name__
if cls_name == "Params4bit":

View File

@ -22,13 +22,14 @@ import os
from typing import Callable, Optional, Union
import torch
from accelerate.utils.memory import clear_device_cache
from huggingface_hub import snapshot_download
from huggingface_hub.errors import HFValidationError, LocalEntryNotFoundError
from safetensors import SafetensorError, safe_open
from transformers.utils import cached_file
from transformers.utils.hub import get_checkpoint_shard_files
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_xpu_available
class NFQuantizer:
@ -201,7 +202,6 @@ def loftq_init(weight: Union[torch.Tensor, torch.nn.Parameter], num_bits: int, r
out_feature, in_feature = weight.size()
device = weight.device
dtype = weight.dtype
logging.info(
f"Weight: ({out_feature}, {in_feature}) | Rank: {reduced_rank} | Num Iter: {num_iter} | Num Bits: {num_bits}"
)
@ -209,12 +209,12 @@ def loftq_init(weight: Union[torch.Tensor, torch.nn.Parameter], num_bits: int, r
quantizer = NFQuantizer(num_bits=num_bits, device=device, method="normal", block_size=64)
compute_device = device
else:
compute_device = "cuda"
compute_device = "xpu" if is_xpu_available() else "cuda"
weight = weight.to(device=compute_device, dtype=torch.float32)
res = weight.clone()
for i in range(num_iter):
torch.cuda.empty_cache()
clear_device_cache()
# Quantization
if num_bits == 4 and is_bnb_4bit_available():
qweight = bnb.nn.Params4bit(
@ -246,12 +246,12 @@ def _loftq_init_new(qweight, weight, num_bits: int, reduced_rank: int):
if not is_bnb_4bit_available():
raise ValueError("bitsandbytes 4bit quantization is not available.")
compute_device = "cuda"
compute_device = "xpu" if is_xpu_available() else "cuda"
dequantized_weight = bnb.functional.dequantize_4bit(qweight.data, qweight.quant_state)
weight = weight.to(device=compute_device, dtype=torch.float32)
residual = weight - dequantized_weight
torch.cuda.empty_cache()
clear_device_cache()
# Decompose the residualidual by SVD
output = _low_rank_decomposition(residual, reduced_rank=reduced_rank)
L, R, reduced_rank = output["L"], output["R"], output["reduced_rank"]

View File

@ -27,10 +27,12 @@ import pytest
import torch
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig
from peft.import_utils import is_xpu_available
bnb = pytest.importorskip("bitsandbytes")
device = torch.device("cuda")
device = torch.device("xpu") if is_xpu_available() else torch.device("cuda")
def bytes_from_tensor(x):
@ -47,7 +49,7 @@ def bytes_from_tensor(x):
############
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
def test_opt_350m_4bit():
torch.manual_seed(0)
bnb_config = BitsAndBytesConfig(
@ -70,7 +72,7 @@ def test_opt_350m_4bit():
torch.testing.assert_allclose(output, expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
def test_opt_350m_8bit():
torch.manual_seed(0)
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
@ -89,7 +91,7 @@ def test_opt_350m_8bit():
torch.testing.assert_allclose(output, expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
def test_opt_350m_4bit_double_quant():
torch.manual_seed(0)
bnb_config = BitsAndBytesConfig(
@ -112,7 +114,7 @@ def test_opt_350m_4bit_double_quant():
torch.testing.assert_allclose(output, expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
def test_opt_350m_4bit_compute_dtype_float16():
torch.manual_seed(0)
bnb_config = BitsAndBytesConfig(
@ -135,7 +137,7 @@ def test_opt_350m_4bit_compute_dtype_float16():
torch.testing.assert_allclose(output, expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
def test_opt_350m_4bit_quant_type_nf4():
torch.manual_seed(0)
bnb_config = BitsAndBytesConfig(
@ -159,7 +161,7 @@ def test_opt_350m_4bit_quant_type_nf4():
torch.testing.assert_allclose(output, expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
def test_opt_350m_4bit_quant_storage():
# note: using torch.float32 instead of the default torch.uint8 does not seem to affect the result
torch.manual_seed(0)
@ -184,7 +186,7 @@ def test_opt_350m_4bit_quant_storage():
torch.testing.assert_allclose(output, expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
def test_opt_350m_8bit_threshold():
torch.manual_seed(0)
bnb_config = BitsAndBytesConfig(
@ -211,7 +213,7 @@ def test_opt_350m_8bit_threshold():
###########
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
def test_flan_t5_4bit():
torch.manual_seed(0)
bnb_config = BitsAndBytesConfig(
@ -235,7 +237,7 @@ def test_flan_t5_4bit():
torch.testing.assert_allclose(output, expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
@pytest.mark.xfail # might not be reproducible depending on hardware
def test_flan_t5_8bit():
torch.manual_seed(0)

View File

@ -63,7 +63,6 @@ from .testing_utils import (
require_multi_accelerator,
require_non_cpu,
require_torch_gpu,
require_torch_multi_gpu,
)
@ -594,7 +593,7 @@ class PeftGPUCommonTests(unittest.TestCase):
# this should work without any problem
_ = model.generate(input_ids=input_ids)
@require_torch_multi_gpu
@require_multi_accelerator
@pytest.mark.multi_gpu_tests
@require_bitsandbytes
def test_lora_seq2seq_lm_multi_gpu_inference(self):
@ -622,7 +621,7 @@ class PeftGPUCommonTests(unittest.TestCase):
# this should work without any problem
_ = model.generate(input_ids=input_ids)
@require_torch_multi_gpu
@require_multi_accelerator
@pytest.mark.multi_gpu_tests
@require_bitsandbytes
def test_adaption_prompt_8bit(self):
@ -645,7 +644,7 @@ class PeftGPUCommonTests(unittest.TestCase):
random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
_ = model(random_input)
@require_torch_multi_gpu
@require_multi_accelerator
@pytest.mark.multi_gpu_tests
@require_bitsandbytes
def test_adaption_prompt_4bit(self):
@ -668,7 +667,7 @@ class PeftGPUCommonTests(unittest.TestCase):
random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
_ = model(random_input)
@require_torch_gpu
@require_non_cpu
@pytest.mark.single_gpu_tests
@require_bitsandbytes
def test_print_4bit_expected(self):
@ -778,7 +777,7 @@ class PeftGPUCommonTests(unittest.TestCase):
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, bnb.nn.Linear8bitLt)
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, bnb.nn.Linear8bitLt)
@require_torch_gpu
@require_non_cpu
@pytest.mark.single_gpu_tests
@require_bitsandbytes
def test_8bit_merge_and_disable_lora(self):
@ -814,7 +813,7 @@ class PeftGPUCommonTests(unittest.TestCase):
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear8bitLt)
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear8bitLt)
@require_torch_gpu
@require_non_cpu
@pytest.mark.single_gpu_tests
@require_bitsandbytes
def test_8bit_merge_lora_with_bias(self):
@ -846,7 +845,7 @@ class PeftGPUCommonTests(unittest.TestCase):
assert not torch.allclose(out_base, out_before_merge, atol=atol, rtol=rtol)
assert torch.allclose(out_before_merge, out_after_merge, atol=atol, rtol=rtol)
@require_torch_gpu
@require_non_cpu
@pytest.mark.single_gpu_tests
@require_bitsandbytes
def test_4bit_merge_lora(self):
@ -888,7 +887,7 @@ class PeftGPUCommonTests(unittest.TestCase):
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, bnb.nn.Linear4bit)
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, bnb.nn.Linear4bit)
@require_torch_gpu
@require_non_cpu
@pytest.mark.single_gpu_tests
@require_bitsandbytes
def test_4bit_merge_and_disable_lora(self):
@ -930,7 +929,7 @@ class PeftGPUCommonTests(unittest.TestCase):
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear4bit)
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear4bit)
@require_torch_gpu
@require_non_cpu
@pytest.mark.single_gpu_tests
@require_bitsandbytes
def test_4bit_merge_lora_with_bias(self):
@ -971,7 +970,7 @@ class PeftGPUCommonTests(unittest.TestCase):
assert not torch.allclose(out_base, out_before_merge, atol=atol, rtol=rtol)
assert torch.allclose(out_before_merge, out_after_merge, atol=atol, rtol=rtol)
@require_torch_gpu
@require_non_cpu
@pytest.mark.single_gpu_tests
@require_bitsandbytes
def test_4bit_lora_mixed_adapter_batches_lora(self):
@ -1042,7 +1041,7 @@ class PeftGPUCommonTests(unittest.TestCase):
assert torch.allclose(out_adapter0[1::3], out_mixed[1::3], atol=atol, rtol=rtol)
assert torch.allclose(out_adapter1[2::3], out_mixed[2::3], atol=atol, rtol=rtol)
@require_torch_gpu
@require_non_cpu
@pytest.mark.single_gpu_tests
@require_bitsandbytes
def test_8bit_lora_mixed_adapter_batches_lora(self):
@ -1124,7 +1123,7 @@ class PeftGPUCommonTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=True)
@require_torch_gpu
@require_non_cpu
@pytest.mark.single_gpu_tests
@require_bitsandbytes
def test_4bit_dora_inference(self):
@ -1163,7 +1162,7 @@ class PeftGPUCommonTests(unittest.TestCase):
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear4bit)
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear4bit)
@require_torch_gpu
@require_non_cpu
@pytest.mark.single_gpu_tests
@require_bitsandbytes
def test_8bit_dora_inference(self):
@ -1197,7 +1196,7 @@ class PeftGPUCommonTests(unittest.TestCase):
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear8bitLt)
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear8bitLt)
@require_torch_gpu
@require_non_cpu
@pytest.mark.single_gpu_tests
@require_bitsandbytes
def test_4bit_dora_merging(self):

View File

@ -27,6 +27,7 @@ import torch
from accelerate import infer_auto_device_map
from accelerate.test_utils.testing import run_command
from accelerate.utils import patch_environment
from accelerate.utils.memory import clear_device_cache
from datasets import Audio, Dataset, DatasetDict, load_dataset
from packaging import version
from parameterized import parameterized
@ -91,6 +92,7 @@ from .testing_utils import (
require_torch_gpu,
require_torch_multi_gpu,
require_torchao,
torch_device,
)
@ -131,7 +133,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
return batch
@require_torch_gpu
@require_non_cpu
@require_bitsandbytes
class PeftBnbGPUExampleTests(unittest.TestCase):
r"""
@ -160,10 +162,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
Efficient mechanism to free GPU memory after each test. Based on
https://github.com/huggingface/transformers/issues/21094
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clear_device_cache(garbage_collection=True)
def _check_inference_finite(self, model, batch):
# try inference without Trainer class
@ -351,7 +350,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
assert trainer.state.log_history[-1]["train_loss"] is not None
@pytest.mark.single_gpu_tests
@require_torch_gpu
@require_non_cpu
def test_4bit_adalora_causalLM(self):
r"""
Tests the 4bit training with adalora
@ -1504,8 +1503,7 @@ class PeftGPTQGPUTests(unittest.TestCase):
Efficient mechanism to free GPU memory after each test. Based on
https://github.com/huggingface/transformers/issues/21094
"""
gc.collect()
torch.cuda.empty_cache()
clear_device_cache(garbage_collection=True)
def _check_inference_finite(self, model, batch):
# try inference without Trainer class
@ -1752,8 +1750,7 @@ class OffloadSaveTests(unittest.TestCase):
Efficient mechanism to free GPU memory after each test. Based on
https://github.com/huggingface/transformers/issues/21094
"""
gc.collect()
torch.cuda.empty_cache()
clear_device_cache(garbage_collection=True)
def test_offload_load(self):
r"""
@ -1832,7 +1829,7 @@ class OffloadSaveTests(unittest.TestCase):
assert torch.allclose(post_unload_merge_olayer, pre_merge_olayer)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="test requires a GPU")
@pytest.mark.single_gpu_tests
class TestPiSSA:
r"""
@ -1888,8 +1885,7 @@ class TestPiSSA:
qlora_model = qlora_model.merge_and_unload()
qlora_error = self.nuclear_norm(base_model, qlora_model)
del qlora_model
gc.collect()
torch.cuda.empty_cache()
clear_device_cache(garbage_collection=True)
# logits from quantized LoRA model using PiSSA
lora_config = LoraConfig(
@ -1908,8 +1904,7 @@ class TestPiSSA:
pissa_model.save_pretrained(tmp_path / "residual_model")
del pissa_model
gc.collect()
torch.cuda.empty_cache()
clear_device_cache(garbage_collection=True)
# now load quantized model and apply PiSSA-initialized weights on top
qpissa_model = self.quantize_model(
@ -1919,8 +1914,7 @@ class TestPiSSA:
qpissa_model = qpissa_model.merge_and_unload()
qpissa_error = self.nuclear_norm(base_model, qpissa_model)
del qpissa_model
gc.collect()
torch.cuda.empty_cache()
clear_device_cache(garbage_collection=True)
assert qlora_error > 0.0
assert qpissa_error > 0.0
@ -1928,7 +1922,7 @@ class TestPiSSA:
# next, check that PiSSA quantization errors are smaller than LoRA errors by a certain margin
assert qpissa_error < (qlora_error / self.error_factor)
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_bloomz_pissa_4bit(self, device, tmp_path):
# In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model
# using PiSSA. When quantizing, we expect a certain level of error. However, we expect the PiSSA quantized
@ -1938,25 +1932,25 @@ class TestPiSSA:
self.get_errors(bits=4, device=device, tmp_path=tmp_path)
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_bloomz_pissa_8bit(self, device, tmp_path):
# Same test as test_bloomz_pissa_4bit but with 8 bits.
self.get_errors(bits=8, device=device, tmp_path=tmp_path)
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_t5_pissa_4bit(self, device, tmp_path):
self.get_errors(bits=4, device=device, model_id="t5-small", tmp_path=tmp_path)
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_t5_pissa_8bit(self, device, tmp_path):
self.get_errors(bits=8, device=device, model_id="t5-small", tmp_path=tmp_path)
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_gpt2_pissa_4bit(self, device, tmp_path):
# see 2104
self.get_errors(bits=4, device=device, model_id="gpt2", tmp_path=tmp_path)
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_gpt2_pissa_8bit(self, device, tmp_path):
# see 2104
self.get_errors(bits=8, device=device, model_id="gpt2", tmp_path=tmp_path)
@ -1968,7 +1962,7 @@ class TestPiSSA:
import bitsandbytes as bnb
torch.manual_seed(0)
data = torch.rand(10, 1000).to("cuda")
data = torch.rand(10, 1000).to(torch_device)
class MyModule(torch.nn.Module):
def __init__(self):
@ -1983,7 +1977,7 @@ class TestPiSSA:
x_4d = x.flatten().reshape(1, 100, 10, 10)
return self.linear(x), self.embed(x_int), self.conv2d(x_4d)
model = MyModule().to("cuda")
model = MyModule().to(torch_device)
output_base = model(data)[0]
config = LoraConfig(init_lora_weights="pissa", target_modules=["linear"], r=8)
@ -2047,7 +2041,7 @@ class TestPiSSA:
assert not torch.allclose(output_finetuned_pissa, output_converted, atol=tol, rtol=tol)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="test requires a GPU")
@pytest.mark.single_gpu_tests
class TestOLoRA:
r"""
@ -2103,8 +2097,7 @@ class TestOLoRA:
qlora_model = qlora_model.merge_and_unload()
qlora_error = self.nuclear_norm(base_model, qlora_model)
del qlora_model
gc.collect()
torch.cuda.empty_cache()
clear_device_cache(garbage_collection=True)
# logits from quantized LoRA model using OLoRA
lora_config = LoraConfig(
@ -2123,8 +2116,7 @@ class TestOLoRA:
olora_model.save_pretrained(tmp_path / "residual_model")
del olora_model
gc.collect()
torch.cuda.empty_cache()
clear_device_cache(garbage_collection=True)
# now load quantized model and apply OLoRA-initialized weights on top
qolora_model = self.quantize_model(
@ -2134,8 +2126,7 @@ class TestOLoRA:
qolora_model = qolora_model.merge_and_unload()
qolora_error = self.nuclear_norm(base_model, qolora_model)
del qolora_model
gc.collect()
torch.cuda.empty_cache()
clear_device_cache(garbage_collection=True)
assert qlora_error > 0.0
assert qolora_error > 0.0
@ -2143,7 +2134,7 @@ class TestOLoRA:
# next, check that OLoRA quantization errors are smaller than LoRA errors by a certain margin
assert qolora_error < (qlora_error / self.error_factor)
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_bloomz_olora_4bit(self, device, tmp_path):
# In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model
# using OLoRA. When quantizing, we expect a certain level of error. However, we expect the OLoRA quantized
@ -2153,7 +2144,7 @@ class TestOLoRA:
self.get_errors(bits=4, device=device, tmp_path=tmp_path)
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_bloomz_olora_8bit(self, device, tmp_path):
# Same test as test_bloomz_olora_4bit but with 8 bits.
self.get_errors(bits=8, device=device, tmp_path=tmp_path)
@ -2248,8 +2239,7 @@ class TestLoftQ:
logits_base = self.get_logits(model, inputs)
# clean up
del model
gc.collect()
torch.cuda.empty_cache()
clear_device_cache(garbage_collection=True)
# logits from the normal quantized LoRA model
target_modules = "all-linear" if task_type != TaskType.SEQ_2_SEQ_LM else ["o", "k", "wi", "q", "v"]
@ -2269,8 +2259,7 @@ class TestLoftQ:
torch.manual_seed(0)
logits_quantized = self.get_logits(quantized_model, inputs)
del quantized_model
gc.collect()
torch.cuda.empty_cache()
clear_device_cache(garbage_collection=True)
# logits from quantized LoRA model using LoftQ
loftq_config = LoftQConfig(loftq_bits=bits, loftq_iter=loftq_iter)
@ -2282,11 +2271,11 @@ class TestLoftQ:
target_modules=target_modules,
)
model = self.get_base_model(model_id, device)
if device == "cuda":
model = model.to("cuda")
if device != "cpu":
model = model.to(torch_device)
loftq_model = get_peft_model(model, lora_config)
if device == "cuda":
loftq_model = loftq_model.to("cuda")
if device != "cpu":
loftq_model = loftq_model.to(torch_device)
# save LoRA weights, they should be initialized such that they minimize the quantization error
loftq_model.base_model.peft_config["default"].init_lora_weights = True
@ -2296,8 +2285,7 @@ class TestLoftQ:
loftq_model.save_pretrained(tmp_path / "base_model")
del loftq_model
gc.collect()
torch.cuda.empty_cache()
clear_device_cache(garbage_collection=True)
# now load quantized model and apply LoftQ-initialized weights on top
base_model = self.get_base_model(tmp_path / "base_model", device=None, **kwargs, torch_dtype=torch.float32)
@ -2308,8 +2296,7 @@ class TestLoftQ:
torch.manual_seed(0)
logits_loftq = self.get_logits(loftq_model, inputs)
del loftq_model
gc.collect()
torch.cuda.empty_cache()
clear_device_cache(garbage_collection=True)
mae_quantized = torch.abs(logits_base - logits_quantized).mean()
mse_quantized = torch.pow(logits_base - logits_quantized, 2).mean()
@ -2317,7 +2304,7 @@ class TestLoftQ:
mse_loftq = torch.pow(logits_base - logits_loftq, 2).mean()
return mae_quantized, mse_quantized, mae_loftq, mse_loftq
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_bloomz_loftq_4bit(self, device, tmp_path):
# In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model
# using LoftQ. When quantizing, we expect a certain level of error. However, we expect the LoftQ quantized
@ -2336,7 +2323,7 @@ class TestLoftQ:
assert mse_loftq < (mse_quantized / self.error_factor)
assert mae_loftq < (mae_quantized / self.error_factor)
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_bloomz_loftq_4bit_iter_5(self, device, tmp_path):
# Same test as the previous one but with 5 iterations. We should expect the error to be even smaller with more
# iterations, but in practice the difference is not that large, at least not for this small base model.
@ -2353,7 +2340,7 @@ class TestLoftQ:
assert mse_loftq < (mse_quantized / self.error_factor)
assert mae_loftq < (mae_quantized / self.error_factor)
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_bloomz_loftq_8bit(self, device, tmp_path):
# Same test as test_bloomz_loftq_4bit but with 8 bits.
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=8, device=device, tmp_path=tmp_path)
@ -2368,7 +2355,7 @@ class TestLoftQ:
assert mse_loftq < (mse_quantized / self.error_factor)
assert mae_loftq < (mae_quantized / self.error_factor)
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_bloomz_loftq_8bit_iter_5(self, device, tmp_path):
# Same test as test_bloomz_loftq_4bit_iter_5 but with 8 bits.
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
@ -2385,7 +2372,7 @@ class TestLoftQ:
assert mse_loftq < (mse_quantized / self.error_factor)
assert mae_loftq < (mae_quantized / self.error_factor)
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_t5_loftq_4bit(self, device, tmp_path):
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
bits=4, device=device, model_id="t5-small", tmp_path=tmp_path
@ -2400,7 +2387,7 @@ class TestLoftQ:
assert mse_loftq < (mse_quantized / self.error_factor)
assert mae_loftq < (mae_quantized / self.error_factor)
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_t5_loftq_8bit(self, device, tmp_path):
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
bits=8, device=device, model_id="t5-small", tmp_path=tmp_path
@ -2416,7 +2403,7 @@ class TestLoftQ:
assert mae_loftq < (mae_quantized / self.error_factor)
@pytest.mark.xfail # failing for now, but having DoRA pass is only a nice-to-have, not a must, so we're good
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_bloomz_loftq_4bit_dora(self, device, tmp_path):
# same as test_bloomz_loftq_4bit but with DoRA
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
@ -2433,7 +2420,7 @@ class TestLoftQ:
assert mae_loftq < (mae_quantized / factor)
assert mse_loftq < (mse_quantized / factor)
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("device", [torch_device, "cpu"])
def test_bloomz_loftq_8bit_dora(self, device, tmp_path):
# same as test_bloomz_loftq_8bit but with DoRA
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
@ -2462,7 +2449,7 @@ class TestLoftQ:
"""
torch.manual_seed(0)
model_id = "bigscience/bloomz-560m"
device = "cuda"
device = torch_device
tokenizer = AutoTokenizer.from_pretrained(model_id)
inputs = tokenizer("The dog was", padding=True, return_tensors="pt").to(device)
@ -2513,15 +2500,13 @@ class TestLoftQ:
assert not all(logs)
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clear_device_cache(garbage_collection=True)
def test_replace_lora_weights_with_local_model(self):
# see issue 2020
torch.manual_seed(0)
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
device = "cuda"
device = torch_device
with tempfile.TemporaryDirectory() as tmp_dir:
# save base model locally
@ -2552,9 +2537,7 @@ class TestLoftQ:
replace_lora_weights_loftq(model)
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clear_device_cache(garbage_collection=True)
def test_config_no_loftq_init(self):
with pytest.warns(
@ -2600,6 +2583,8 @@ class MixedPrecisionTests(unittest.TestCase):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif is_xpu_available():
torch.xpu.empty_cache()
gc.collect()
@pytest.mark.single_gpu_tests
@ -3976,8 +3961,7 @@ class TestPTuningReproducibility:
model.save_pretrained(tmp_path)
del model
torch.cuda.empty_cache()
gc.collect()
clear_device_cache(garbage_collection=True)
model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
model = PeftModel.from_pretrained(model, tmp_path)

View File

@ -19,7 +19,7 @@ from torch import nn
from peft.import_utils import is_bnb_available
from peft.optimizers import create_loraplus_optimizer
from .testing_utils import require_bitsandbytes
from .testing_utils import require_bitsandbytes, torch_device
if is_bnb_available():
@ -80,7 +80,7 @@ def test_lora_plus_optimizer_sucess():
"betas": (0.9, 0.999),
"loraplus_weight_decay": 0.0,
}
model: SimpleNet = SimpleNet().cuda()
model: SimpleNet = SimpleNet().to(torch_device)
optim = create_loraplus_optimizer(
model=model,
optimizer_cls=optimizer_cls,
@ -91,9 +91,9 @@ def test_lora_plus_optimizer_sucess():
)
loss = torch.nn.CrossEntropyLoss()
bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())
x = torch.randint(100, (2, 4, 10)).cuda()
x = torch.randint(100, (2, 4, 10)).to(torch_device)
output = model(x).permute(0, 3, 1, 2)
label = torch.randint(16, (2, 4, 10)).cuda()
label = torch.randint(16, (2, 4, 10)).to(torch_device)
loss_value = loss(output, label)
loss_value.backward()
optim.step()

View File

@ -57,7 +57,7 @@ from peft.tuners.tuners_utils import (
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, ModulesToSaveWrapper, infer_device
from peft.utils.constants import DUMMY_MODEL_CONFIG, MIN_TARGET_MODULES_FOR_OPTIMIZATION
from .testing_utils import require_bitsandbytes, require_non_cpu, require_torch_gpu
from .testing_utils import require_bitsandbytes, require_non_cpu
# Implements tests for regex matching logic common for all BaseTuner subclasses, and
@ -271,7 +271,7 @@ class PeftCustomKwargsTester(unittest.TestCase):
)
@parameterized.expand(BNB_TEST_CASES)
@require_torch_gpu
@require_non_cpu
@require_bitsandbytes
def test_maybe_include_all_linear_layers_lora_bnb(
self, model_id, model_type, initial_target_modules, expected_target_modules, quantization

View File

@ -20,6 +20,7 @@ from safetensors import safe_open
from torch import nn
from peft import PeftModel, VBLoRAConfig, get_peft_model
from peft.import_utils import is_xpu_available
class MLP(nn.Module):
@ -189,8 +190,9 @@ class TestVBLoRA:
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_vblora_dtypes(self, dtype):
mlp = self.get_mlp()
if (dtype == torch.bfloat16) and not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()):
pytest.skip("bfloat16 not supported on this system, skipping the test")
if dtype == torch.bfloat16:
if not is_xpu_available() and not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()):
pytest.skip("bfloat16 not supported on this system, skipping the test")
config = VBLoRAConfig(
target_modules=["lin0", "lin1", "lin3"], vector_length=2, num_vectors=10, save_only_topk_weights=False