diff --git a/src/peft/utils/integrations.py b/src/peft/utils/integrations.py index 5c23f404..7d472870 100644 --- a/src/peft/utils/integrations.py +++ b/src/peft/utils/integrations.py @@ -23,6 +23,8 @@ import torch import transformers from torch import nn +from peft.import_utils import is_xpu_available + @contextmanager def gather_params_ctx(param, modifier_rank: int = 0, fwd_module: torch.nn.Module = None): @@ -90,8 +92,11 @@ def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None): # BNB requires CUDA weights device = weight.device is_cpu = device.type == torch.device("cpu").type - if is_cpu and torch.cuda.is_available(): - weight = weight.to(torch.device("cuda")) + if is_cpu: + if torch.cuda.is_available(): + weight = weight.to(torch.device("cuda")) + elif is_xpu_available(): + weight = weight.to(torch.device("xpu")) cls_name = weight.__class__.__name__ if cls_name == "Params4bit": diff --git a/src/peft/utils/loftq_utils.py b/src/peft/utils/loftq_utils.py index 2f143f30..1019326e 100644 --- a/src/peft/utils/loftq_utils.py +++ b/src/peft/utils/loftq_utils.py @@ -22,13 +22,14 @@ import os from typing import Callable, Optional, Union import torch +from accelerate.utils.memory import clear_device_cache from huggingface_hub import snapshot_download from huggingface_hub.errors import HFValidationError, LocalEntryNotFoundError from safetensors import SafetensorError, safe_open from transformers.utils import cached_file from transformers.utils.hub import get_checkpoint_shard_files -from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_xpu_available class NFQuantizer: @@ -201,7 +202,6 @@ def loftq_init(weight: Union[torch.Tensor, torch.nn.Parameter], num_bits: int, r out_feature, in_feature = weight.size() device = weight.device dtype = weight.dtype - logging.info( f"Weight: ({out_feature}, {in_feature}) | Rank: {reduced_rank} | Num Iter: {num_iter} | Num Bits: {num_bits}" ) @@ -209,12 +209,12 @@ def loftq_init(weight: Union[torch.Tensor, torch.nn.Parameter], num_bits: int, r quantizer = NFQuantizer(num_bits=num_bits, device=device, method="normal", block_size=64) compute_device = device else: - compute_device = "cuda" + compute_device = "xpu" if is_xpu_available() else "cuda" weight = weight.to(device=compute_device, dtype=torch.float32) res = weight.clone() for i in range(num_iter): - torch.cuda.empty_cache() + clear_device_cache() # Quantization if num_bits == 4 and is_bnb_4bit_available(): qweight = bnb.nn.Params4bit( @@ -246,12 +246,12 @@ def _loftq_init_new(qweight, weight, num_bits: int, reduced_rank: int): if not is_bnb_4bit_available(): raise ValueError("bitsandbytes 4bit quantization is not available.") - compute_device = "cuda" + compute_device = "xpu" if is_xpu_available() else "cuda" dequantized_weight = bnb.functional.dequantize_4bit(qweight.data, qweight.quant_state) weight = weight.to(device=compute_device, dtype=torch.float32) residual = weight - dequantized_weight - torch.cuda.empty_cache() + clear_device_cache() # Decompose the residualidual by SVD output = _low_rank_decomposition(residual, reduced_rank=reduced_rank) L, R, reduced_rank = output["L"], output["R"], output["reduced_rank"] diff --git a/tests/bnb/test_bnb_regression.py b/tests/bnb/test_bnb_regression.py index 8fe628e7..68541974 100644 --- a/tests/bnb/test_bnb_regression.py +++ b/tests/bnb/test_bnb_regression.py @@ -27,10 +27,12 @@ import pytest import torch from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, BitsAndBytesConfig +from peft.import_utils import is_xpu_available + bnb = pytest.importorskip("bitsandbytes") -device = torch.device("cuda") +device = torch.device("xpu") if is_xpu_available() else torch.device("cuda") def bytes_from_tensor(x): @@ -47,7 +49,7 @@ def bytes_from_tensor(x): ############ -@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.") +@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.") def test_opt_350m_4bit(): torch.manual_seed(0) bnb_config = BitsAndBytesConfig( @@ -70,7 +72,7 @@ def test_opt_350m_4bit(): torch.testing.assert_allclose(output, expected) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.") +@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.") def test_opt_350m_8bit(): torch.manual_seed(0) bnb_config = BitsAndBytesConfig(load_in_8bit=True) @@ -89,7 +91,7 @@ def test_opt_350m_8bit(): torch.testing.assert_allclose(output, expected) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.") +@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.") def test_opt_350m_4bit_double_quant(): torch.manual_seed(0) bnb_config = BitsAndBytesConfig( @@ -112,7 +114,7 @@ def test_opt_350m_4bit_double_quant(): torch.testing.assert_allclose(output, expected) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.") +@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.") def test_opt_350m_4bit_compute_dtype_float16(): torch.manual_seed(0) bnb_config = BitsAndBytesConfig( @@ -135,7 +137,7 @@ def test_opt_350m_4bit_compute_dtype_float16(): torch.testing.assert_allclose(output, expected) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.") +@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.") def test_opt_350m_4bit_quant_type_nf4(): torch.manual_seed(0) bnb_config = BitsAndBytesConfig( @@ -159,7 +161,7 @@ def test_opt_350m_4bit_quant_type_nf4(): torch.testing.assert_allclose(output, expected) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.") +@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.") def test_opt_350m_4bit_quant_storage(): # note: using torch.float32 instead of the default torch.uint8 does not seem to affect the result torch.manual_seed(0) @@ -184,7 +186,7 @@ def test_opt_350m_4bit_quant_storage(): torch.testing.assert_allclose(output, expected) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.") +@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.") def test_opt_350m_8bit_threshold(): torch.manual_seed(0) bnb_config = BitsAndBytesConfig( @@ -211,7 +213,7 @@ def test_opt_350m_8bit_threshold(): ########### -@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.") +@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.") def test_flan_t5_4bit(): torch.manual_seed(0) bnb_config = BitsAndBytesConfig( @@ -235,7 +237,7 @@ def test_flan_t5_4bit(): torch.testing.assert_allclose(output, expected) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device available.") +@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="No CUDA device available.") @pytest.mark.xfail # might not be reproducible depending on hardware def test_flan_t5_8bit(): torch.manual_seed(0) diff --git a/tests/test_common_gpu.py b/tests/test_common_gpu.py index b366c5e3..e40ba92a 100644 --- a/tests/test_common_gpu.py +++ b/tests/test_common_gpu.py @@ -63,7 +63,6 @@ from .testing_utils import ( require_multi_accelerator, require_non_cpu, require_torch_gpu, - require_torch_multi_gpu, ) @@ -594,7 +593,7 @@ class PeftGPUCommonTests(unittest.TestCase): # this should work without any problem _ = model.generate(input_ids=input_ids) - @require_torch_multi_gpu + @require_multi_accelerator @pytest.mark.multi_gpu_tests @require_bitsandbytes def test_lora_seq2seq_lm_multi_gpu_inference(self): @@ -622,7 +621,7 @@ class PeftGPUCommonTests(unittest.TestCase): # this should work without any problem _ = model.generate(input_ids=input_ids) - @require_torch_multi_gpu + @require_multi_accelerator @pytest.mark.multi_gpu_tests @require_bitsandbytes def test_adaption_prompt_8bit(self): @@ -645,7 +644,7 @@ class PeftGPUCommonTests(unittest.TestCase): random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device) _ = model(random_input) - @require_torch_multi_gpu + @require_multi_accelerator @pytest.mark.multi_gpu_tests @require_bitsandbytes def test_adaption_prompt_4bit(self): @@ -668,7 +667,7 @@ class PeftGPUCommonTests(unittest.TestCase): random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device) _ = model(random_input) - @require_torch_gpu + @require_non_cpu @pytest.mark.single_gpu_tests @require_bitsandbytes def test_print_4bit_expected(self): @@ -778,7 +777,7 @@ class PeftGPUCommonTests(unittest.TestCase): assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, bnb.nn.Linear8bitLt) assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, bnb.nn.Linear8bitLt) - @require_torch_gpu + @require_non_cpu @pytest.mark.single_gpu_tests @require_bitsandbytes def test_8bit_merge_and_disable_lora(self): @@ -814,7 +813,7 @@ class PeftGPUCommonTests(unittest.TestCase): assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear8bitLt) assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear8bitLt) - @require_torch_gpu + @require_non_cpu @pytest.mark.single_gpu_tests @require_bitsandbytes def test_8bit_merge_lora_with_bias(self): @@ -846,7 +845,7 @@ class PeftGPUCommonTests(unittest.TestCase): assert not torch.allclose(out_base, out_before_merge, atol=atol, rtol=rtol) assert torch.allclose(out_before_merge, out_after_merge, atol=atol, rtol=rtol) - @require_torch_gpu + @require_non_cpu @pytest.mark.single_gpu_tests @require_bitsandbytes def test_4bit_merge_lora(self): @@ -888,7 +887,7 @@ class PeftGPUCommonTests(unittest.TestCase): assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, bnb.nn.Linear4bit) assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, bnb.nn.Linear4bit) - @require_torch_gpu + @require_non_cpu @pytest.mark.single_gpu_tests @require_bitsandbytes def test_4bit_merge_and_disable_lora(self): @@ -930,7 +929,7 @@ class PeftGPUCommonTests(unittest.TestCase): assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear4bit) assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear4bit) - @require_torch_gpu + @require_non_cpu @pytest.mark.single_gpu_tests @require_bitsandbytes def test_4bit_merge_lora_with_bias(self): @@ -971,7 +970,7 @@ class PeftGPUCommonTests(unittest.TestCase): assert not torch.allclose(out_base, out_before_merge, atol=atol, rtol=rtol) assert torch.allclose(out_before_merge, out_after_merge, atol=atol, rtol=rtol) - @require_torch_gpu + @require_non_cpu @pytest.mark.single_gpu_tests @require_bitsandbytes def test_4bit_lora_mixed_adapter_batches_lora(self): @@ -1042,7 +1041,7 @@ class PeftGPUCommonTests(unittest.TestCase): assert torch.allclose(out_adapter0[1::3], out_mixed[1::3], atol=atol, rtol=rtol) assert torch.allclose(out_adapter1[2::3], out_mixed[2::3], atol=atol, rtol=rtol) - @require_torch_gpu + @require_non_cpu @pytest.mark.single_gpu_tests @require_bitsandbytes def test_8bit_lora_mixed_adapter_batches_lora(self): @@ -1124,7 +1123,7 @@ class PeftGPUCommonTests(unittest.TestCase): with tempfile.TemporaryDirectory() as tmp_dir: model.save_pretrained(tmp_dir, safe_serialization=True) - @require_torch_gpu + @require_non_cpu @pytest.mark.single_gpu_tests @require_bitsandbytes def test_4bit_dora_inference(self): @@ -1163,7 +1162,7 @@ class PeftGPUCommonTests(unittest.TestCase): assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear4bit) assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear4bit) - @require_torch_gpu + @require_non_cpu @pytest.mark.single_gpu_tests @require_bitsandbytes def test_8bit_dora_inference(self): @@ -1197,7 +1196,7 @@ class PeftGPUCommonTests(unittest.TestCase): assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.q_proj, LoraLinear8bitLt) assert isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, LoraLinear8bitLt) - @require_torch_gpu + @require_non_cpu @pytest.mark.single_gpu_tests @require_bitsandbytes def test_4bit_dora_merging(self): diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index fb28d7db..fdcda7d6 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -27,6 +27,7 @@ import torch from accelerate import infer_auto_device_map from accelerate.test_utils.testing import run_command from accelerate.utils import patch_environment +from accelerate.utils.memory import clear_device_cache from datasets import Audio, Dataset, DatasetDict, load_dataset from packaging import version from parameterized import parameterized @@ -91,6 +92,7 @@ from .testing_utils import ( require_torch_gpu, require_torch_multi_gpu, require_torchao, + torch_device, ) @@ -131,7 +133,7 @@ class DataCollatorSpeechSeq2SeqWithPadding: return batch -@require_torch_gpu +@require_non_cpu @require_bitsandbytes class PeftBnbGPUExampleTests(unittest.TestCase): r""" @@ -160,10 +162,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase): Efficient mechanism to free GPU memory after each test. Based on https://github.com/huggingface/transformers/issues/21094 """ - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clear_device_cache(garbage_collection=True) def _check_inference_finite(self, model, batch): # try inference without Trainer class @@ -351,7 +350,7 @@ class PeftBnbGPUExampleTests(unittest.TestCase): assert trainer.state.log_history[-1]["train_loss"] is not None @pytest.mark.single_gpu_tests - @require_torch_gpu + @require_non_cpu def test_4bit_adalora_causalLM(self): r""" Tests the 4bit training with adalora @@ -1504,8 +1503,7 @@ class PeftGPTQGPUTests(unittest.TestCase): Efficient mechanism to free GPU memory after each test. Based on https://github.com/huggingface/transformers/issues/21094 """ - gc.collect() - torch.cuda.empty_cache() + clear_device_cache(garbage_collection=True) def _check_inference_finite(self, model, batch): # try inference without Trainer class @@ -1752,8 +1750,7 @@ class OffloadSaveTests(unittest.TestCase): Efficient mechanism to free GPU memory after each test. Based on https://github.com/huggingface/transformers/issues/21094 """ - gc.collect() - torch.cuda.empty_cache() + clear_device_cache(garbage_collection=True) def test_offload_load(self): r""" @@ -1832,7 +1829,7 @@ class OffloadSaveTests(unittest.TestCase): assert torch.allclose(post_unload_merge_olayer, pre_merge_olayer) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU") +@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="test requires a GPU") @pytest.mark.single_gpu_tests class TestPiSSA: r""" @@ -1888,8 +1885,7 @@ class TestPiSSA: qlora_model = qlora_model.merge_and_unload() qlora_error = self.nuclear_norm(base_model, qlora_model) del qlora_model - gc.collect() - torch.cuda.empty_cache() + clear_device_cache(garbage_collection=True) # logits from quantized LoRA model using PiSSA lora_config = LoraConfig( @@ -1908,8 +1904,7 @@ class TestPiSSA: pissa_model.save_pretrained(tmp_path / "residual_model") del pissa_model - gc.collect() - torch.cuda.empty_cache() + clear_device_cache(garbage_collection=True) # now load quantized model and apply PiSSA-initialized weights on top qpissa_model = self.quantize_model( @@ -1919,8 +1914,7 @@ class TestPiSSA: qpissa_model = qpissa_model.merge_and_unload() qpissa_error = self.nuclear_norm(base_model, qpissa_model) del qpissa_model - gc.collect() - torch.cuda.empty_cache() + clear_device_cache(garbage_collection=True) assert qlora_error > 0.0 assert qpissa_error > 0.0 @@ -1928,7 +1922,7 @@ class TestPiSSA: # next, check that PiSSA quantization errors are smaller than LoRA errors by a certain margin assert qpissa_error < (qlora_error / self.error_factor) - @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_bloomz_pissa_4bit(self, device, tmp_path): # In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model # using PiSSA. When quantizing, we expect a certain level of error. However, we expect the PiSSA quantized @@ -1938,25 +1932,25 @@ class TestPiSSA: self.get_errors(bits=4, device=device, tmp_path=tmp_path) - @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_bloomz_pissa_8bit(self, device, tmp_path): # Same test as test_bloomz_pissa_4bit but with 8 bits. self.get_errors(bits=8, device=device, tmp_path=tmp_path) - @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_t5_pissa_4bit(self, device, tmp_path): self.get_errors(bits=4, device=device, model_id="t5-small", tmp_path=tmp_path) - @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_t5_pissa_8bit(self, device, tmp_path): self.get_errors(bits=8, device=device, model_id="t5-small", tmp_path=tmp_path) - @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_gpt2_pissa_4bit(self, device, tmp_path): # see 2104 self.get_errors(bits=4, device=device, model_id="gpt2", tmp_path=tmp_path) - @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_gpt2_pissa_8bit(self, device, tmp_path): # see 2104 self.get_errors(bits=8, device=device, model_id="gpt2", tmp_path=tmp_path) @@ -1968,7 +1962,7 @@ class TestPiSSA: import bitsandbytes as bnb torch.manual_seed(0) - data = torch.rand(10, 1000).to("cuda") + data = torch.rand(10, 1000).to(torch_device) class MyModule(torch.nn.Module): def __init__(self): @@ -1983,7 +1977,7 @@ class TestPiSSA: x_4d = x.flatten().reshape(1, 100, 10, 10) return self.linear(x), self.embed(x_int), self.conv2d(x_4d) - model = MyModule().to("cuda") + model = MyModule().to(torch_device) output_base = model(data)[0] config = LoraConfig(init_lora_weights="pissa", target_modules=["linear"], r=8) @@ -2047,7 +2041,7 @@ class TestPiSSA: assert not torch.allclose(output_finetuned_pissa, output_converted, atol=tol, rtol=tol) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU") +@pytest.mark.skipif(not (torch.cuda.is_available() or is_xpu_available()), reason="test requires a GPU") @pytest.mark.single_gpu_tests class TestOLoRA: r""" @@ -2103,8 +2097,7 @@ class TestOLoRA: qlora_model = qlora_model.merge_and_unload() qlora_error = self.nuclear_norm(base_model, qlora_model) del qlora_model - gc.collect() - torch.cuda.empty_cache() + clear_device_cache(garbage_collection=True) # logits from quantized LoRA model using OLoRA lora_config = LoraConfig( @@ -2123,8 +2116,7 @@ class TestOLoRA: olora_model.save_pretrained(tmp_path / "residual_model") del olora_model - gc.collect() - torch.cuda.empty_cache() + clear_device_cache(garbage_collection=True) # now load quantized model and apply OLoRA-initialized weights on top qolora_model = self.quantize_model( @@ -2134,8 +2126,7 @@ class TestOLoRA: qolora_model = qolora_model.merge_and_unload() qolora_error = self.nuclear_norm(base_model, qolora_model) del qolora_model - gc.collect() - torch.cuda.empty_cache() + clear_device_cache(garbage_collection=True) assert qlora_error > 0.0 assert qolora_error > 0.0 @@ -2143,7 +2134,7 @@ class TestOLoRA: # next, check that OLoRA quantization errors are smaller than LoRA errors by a certain margin assert qolora_error < (qlora_error / self.error_factor) - @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_bloomz_olora_4bit(self, device, tmp_path): # In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model # using OLoRA. When quantizing, we expect a certain level of error. However, we expect the OLoRA quantized @@ -2153,7 +2144,7 @@ class TestOLoRA: self.get_errors(bits=4, device=device, tmp_path=tmp_path) - @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_bloomz_olora_8bit(self, device, tmp_path): # Same test as test_bloomz_olora_4bit but with 8 bits. self.get_errors(bits=8, device=device, tmp_path=tmp_path) @@ -2248,8 +2239,7 @@ class TestLoftQ: logits_base = self.get_logits(model, inputs) # clean up del model - gc.collect() - torch.cuda.empty_cache() + clear_device_cache(garbage_collection=True) # logits from the normal quantized LoRA model target_modules = "all-linear" if task_type != TaskType.SEQ_2_SEQ_LM else ["o", "k", "wi", "q", "v"] @@ -2269,8 +2259,7 @@ class TestLoftQ: torch.manual_seed(0) logits_quantized = self.get_logits(quantized_model, inputs) del quantized_model - gc.collect() - torch.cuda.empty_cache() + clear_device_cache(garbage_collection=True) # logits from quantized LoRA model using LoftQ loftq_config = LoftQConfig(loftq_bits=bits, loftq_iter=loftq_iter) @@ -2282,11 +2271,11 @@ class TestLoftQ: target_modules=target_modules, ) model = self.get_base_model(model_id, device) - if device == "cuda": - model = model.to("cuda") + if device != "cpu": + model = model.to(torch_device) loftq_model = get_peft_model(model, lora_config) - if device == "cuda": - loftq_model = loftq_model.to("cuda") + if device != "cpu": + loftq_model = loftq_model.to(torch_device) # save LoRA weights, they should be initialized such that they minimize the quantization error loftq_model.base_model.peft_config["default"].init_lora_weights = True @@ -2296,8 +2285,7 @@ class TestLoftQ: loftq_model.save_pretrained(tmp_path / "base_model") del loftq_model - gc.collect() - torch.cuda.empty_cache() + clear_device_cache(garbage_collection=True) # now load quantized model and apply LoftQ-initialized weights on top base_model = self.get_base_model(tmp_path / "base_model", device=None, **kwargs, torch_dtype=torch.float32) @@ -2308,8 +2296,7 @@ class TestLoftQ: torch.manual_seed(0) logits_loftq = self.get_logits(loftq_model, inputs) del loftq_model - gc.collect() - torch.cuda.empty_cache() + clear_device_cache(garbage_collection=True) mae_quantized = torch.abs(logits_base - logits_quantized).mean() mse_quantized = torch.pow(logits_base - logits_quantized, 2).mean() @@ -2317,7 +2304,7 @@ class TestLoftQ: mse_loftq = torch.pow(logits_base - logits_loftq, 2).mean() return mae_quantized, mse_quantized, mae_loftq, mse_loftq - @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_bloomz_loftq_4bit(self, device, tmp_path): # In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model # using LoftQ. When quantizing, we expect a certain level of error. However, we expect the LoftQ quantized @@ -2336,7 +2323,7 @@ class TestLoftQ: assert mse_loftq < (mse_quantized / self.error_factor) assert mae_loftq < (mae_quantized / self.error_factor) - @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_bloomz_loftq_4bit_iter_5(self, device, tmp_path): # Same test as the previous one but with 5 iterations. We should expect the error to be even smaller with more # iterations, but in practice the difference is not that large, at least not for this small base model. @@ -2353,7 +2340,7 @@ class TestLoftQ: assert mse_loftq < (mse_quantized / self.error_factor) assert mae_loftq < (mae_quantized / self.error_factor) - @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_bloomz_loftq_8bit(self, device, tmp_path): # Same test as test_bloomz_loftq_4bit but with 8 bits. mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors(bits=8, device=device, tmp_path=tmp_path) @@ -2368,7 +2355,7 @@ class TestLoftQ: assert mse_loftq < (mse_quantized / self.error_factor) assert mae_loftq < (mae_quantized / self.error_factor) - @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_bloomz_loftq_8bit_iter_5(self, device, tmp_path): # Same test as test_bloomz_loftq_4bit_iter_5 but with 8 bits. mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors( @@ -2385,7 +2372,7 @@ class TestLoftQ: assert mse_loftq < (mse_quantized / self.error_factor) assert mae_loftq < (mae_quantized / self.error_factor) - @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_t5_loftq_4bit(self, device, tmp_path): mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors( bits=4, device=device, model_id="t5-small", tmp_path=tmp_path @@ -2400,7 +2387,7 @@ class TestLoftQ: assert mse_loftq < (mse_quantized / self.error_factor) assert mae_loftq < (mae_quantized / self.error_factor) - @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_t5_loftq_8bit(self, device, tmp_path): mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors( bits=8, device=device, model_id="t5-small", tmp_path=tmp_path @@ -2416,7 +2403,7 @@ class TestLoftQ: assert mae_loftq < (mae_quantized / self.error_factor) @pytest.mark.xfail # failing for now, but having DoRA pass is only a nice-to-have, not a must, so we're good - @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_bloomz_loftq_4bit_dora(self, device, tmp_path): # same as test_bloomz_loftq_4bit but with DoRA mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors( @@ -2433,7 +2420,7 @@ class TestLoftQ: assert mae_loftq < (mae_quantized / factor) assert mse_loftq < (mse_quantized / factor) - @pytest.mark.parametrize("device", ["cuda", "cpu"]) + @pytest.mark.parametrize("device", [torch_device, "cpu"]) def test_bloomz_loftq_8bit_dora(self, device, tmp_path): # same as test_bloomz_loftq_8bit but with DoRA mae_quantized, mse_quantized, mae_loftq, mse_loftq = self.get_errors( @@ -2462,7 +2449,7 @@ class TestLoftQ: """ torch.manual_seed(0) model_id = "bigscience/bloomz-560m" - device = "cuda" + device = torch_device tokenizer = AutoTokenizer.from_pretrained(model_id) inputs = tokenizer("The dog was", padding=True, return_tensors="pt").to(device) @@ -2513,15 +2500,13 @@ class TestLoftQ: assert not all(logs) del model - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clear_device_cache(garbage_collection=True) def test_replace_lora_weights_with_local_model(self): # see issue 2020 torch.manual_seed(0) model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" - device = "cuda" + device = torch_device with tempfile.TemporaryDirectory() as tmp_dir: # save base model locally @@ -2552,9 +2537,7 @@ class TestLoftQ: replace_lora_weights_loftq(model) del model - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clear_device_cache(garbage_collection=True) def test_config_no_loftq_init(self): with pytest.warns( @@ -2600,6 +2583,8 @@ class MixedPrecisionTests(unittest.TestCase): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() + elif is_xpu_available(): + torch.xpu.empty_cache() gc.collect() @pytest.mark.single_gpu_tests @@ -3976,8 +3961,7 @@ class TestPTuningReproducibility: model.save_pretrained(tmp_path) del model - torch.cuda.empty_cache() - gc.collect() + clear_device_cache(garbage_collection=True) model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) model = PeftModel.from_pretrained(model, tmp_path) diff --git a/tests/test_loraplus.py b/tests/test_loraplus.py index 1ba51fbd..64bb8bc3 100644 --- a/tests/test_loraplus.py +++ b/tests/test_loraplus.py @@ -19,7 +19,7 @@ from torch import nn from peft.import_utils import is_bnb_available from peft.optimizers import create_loraplus_optimizer -from .testing_utils import require_bitsandbytes +from .testing_utils import require_bitsandbytes, torch_device if is_bnb_available(): @@ -80,7 +80,7 @@ def test_lora_plus_optimizer_sucess(): "betas": (0.9, 0.999), "loraplus_weight_decay": 0.0, } - model: SimpleNet = SimpleNet().cuda() + model: SimpleNet = SimpleNet().to(torch_device) optim = create_loraplus_optimizer( model=model, optimizer_cls=optimizer_cls, @@ -91,9 +91,9 @@ def test_lora_plus_optimizer_sucess(): ) loss = torch.nn.CrossEntropyLoss() bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters()) - x = torch.randint(100, (2, 4, 10)).cuda() + x = torch.randint(100, (2, 4, 10)).to(torch_device) output = model(x).permute(0, 3, 1, 2) - label = torch.randint(16, (2, 4, 10)).cuda() + label = torch.randint(16, (2, 4, 10)).to(torch_device) loss_value = loss(output, label) loss_value.backward() optim.step() diff --git a/tests/test_tuners_utils.py b/tests/test_tuners_utils.py index a8a74be1..2a536251 100644 --- a/tests/test_tuners_utils.py +++ b/tests/test_tuners_utils.py @@ -57,7 +57,7 @@ from peft.tuners.tuners_utils import ( from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, ModulesToSaveWrapper, infer_device from peft.utils.constants import DUMMY_MODEL_CONFIG, MIN_TARGET_MODULES_FOR_OPTIMIZATION -from .testing_utils import require_bitsandbytes, require_non_cpu, require_torch_gpu +from .testing_utils import require_bitsandbytes, require_non_cpu # Implements tests for regex matching logic common for all BaseTuner subclasses, and @@ -271,7 +271,7 @@ class PeftCustomKwargsTester(unittest.TestCase): ) @parameterized.expand(BNB_TEST_CASES) - @require_torch_gpu + @require_non_cpu @require_bitsandbytes def test_maybe_include_all_linear_layers_lora_bnb( self, model_id, model_type, initial_target_modules, expected_target_modules, quantization diff --git a/tests/test_vblora.py b/tests/test_vblora.py index a676bf46..3db9778e 100644 --- a/tests/test_vblora.py +++ b/tests/test_vblora.py @@ -20,6 +20,7 @@ from safetensors import safe_open from torch import nn from peft import PeftModel, VBLoRAConfig, get_peft_model +from peft.import_utils import is_xpu_available class MLP(nn.Module): @@ -189,8 +190,9 @@ class TestVBLoRA: @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) def test_vblora_dtypes(self, dtype): mlp = self.get_mlp() - if (dtype == torch.bfloat16) and not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()): - pytest.skip("bfloat16 not supported on this system, skipping the test") + if dtype == torch.bfloat16: + if not is_xpu_available() and not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()): + pytest.skip("bfloat16 not supported on this system, skipping the test") config = VBLoRAConfig( target_modules=["lin0", "lin1", "lin3"], vector_length=2, num_vectors=10, save_only_topk_weights=False