mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 23:43:47 +08:00
TST Enable BNB tests on XPU (#2396)
This commit is contained in:
@ -23,6 +23,8 @@ import torch
|
||||
import transformers
|
||||
from torch import nn
|
||||
|
||||
from peft.import_utils import is_xpu_available
|
||||
|
||||
|
||||
@contextmanager
|
||||
def gather_params_ctx(param, modifier_rank: int = 0, fwd_module: torch.nn.Module = None):
|
||||
@ -90,8 +92,11 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
|
||||
# BNB requires CUDA weights
|
||||
device = weight.device
|
||||
is_cpu = device.type == torch.device("cpu").type
|
||||
if is_cpu and torch.cuda.is_available():
|
||||
weight = weight.to(torch.device("cuda"))
|
||||
if is_cpu:
|
||||
if torch.cuda.is_available():
|
||||
weight = weight.to(torch.device("cuda"))
|
||||
elif is_xpu_available():
|
||||
weight = weight.to(torch.device("xpu"))
|
||||
|
||||
cls_name = weight.__class__.__name__
|
||||
if cls_name == "Params4bit":
|
||||
|
@ -22,13 +22,14 @@ import os
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from accelerate.utils.memory import clear_device_cache
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.errors import HFValidationError, LocalEntryNotFoundError
|
||||
from safetensors import SafetensorError, safe_open
|
||||
from transformers.utils import cached_file
|
||||
from transformers.utils.hub import get_checkpoint_shard_files
|
||||
|
||||
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
|
||||
from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_xpu_available
|
||||
|
||||
|
||||
class NFQuantizer:
|
||||
@ -201,7 +202,6 @@ def loftq_init(weight: Union[torch.Tensor, torch.nn.Parameter], num_bits: int, r
|
||||
out_feature, in_feature = weight.size()
|
||||
device = weight.device
|
||||
dtype = weight.dtype
|
||||
|
||||
logging.info(
|
||||
f"Weight: ({out_feature}, {in_feature}) | Rank: {reduced_rank} | Num Iter: {num_iter} | Num Bits: {num_bits}"
|
||||
)
|
||||
@ -209,12 +209,12 @@ def loftq_init(weight: Union[torch.Tensor, torch.nn.Parameter], num_bits: int, r
|
||||
quantizer = NFQuantizer(num_bits=num_bits, device=device, method="normal", block_size=64)
|
||||
compute_device = device
|
||||
else:
|
||||
compute_device = "cuda"
|
||||
compute_device = "xpu" if is_xpu_available() else "cuda"
|
||||
|
||||
weight = weight.to(device=compute_device, dtype=torch.float32)
|
||||
res = weight.clone()
|
||||
for i in range(num_iter):
|
||||
torch.cuda.empty_cache()
|
||||
clear_device_cache()
|
||||
# Quantization
|
||||
if num_bits == 4 and is_bnb_4bit_available():
|
||||
qweight = bnb.nn.Params4bit(
|
||||
@ -246,12 +246,12 @@ def _loftq_init_new(qweight, weight, num_bits: int, reduced_rank: int):
|
||||
if not is_bnb_4bit_available():
|
||||
raise ValueError("bitsandbytes 4bit quantization is not available.")
|
||||
|
||||
compute_device = "cuda"
|
||||
compute_device = "xpu" if is_xpu_available() else "cuda"
|
||||
dequantized_weight = bnb.functional.dequantize_4bit(qweight.data, qweight.quant_state)
|
||||
|
||||
weight = weight.to(device=compute_device, dtype=torch.float32)
|
||||
residual = weight - dequantized_weight
|
||||
torch.cuda.empty_cache()
|
||||
clear_device_cache()
|
||||
# Decompose the residualidual by SVD
|
||||
output = _low_rank_decomposition(residual, reduced_rank=reduced_rank)
|
||||
L, R, reduced_rank = output["L"], output["R"], output["reduced_rank"]
|
||||
|
@ -27,10 +27,12 @@ import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig
|
||||
|
||||
from peft.import_utils import is_xpu_available
|
||||
|
||||
|
||||
bnb = pytest.importorskip("bitsandbytes")
|
||||
|
||||
device = torch.device("cuda")
|
||||
device = torch.device("xpu") if is_xpu_available() else torch.device("cuda")
|
||||
|
||||
|
||||
def bytes_from_tensor(x):
|
||||
@ -47,7 +49,7 @@ def bytes_from_tensor(x):
|
||||
############
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
|
||||
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
|
||||
def test_opt_350m_4bit():
|
||||
torch.manual_seed(0)
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
@ -70,7 +72,7 @@ def test_opt_350m_4bit():
|
||||
torch.testing.assert_allclose(output, expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
|
||||
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
|
||||
def test_opt_350m_8bit():
|
||||
torch.manual_seed(0)
|
||||
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
@ -89,7 +91,7 @@ def test_opt_350m_8bit():
|
||||
torch.testing.assert_allclose(output, expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
|
||||
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
|
||||
def test_opt_350m_4bit_double_quant():
|
||||
torch.manual_seed(0)
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
@ -112,7 +114,7 @@ def test_opt_350m_4bit_double_quant():
|
||||
torch.testing.assert_allclose(output, expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
|
||||
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
|
||||
def test_opt_350m_4bit_compute_dtype_float16():
|
||||
torch.manual_seed(0)
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
@ -135,7 +137,7 @@ def test_opt_350m_4bit_compute_dtype_float16():
|
||||
torch.testing.assert_allclose(output, expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
|
||||
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
|
||||
def test_opt_350m_4bit_quant_type_nf4():
|
||||
torch.manual_seed(0)
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
@ -159,7 +161,7 @@ def test_opt_350m_4bit_quant_type_nf4():
|
||||
torch.testing.assert_allclose(output, expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
|
||||
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
|
||||
def test_opt_350m_4bit_quant_storage():
|
||||
# note: using torch.float32 instead of the default torch.uint8 does not seem to affect the result
|
||||
torch.manual_seed(0)
|
||||
@ -184,7 +186,7 @@ def test_opt_350m_4bit_quant_storage():
|
||||
torch.testing.assert_allclose(output, expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
|
||||
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
|
||||
def test_opt_350m_8bit_threshold():
|
||||
torch.manual_seed(0)
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
@ -211,7 +213,7 @@ def test_opt_350m_8bit_threshold():
|
||||
###########
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
|
||||
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
|
||||
def test_flan_t5_4bit():
|
||||
torch.manual_seed(0)
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
@ -235,7 +237,7 @@ def test_flan_t5_4bit():
|
||||
torch.testing.assert_allclose(output, expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.")
|
||||
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.")
|
||||
@pytest.mark.xfail # might not be reproducible depending on hardware
|
||||
def test_flan_t5_8bit():
|
||||
torch.manual_seed(0)
|
||||
|
@ -63,7 +63,6 @@ from .testing_utils import (
|
||||
require_multi_accelerator,
|
||||
require_non_cpu,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
)
|
||||
|
||||
|
||||
@ -594,7 +593,7 @@ class PeftGPUCommonTests(unittest.TestCase):
|
||||
# this should work without any problem
|
||||
_ = model.generate(input_ids=input_ids)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@require_multi_accelerator
|
||||
@pytest.mark.multi_gpu_tests
|
||||
@require_bitsandbytes
|
||||
def test_lora_seq2seq_lm_multi_gpu_inference(self):
|
||||
@ -622,7 +621,7 @@ class PeftGPUCommonTests(unittest.TestCase):
|
||||
# this should work without any problem
|
||||
_ = model.generate(input_ids=input_ids)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@require_multi_accelerator
|
||||
@pytest.mark.multi_gpu_tests
|
||||
@require_bitsandbytes
|
||||
def test_adaption_prompt_8bit(self):
|
||||
@ -645,7 +644,7 @@ class PeftGPUCommonTests(unittest.TestCase):
|
||||
random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
|
||||
_ = model(random_input)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
@require_multi_accelerator
|
||||
@pytest.mark.multi_gpu_tests
|
||||
@require_bitsandbytes
|
||||
def test_adaption_prompt_4bit(self):
|
||||
@ -668,7 +667,7 @@ class PeftGPUCommonTests(unittest.TestCase):
|
||||
random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
|
||||
_ = model(random_input)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_non_cpu
|
||||
@pytest.mark.single_gpu_tests
|
||||
@require_bitsandbytes
|
||||
def test_print_4bit_expected(self):
|
||||
@ -778,7 +777,7 @@ class PeftGPUCommonTests(unittest.TestCase):
|
||||
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, bnb.nn.Linear8bitLt)
|
||||
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, bnb.nn.Linear8bitLt)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_non_cpu
|
||||
@pytest.mark.single_gpu_tests
|
||||
@require_bitsandbytes
|
||||
def test_8bit_merge_and_disable_lora(self):
|
||||
@ -814,7 +813,7 @@ class PeftGPUCommonTests(unittest.TestCase):
|
||||
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear8bitLt)
|
||||
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear8bitLt)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_non_cpu
|
||||
@pytest.mark.single_gpu_tests
|
||||
@require_bitsandbytes
|
||||
def test_8bit_merge_lora_with_bias(self):
|
||||
@ -846,7 +845,7 @@ class PeftGPUCommonTests(unittest.TestCase):
|
||||
assert not torch.allclose(out_base, out_before_merge, atol=atol, rtol=rtol)
|
||||
assert torch.allclose(out_before_merge, out_after_merge, atol=atol, rtol=rtol)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_non_cpu
|
||||
@pytest.mark.single_gpu_tests
|
||||
@require_bitsandbytes
|
||||
def test_4bit_merge_lora(self):
|
||||
@ -888,7 +887,7 @@ class PeftGPUCommonTests(unittest.TestCase):
|
||||
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, bnb.nn.Linear4bit)
|
||||
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, bnb.nn.Linear4bit)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_non_cpu
|
||||
@pytest.mark.single_gpu_tests
|
||||
@require_bitsandbytes
|
||||
def test_4bit_merge_and_disable_lora(self):
|
||||
@ -930,7 +929,7 @@ class PeftGPUCommonTests(unittest.TestCase):
|
||||
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear4bit)
|
||||
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear4bit)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_non_cpu
|
||||
@pytest.mark.single_gpu_tests
|
||||
@require_bitsandbytes
|
||||
def test_4bit_merge_lora_with_bias(self):
|
||||
@ -971,7 +970,7 @@ class PeftGPUCommonTests(unittest.TestCase):
|
||||
assert not torch.allclose(out_base, out_before_merge, atol=atol, rtol=rtol)
|
||||
assert torch.allclose(out_before_merge, out_after_merge, atol=atol, rtol=rtol)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_non_cpu
|
||||
@pytest.mark.single_gpu_tests
|
||||
@require_bitsandbytes
|
||||
def test_4bit_lora_mixed_adapter_batches_lora(self):
|
||||
@ -1042,7 +1041,7 @@ class PeftGPUCommonTests(unittest.TestCase):
|
||||
assert torch.allclose(out_adapter0[1::3], out_mixed[1::3], atol=atol, rtol=rtol)
|
||||
assert torch.allclose(out_adapter1[2::3], out_mixed[2::3], atol=atol, rtol=rtol)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_non_cpu
|
||||
@pytest.mark.single_gpu_tests
|
||||
@require_bitsandbytes
|
||||
def test_8bit_lora_mixed_adapter_batches_lora(self):
|
||||
@ -1124,7 +1123,7 @@ class PeftGPUCommonTests(unittest.TestCase):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_non_cpu
|
||||
@pytest.mark.single_gpu_tests
|
||||
@require_bitsandbytes
|
||||
def test_4bit_dora_inference(self):
|
||||
@ -1163,7 +1162,7 @@ class PeftGPUCommonTests(unittest.TestCase):
|
||||
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear4bit)
|
||||
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear4bit)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_non_cpu
|
||||
@pytest.mark.single_gpu_tests
|
||||
@require_bitsandbytes
|
||||
def test_8bit_dora_inference(self):
|
||||
@ -1197,7 +1196,7 @@ class PeftGPUCommonTests(unittest.TestCase):
|
||||
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear8bitLt)
|
||||
assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear8bitLt)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_non_cpu
|
||||
@pytest.mark.single_gpu_tests
|
||||
@require_bitsandbytes
|
||||
def test_4bit_dora_merging(self):
|
||||
|
@ -27,6 +27,7 @@ import torch
|
||||
from accelerate import infer_auto_device_map
|
||||
from accelerate.test_utils.testing import run_command
|
||||
from accelerate.utils import patch_environment
|
||||
from accelerate.utils.memory import clear_device_cache
|
||||
from datasets import Audio, Dataset, DatasetDict, load_dataset
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
@ -91,6 +92,7 @@ from .testing_utils import (
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torchao,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@ -131,7 +133,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
|
||||
return batch
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_non_cpu
|
||||
@require_bitsandbytes
|
||||
class PeftBnbGPUExampleTests(unittest.TestCase):
|
||||
r"""
|
||||
@ -160,10 +162,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
|
||||
Efficient mechanism to free GPU memory after each test. Based on
|
||||
https://github.com/huggingface/transformers/issues/21094
|
||||
"""
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
clear_device_cache(garbage_collection=True)
|
||||
|
||||
def _check_inference_finite(self, model, batch):
|
||||
# try inference without Trainer class
|
||||
@ -351,7 +350,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
@pytest.mark.single_gpu_tests
|
||||
@require_torch_gpu
|
||||
@require_non_cpu
|
||||
def test_4bit_adalora_causalLM(self):
|
||||
r"""
|
||||
Tests the 4bit training with adalora
|
||||
@ -1504,8 +1503,7 @@ class PeftGPTQGPUTests(unittest.TestCase):
|
||||
Efficient mechanism to free GPU memory after each test. Based on
|
||||
https://github.com/huggingface/transformers/issues/21094
|
||||
"""
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
clear_device_cache(garbage_collection=True)
|
||||
|
||||
def _check_inference_finite(self, model, batch):
|
||||
# try inference without Trainer class
|
||||
@ -1752,8 +1750,7 @@ class OffloadSaveTests(unittest.TestCase):
|
||||
Efficient mechanism to free GPU memory after each test. Based on
|
||||
https://github.com/huggingface/transformers/issues/21094
|
||||
"""
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
clear_device_cache(garbage_collection=True)
|
||||
|
||||
def test_offload_load(self):
|
||||
r"""
|
||||
@ -1832,7 +1829,7 @@ class OffloadSaveTests(unittest.TestCase):
|
||||
assert torch.allclose(post_unload_merge_olayer, pre_merge_olayer)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
|
||||
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="test requires a GPU")
|
||||
@pytest.mark.single_gpu_tests
|
||||
class TestPiSSA:
|
||||
r"""
|
||||
@ -1888,8 +1885,7 @@ class TestPiSSA:
|
||||
qlora_model = qlora_model.merge_and_unload()
|
||||
qlora_error = self.nuclear_norm(base_model, qlora_model)
|
||||
del qlora_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
clear_device_cache(garbage_collection=True)
|
||||
|
||||
# logits from quantized LoRA model using PiSSA
|
||||
lora_config = LoraConfig(
|
||||
@ -1908,8 +1904,7 @@ class TestPiSSA:
|
||||
pissa_model.save_pretrained(tmp_path / "residual_model")
|
||||
|
||||
del pissa_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
clear_device_cache(garbage_collection=True)
|
||||
|
||||
# now load quantized model and apply PiSSA-initialized weights on top
|
||||
qpissa_model = self.quantize_model(
|
||||
@ -1919,8 +1914,7 @@ class TestPiSSA:
|
||||
qpissa_model = qpissa_model.merge_and_unload()
|
||||
qpissa_error = self.nuclear_norm(base_model, qpissa_model)
|
||||
del qpissa_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
clear_device_cache(garbage_collection=True)
|
||||
|
||||
assert qlora_error > 0.0
|
||||
assert qpissa_error > 0.0
|
||||
@ -1928,7 +1922,7 @@ class TestPiSSA:
|
||||
# next, check that PiSSA quantization errors are smaller than LoRA errors by a certain margin
|
||||
assert qpissa_error < (qlora_error / self.error_factor)
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("device", [torch_device, "cpu"])
|
||||
def test_bloomz_pissa_4bit(self, device, tmp_path):
|
||||
# In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model
|
||||
# using PiSSA. When quantizing, we expect a certain level of error. However, we expect the PiSSA quantized
|
||||
@ -1938,25 +1932,25 @@ class TestPiSSA:
|
||||
|
||||
self.get_errors(bits=4, device=device, tmp_path=tmp_path)
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("device", [torch_device, "cpu"])
|
||||
def test_bloomz_pissa_8bit(self, device, tmp_path):
|
||||
# Same test as test_bloomz_pissa_4bit but with 8 bits.
|
||||
self.get_errors(bits=8, device=device, tmp_path=tmp_path)
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("device", [torch_device, "cpu"])
|
||||
def test_t5_pissa_4bit(self, device, tmp_path):
|
||||
self.get_errors(bits=4, device=device, model_id="t5-small", tmp_path=tmp_path)
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("device", [torch_device, "cpu"])
|
||||
def test_t5_pissa_8bit(self, device, tmp_path):
|
||||
self.get_errors(bits=8, device=device, model_id="t5-small", tmp_path=tmp_path)
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("device", [torch_device, "cpu"])
|
||||
def test_gpt2_pissa_4bit(self, device, tmp_path):
|
||||
# see 2104
|
||||
self.get_errors(bits=4, device=device, model_id="gpt2", tmp_path=tmp_path)
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("device", [torch_device, "cpu"])
|
||||
def test_gpt2_pissa_8bit(self, device, tmp_path):
|
||||
# see 2104
|
||||
self.get_errors(bits=8, device=device, model_id="gpt2", tmp_path=tmp_path)
|
||||
@ -1968,7 +1962,7 @@ class TestPiSSA:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
torch.manual_seed(0)
|
||||
data = torch.rand(10, 1000).to("cuda")
|
||||
data = torch.rand(10, 1000).to(torch_device)
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -1983,7 +1977,7 @@ class TestPiSSA:
|
||||
x_4d = x.flatten().reshape(1, 100, 10, 10)
|
||||
return self.linear(x), self.embed(x_int), self.conv2d(x_4d)
|
||||
|
||||
model = MyModule().to("cuda")
|
||||
model = MyModule().to(torch_device)
|
||||
output_base = model(data)[0]
|
||||
|
||||
config = LoraConfig(init_lora_weights="pissa", target_modules=["linear"], r=8)
|
||||
@ -2047,7 +2041,7 @@ class TestPiSSA:
|
||||
assert not torch.allclose(output_finetuned_pissa, output_converted, atol=tol, rtol=tol)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
|
||||
@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="test requires a GPU")
|
||||
@pytest.mark.single_gpu_tests
|
||||
class TestOLoRA:
|
||||
r"""
|
||||
@ -2103,8 +2097,7 @@ class TestOLoRA:
|
||||
qlora_model = qlora_model.merge_and_unload()
|
||||
qlora_error = self.nuclear_norm(base_model, qlora_model)
|
||||
del qlora_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
clear_device_cache(garbage_collection=True)
|
||||
|
||||
# logits from quantized LoRA model using OLoRA
|
||||
lora_config = LoraConfig(
|
||||
@ -2123,8 +2116,7 @@ class TestOLoRA:
|
||||
olora_model.save_pretrained(tmp_path / "residual_model")
|
||||
|
||||
del olora_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
clear_device_cache(garbage_collection=True)
|
||||
|
||||
# now load quantized model and apply OLoRA-initialized weights on top
|
||||
qolora_model = self.quantize_model(
|
||||
@ -2134,8 +2126,7 @@ class TestOLoRA:
|
||||
qolora_model = qolora_model.merge_and_unload()
|
||||
qolora_error = self.nuclear_norm(base_model, qolora_model)
|
||||
del qolora_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
clear_device_cache(garbage_collection=True)
|
||||
|
||||
assert qlora_error > 0.0
|
||||
assert qolora_error > 0.0
|
||||
@ -2143,7 +2134,7 @@ class TestOLoRA:
|
||||
# next, check that OLoRA quantization errors are smaller than LoRA errors by a certain margin
|
||||
assert qolora_error < (qlora_error / self.error_factor)
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("device", [torch_device, "cpu"])
|
||||
def test_bloomz_olora_4bit(self, device, tmp_path):
|
||||
# In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model
|
||||
# using OLoRA. When quantizing, we expect a certain level of error. However, we expect the OLoRA quantized
|
||||
@ -2153,7 +2144,7 @@ class TestOLoRA:
|
||||
|
||||
self.get_errors(bits=4, device=device, tmp_path=tmp_path)
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("device", [torch_device, "cpu"])
|
||||
def test_bloomz_olora_8bit(self, device, tmp_path):
|
||||
# Same test as test_bloomz_olora_4bit but with 8 bits.
|
||||
self.get_errors(bits=8, device=device, tmp_path=tmp_path)
|
||||
@ -2248,8 +2239,7 @@ class TestLoftQ:
|
||||
logits_base = self.get_logits(model, inputs)
|
||||
# clean up
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
clear_device_cache(garbage_collection=True)
|
||||
|
||||
# logits from the normal quantized LoRA model
|
||||
target_modules = "all-linear" if task_type != TaskType.SEQ_2_SEQ_LM else ["o", "k", "wi", "q", "v"]
|
||||
@ -2269,8 +2259,7 @@ class TestLoftQ:
|
||||
torch.manual_seed(0)
|
||||
logits_quantized = self.get_logits(quantized_model, inputs)
|
||||
del quantized_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
clear_device_cache(garbage_collection=True)
|
||||
|
||||
# logits from quantized LoRA model using LoftQ
|
||||
loftq_config = LoftQConfig(loftq_bits=bits, loftq_iter=loftq_iter)
|
||||
@ -2282,11 +2271,11 @@ class TestLoftQ:
|
||||
target_modules=target_modules,
|
||||
)
|
||||
model = self.get_base_model(model_id, device)
|
||||
if device == "cuda":
|
||||
model = model.to("cuda")
|
||||
if device != "cpu":
|
||||
model = model.to(torch_device)
|
||||
loftq_model = get_peft_model(model, lora_config)
|
||||
if device == "cuda":
|
||||
loftq_model = loftq_model.to("cuda")
|
||||
if device != "cpu":
|
||||
loftq_model = loftq_model.to(torch_device)
|
||||
|
||||
# save LoRA weights, they should be initialized such that they minimize the quantization error
|
||||
loftq_model.base_model.peft_config["default"].init_lora_weights = True
|
||||
@ -2296,8 +2285,7 @@ class TestLoftQ:
|
||||
loftq_model.save_pretrained(tmp_path / "base_model")
|
||||
|
||||
del loftq_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
clear_device_cache(garbage_collection=True)
|
||||
|
||||
# now load quantized model and apply LoftQ-initialized weights on top
|
||||
base_model = self.get_base_model(tmp_path / "base_model", device=None, **kwargs, torch_dtype=torch.float32)
|
||||
@ -2308,8 +2296,7 @@ class TestLoftQ:
|
||||
torch.manual_seed(0)
|
||||
logits_loftq = self.get_logits(loftq_model, inputs)
|
||||
del loftq_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
clear_device_cache(garbage_collection=True)
|
||||
|
||||
mae_quantized = torch.abs(logits_base - logits_quantized).mean()
|
||||
mse_quantized = torch.pow(logits_base - logits_quantized, 2).mean()
|
||||
@ -2317,7 +2304,7 @@ class TestLoftQ:
|
||||
mse_loftq = torch.pow(logits_base - logits_loftq, 2).mean()
|
||||
return mae_quantized, mse_quantized, mae_loftq, mse_loftq
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("device", [torch_device, "cpu"])
|
||||
def test_bloomz_loftq_4bit(self, device, tmp_path):
|
||||
# In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model
|
||||
# using LoftQ. When quantizing, we expect a certain level of error. However, we expect the LoftQ quantized
|
||||
@ -2336,7 +2323,7 @@ class TestLoftQ:
|
||||
assert mse_loftq < (mse_quantized / self.error_factor)
|
||||
assert mae_loftq < (mae_quantized / self.error_factor)
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("device", [torch_device, "cpu"])
|
||||
def test_bloomz_loftq_4bit_iter_5(self, device, tmp_path):
|
||||
# Same test as the previous one but with 5 iterations. We should expect the error to be even smaller with more
|
||||
# iterations, but in practice the difference is not that large, at least not for this small base model.
|
||||
@ -2353,7 +2340,7 @@ class TestLoftQ:
|
||||
assert mse_loftq < (mse_quantized / self.error_factor)
|
||||
assert mae_loftq < (mae_quantized / self.error_factor)
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("device", [torch_device, "cpu"])
|
||||
def test_bloomz_loftq_8bit(self, device, tmp_path):
|
||||
# Same test as test_bloomz_loftq_4bit but with 8 bits.
|
||||
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=8, device=device, tmp_path=tmp_path)
|
||||
@ -2368,7 +2355,7 @@ class TestLoftQ:
|
||||
assert mse_loftq < (mse_quantized / self.error_factor)
|
||||
assert mae_loftq < (mae_quantized / self.error_factor)
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("device", [torch_device, "cpu"])
|
||||
def test_bloomz_loftq_8bit_iter_5(self, device, tmp_path):
|
||||
# Same test as test_bloomz_loftq_4bit_iter_5 but with 8 bits.
|
||||
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
|
||||
@ -2385,7 +2372,7 @@ class TestLoftQ:
|
||||
assert mse_loftq < (mse_quantized / self.error_factor)
|
||||
assert mae_loftq < (mae_quantized / self.error_factor)
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("device", [torch_device, "cpu"])
|
||||
def test_t5_loftq_4bit(self, device, tmp_path):
|
||||
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
|
||||
bits=4, device=device, model_id="t5-small", tmp_path=tmp_path
|
||||
@ -2400,7 +2387,7 @@ class TestLoftQ:
|
||||
assert mse_loftq < (mse_quantized / self.error_factor)
|
||||
assert mae_loftq < (mae_quantized / self.error_factor)
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("device", [torch_device, "cpu"])
|
||||
def test_t5_loftq_8bit(self, device, tmp_path):
|
||||
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
|
||||
bits=8, device=device, model_id="t5-small", tmp_path=tmp_path
|
||||
@ -2416,7 +2403,7 @@ class TestLoftQ:
|
||||
assert mae_loftq < (mae_quantized / self.error_factor)
|
||||
|
||||
@pytest.mark.xfail # failing for now, but having DoRA pass is only a nice-to-have, not a must, so we're good
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("device", [torch_device, "cpu"])
|
||||
def test_bloomz_loftq_4bit_dora(self, device, tmp_path):
|
||||
# same as test_bloomz_loftq_4bit but with DoRA
|
||||
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
|
||||
@ -2433,7 +2420,7 @@ class TestLoftQ:
|
||||
assert mae_loftq < (mae_quantized / factor)
|
||||
assert mse_loftq < (mse_quantized / factor)
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "cpu"])
|
||||
@pytest.mark.parametrize("device", [torch_device, "cpu"])
|
||||
def test_bloomz_loftq_8bit_dora(self, device, tmp_path):
|
||||
# same as test_bloomz_loftq_8bit but with DoRA
|
||||
mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(
|
||||
@ -2462,7 +2449,7 @@ class TestLoftQ:
|
||||
"""
|
||||
torch.manual_seed(0)
|
||||
model_id = "bigscience/bloomz-560m"
|
||||
device = "cuda"
|
||||
device = torch_device
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
inputs = tokenizer("The dog was", padding=True, return_tensors="pt").to(device)
|
||||
|
||||
@ -2513,15 +2500,13 @@ class TestLoftQ:
|
||||
assert not all(logs)
|
||||
|
||||
del model
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
clear_device_cache(garbage_collection=True)
|
||||
|
||||
def test_replace_lora_weights_with_local_model(self):
|
||||
# see issue 2020
|
||||
torch.manual_seed(0)
|
||||
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
|
||||
device = "cuda"
|
||||
device = torch_device
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# save base model locally
|
||||
@ -2552,9 +2537,7 @@ class TestLoftQ:
|
||||
replace_lora_weights_loftq(model)
|
||||
|
||||
del model
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
clear_device_cache(garbage_collection=True)
|
||||
|
||||
def test_config_no_loftq_init(self):
|
||||
with pytest.warns(
|
||||
@ -2600,6 +2583,8 @@ class MixedPrecisionTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif is_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
@pytest.mark.single_gpu_tests
|
||||
@ -3976,8 +3961,7 @@ class TestPTuningReproducibility:
|
||||
|
||||
model.save_pretrained(tmp_path)
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
clear_device_cache(garbage_collection=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
|
||||
model = PeftModel.from_pretrained(model, tmp_path)
|
||||
|
@ -19,7 +19,7 @@ from torch import nn
|
||||
from peft.import_utils import is_bnb_available
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
|
||||
from .testing_utils import require_bitsandbytes
|
||||
from .testing_utils import require_bitsandbytes, torch_device
|
||||
|
||||
|
||||
if is_bnb_available():
|
||||
@ -80,7 +80,7 @@ def test_lora_plus_optimizer_sucess():
|
||||
"betas": (0.9, 0.999),
|
||||
"loraplus_weight_decay": 0.0,
|
||||
}
|
||||
model: SimpleNet = SimpleNet().cuda()
|
||||
model: SimpleNet = SimpleNet().to(torch_device)
|
||||
optim = create_loraplus_optimizer(
|
||||
model=model,
|
||||
optimizer_cls=optimizer_cls,
|
||||
@ -91,9 +91,9 @@ def test_lora_plus_optimizer_sucess():
|
||||
)
|
||||
loss = torch.nn.CrossEntropyLoss()
|
||||
bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())
|
||||
x = torch.randint(100, (2, 4, 10)).cuda()
|
||||
x = torch.randint(100, (2, 4, 10)).to(torch_device)
|
||||
output = model(x).permute(0, 3, 1, 2)
|
||||
label = torch.randint(16, (2, 4, 10)).cuda()
|
||||
label = torch.randint(16, (2, 4, 10)).to(torch_device)
|
||||
loss_value = loss(output, label)
|
||||
loss_value.backward()
|
||||
optim.step()
|
||||
|
@ -57,7 +57,7 @@ from peft.tuners.tuners_utils import (
|
||||
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, ModulesToSaveWrapper, infer_device
|
||||
from peft.utils.constants import DUMMY_MODEL_CONFIG, MIN_TARGET_MODULES_FOR_OPTIMIZATION
|
||||
|
||||
from .testing_utils import require_bitsandbytes, require_non_cpu, require_torch_gpu
|
||||
from .testing_utils import require_bitsandbytes, require_non_cpu
|
||||
|
||||
|
||||
# Implements tests for regex matching logic common for all BaseTuner subclasses, and
|
||||
@ -271,7 +271,7 @@ class PeftCustomKwargsTester(unittest.TestCase):
|
||||
)
|
||||
|
||||
@parameterized.expand(BNB_TEST_CASES)
|
||||
@require_torch_gpu
|
||||
@require_non_cpu
|
||||
@require_bitsandbytes
|
||||
def test_maybe_include_all_linear_layers_lora_bnb(
|
||||
self, model_id, model_type, initial_target_modules, expected_target_modules, quantization
|
||||
|
@ -20,6 +20,7 @@ from safetensors import safe_open
|
||||
from torch import nn
|
||||
|
||||
from peft import PeftModel, VBLoRAConfig, get_peft_model
|
||||
from peft.import_utils import is_xpu_available
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
@ -189,8 +190,9 @@ class TestVBLoRA:
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
def test_vblora_dtypes(self, dtype):
|
||||
mlp = self.get_mlp()
|
||||
if (dtype == torch.bfloat16) and not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()):
|
||||
pytest.skip("bfloat16 not supported on this system, skipping the test")
|
||||
if dtype == torch.bfloat16:
|
||||
if not is_xpu_available() and not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()):
|
||||
pytest.skip("bfloat16 not supported on this system, skipping the test")
|
||||
|
||||
config = VBLoRAConfig(
|
||||
target_modules=["lin0", "lin1", "lin3"], vector_length=2, num_vectors=10, save_only_topk_weights=False
|
||||
|
Reference in New Issue
Block a user