diff --git a/src/peft/utils/__init__.py b/src/peft/utils/__init__.py index 3163a879..fe004356 100644 --- a/src/peft/utils/__init__.py +++ b/src/peft/utils/__init__.py @@ -35,6 +35,7 @@ from .other import ( WEIGHTS_NAME, AuxiliaryTrainingWrapper, ModulesToSaveWrapper, + TrainableTokensWrapper, _freeze_adapter, _get_batch_size, _get_input_embeddings_name, @@ -82,6 +83,7 @@ __all__ = [ "ModulesToSaveWrapper", "PeftType", "TaskType", + "TrainableTokensWrapper", "_freeze_adapter", "_get_batch_size", "_get_input_embeddings_name", diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index b8352818..0ff338cf 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -48,8 +48,8 @@ from peft import ( get_peft_model, ) -from .testing_common import PeftCommonTester, hub_online_once -from .testing_utils import device_count, load_dataset_english_quotes, set_init_weights_false +from .testing_common import PeftCommonTester +from .testing_utils import device_count, hub_online_once, load_dataset_english_quotes, set_init_weights_false PEFT_DECODER_MODELS_TO_TEST = [ diff --git a/tests/test_hub_features.py b/tests/test_hub_features.py index 257c487f..f705167c 100644 --- a/tests/test_hub_features.py +++ b/tests/test_hub_features.py @@ -20,7 +20,7 @@ from transformers import AutoModelForCausalLM from peft import AutoPeftModelForCausalLM, BoneConfig, LoraConfig, PeftConfig, PeftModel, TaskType, get_peft_model -from .testing_common import hub_online_once +from .testing_utils import hub_online_once PEFT_MODELS_TO_TEST = [("peft-internal-testing/test-lora-subfolder", "test")] diff --git a/tests/test_low_level_api.py b/tests/test_low_level_api.py index e2cf6532..7f4347b0 100644 --- a/tests/test_low_level_api.py +++ b/tests/test_low_level_api.py @@ -34,7 +34,7 @@ from peft import ( from peft.tuners import lora from peft.utils import ModulesToSaveWrapper -from .testing_common import hub_online_once +from .testing_utils import hub_online_once class DummyModel(torch.nn.Module): diff --git a/tests/test_seq_classifier.py b/tests/test_seq_classifier.py index d946dcce..ad3815f3 100644 --- a/tests/test_seq_classifier.py +++ b/tests/test_seq_classifier.py @@ -37,7 +37,8 @@ from peft import ( ) from peft.utils.other import ModulesToSaveWrapper -from .testing_common import PeftCommonTester, hub_online_once +from .testing_common import PeftCommonTester +from .testing_utils import hub_online_once PEFT_SEQ_CLS_MODELS_TO_TEST = [ diff --git a/tests/test_target_parameters.py b/tests/test_target_parameters.py index c01937f1..70b90ca7 100644 --- a/tests/test_target_parameters.py +++ b/tests/test_target_parameters.py @@ -18,8 +18,8 @@ from transformers import AutoModelForCausalLM from peft import LoraConfig, TaskType, get_peft_model -from .testing_common import PeftCommonTester, hub_online_once -from .testing_utils import set_init_weights_false +from .testing_common import PeftCommonTester +from .testing_utils import hub_online_once, set_init_weights_false PEFT_DECODER_MODELS_TO_TEST = [ diff --git a/tests/test_trainable_tokens.py b/tests/test_trainable_tokens.py index 15f4a41a..38b32b06 100644 --- a/tests/test_trainable_tokens.py +++ b/tests/test_trainable_tokens.py @@ -23,8 +23,9 @@ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokeni from peft import AutoPeftModel, LoraConfig, PeftModel, TrainableTokensConfig, get_peft_model from peft.tuners.trainable_tokens.layer import TrainableTokensLayer -from peft.utils import get_peft_model_state_dict -from peft.utils.other import TrainableTokensWrapper +from peft.utils import TrainableTokensWrapper, get_peft_model_state_dict + +from .testing_utils import hub_online_once class ModelEmb(torch.nn.Module): @@ -103,7 +104,10 @@ class TestTrainableTokens: @pytest.fixture def model(self, model_id): - return AutoModelForCausalLM.from_pretrained(model_id) + with hub_online_once(model_id): + # This must not be a yield fixture so that we don't carry the hub_online_once + # behavior over to the rest of the test that uses this fixture + return AutoModelForCausalLM.from_pretrained(model_id) @pytest.fixture def tokenizer(self, model_id): diff --git a/tests/test_tuners_utils.py b/tests/test_tuners_utils.py index ee6be105..6a32a590 100644 --- a/tests/test_tuners_utils.py +++ b/tests/test_tuners_utils.py @@ -58,8 +58,7 @@ from peft.tuners.tuners_utils import ( from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, ModulesToSaveWrapper, infer_device from peft.utils.constants import DUMMY_MODEL_CONFIG, MIN_TARGET_MODULES_FOR_OPTIMIZATION -from .testing_common import hub_online_once -from .testing_utils import require_bitsandbytes, require_non_cpu +from .testing_utils import hub_online_once, require_bitsandbytes, require_non_cpu # Implements tests for regex matching logic common for all BaseTuner subclasses, and diff --git a/tests/testing_common.py b/tests/testing_common.py index e0b7d497..9e5c9582 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -20,9 +20,7 @@ import re import shutil import tempfile import warnings -from contextlib import contextmanager from dataclasses import replace -from unittest import mock import pytest import torch @@ -61,14 +59,17 @@ from peft import ( ) from peft.tuners.lora import LoraLayer from peft.tuners.tuners_utils import BaseTunerLayer -from peft.utils import _get_submodules, infer_device -from peft.utils.other import AuxiliaryTrainingWrapper, ModulesToSaveWrapper, TrainableTokensWrapper +from peft.utils import ( + AuxiliaryTrainingWrapper, + ModulesToSaveWrapper, + TrainableTokensWrapper, + _get_submodules, + infer_device, +) -from .testing_utils import get_state_dict +from .testing_utils import get_state_dict, hub_online_once -HUB_MODEL_ACCESSES = {} - CONFIG_TESTING_KWARGS = ( # IA³ { @@ -189,59 +190,6 @@ CLASSES_MAPPING = { DECODER_MODELS_EXTRA = {"cpt": (CPTConfig, CONFIG_TESTING_KWARGS[15])} -@contextmanager -def hub_online_once(model_id: str): - """Set env[HF_HUB_OFFLINE]=1 (and patch transformers/hugging_face_hub to think that it was always that way) - for model ids that were seen already so that the hub is not contacted twice for the same model id in said context. - The cache (`HUB_MODEL_ACCESSES`) also tracks the number of cache hits per model id. - - The reason for doing a context manager and not patching specific methods (e.g., `from_pretrained`) is that there - are a lot of places (`PeftConfig.from_pretrained`, `get_peft_state_dict`, `load_adapter`, ...) that possibly - communicate with the hub to download files / check versions / etc. - - Note that using this context manager can cause problems when used in code sections that access different resources. - Example: - - ``` - def test_something(model_id, config_kwargs): - with hub_online_once(model_id): - model = ...from_pretrained(model_id) - self.do_something_specific_with_model(model) - ``` - It is assumed that `do_something_specific_with_model` is an absract method that is implement by several tests. - Imagine the first test simply does `model.generate([1,2,3])`. The second call from another test suite however uses - a tokenizer (`AutoTokenizer.from_pretrained(model_id)`) - this will fail since the first pass was online but didn't - use the tokenizer and we're now in offline mode and cannot fetch the tokenizer. The recommended workaround is to - extend the cache key (`model_id` passed to `hub_online_once` in this case) by something in case the tokenizer is - used, so that these tests don't share a cache pool with the tests that don't use a tokenizer. - """ - global HUB_MODEL_ACCESSES - override = {} - - try: - if model_id in HUB_MODEL_ACCESSES: - override = {"HF_HUB_OFFLINE": "1"} - HUB_MODEL_ACCESSES[model_id] += 1 - else: - if model_id not in HUB_MODEL_ACCESSES: - HUB_MODEL_ACCESSES[model_id] = 0 - with ( - # strictly speaking it is not necessary to set the environment variable since most code that's out there - # is evaluating it at import time and we'd have to reload the modules for it to take effect. It's - # probably still a good idea to have it if there's some dynamic code that checks it. - mock.patch.dict(os.environ, override), - mock.patch("huggingface_hub.constants.HF_HUB_OFFLINE", override.get("HF_HUB_OFFLINE", False) == "1"), - mock.patch("transformers.utils.hub._is_offline_mode", override.get("HF_HUB_OFFLINE", False) == "1"), - ): - yield - except Exception: - # in case of an error we have to assume that we didn't access the model properly from the hub - # for the first time, so the next call cannot be considered cached. - if HUB_MODEL_ACCESSES.get(model_id) == 0: - del HUB_MODEL_ACCESSES[model_id] - raise - - class PeftCommonTester: r""" A large testing suite for testing common functionality of the PEFT models. diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 7aeb5320..1ed0df90 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest from contextlib import contextmanager from functools import lru_cache, wraps +from unittest import mock import numpy as np import pytest @@ -41,6 +43,10 @@ from peft.import_utils import ( ) +# Globally shared model cache used by `hub_online_once`. +_HUB_MODEL_ACCESSES = {} + + torch_device, device_count, memory_allocated_func = get_backend() @@ -241,3 +247,60 @@ def set_init_weights_false(config_cls, kwargs): else: kwargs["init_weights"] = False return kwargs + + +@contextmanager +def hub_online_once(model_id: str): + """Set env[HF_HUB_OFFLINE]=1 (and patch transformers/hugging_face_hub to think that it was always that way) + for model ids that were already to avoid contacting the hub twice for the same model id in the context. The global + variable `_HUB_MODEL_ACCESSES` tracks the number of hits per model id between `hub_online_once` calls. + + The reason for doing a context manager and not patching specific methods (e.g., `from_pretrained`) is that there + are a lot of places (`PeftConfig.from_pretrained`, `get_peft_state_dict`, `load_adapter`, ...) that possibly + communicate with the hub to download files / check versions / etc. + + Note that using this context manager can cause problems when used in code sections that access different resources. + Example: + + ``` + def test_something(model_id, config_kwargs): + with hub_online_once(model_id): + model = ...from_pretrained(model_id) + self.do_something_specific_with_model(model) + ``` + It is assumed that `do_something_specific_with_model` is an absract method that is implement by several tests. + Imagine the first test simply does `model.generate([1,2,3])`. The second call from another test suite however uses + a tokenizer (`AutoTokenizer.from_pretrained(model_id)`) - this will fail since the first pass was online but didn't + use the tokenizer and we're now in offline mode and cannot fetch the tokenizer. The recommended workaround is to + extend the cache key (`model_id` passed to `hub_online_once` in this case) by something in case the tokenizer is + used, so that these tests don't share a cache pool with the tests that don't use a tokenizer. + + It is best to avoid using this context manager in *yield* fixtures (normal fixtures are fine) as this is equivalent + to wrapping the whole test in the context manager without explicitly writing it out, leading to unexpected + `HF_HUB_OFFLINE` behavior in the test body. + """ + global _HUB_MODEL_ACCESSES + override = {} + + try: + if model_id in _HUB_MODEL_ACCESSES: + override = {"HF_HUB_OFFLINE": "1"} + _HUB_MODEL_ACCESSES[model_id] += 1 + else: + if model_id not in _HUB_MODEL_ACCESSES: + _HUB_MODEL_ACCESSES[model_id] = 0 + with ( + # strictly speaking it is not necessary to set the environment variable since most code that's out there + # is evaluating it at import time and we'd have to reload the modules for it to take effect. It's + # probably still a good idea to have it if there's some dynamic code that checks it. + mock.patch.dict(os.environ, override), + mock.patch("huggingface_hub.constants.HF_HUB_OFFLINE", override.get("HF_HUB_OFFLINE", False) == "1"), + mock.patch("transformers.utils.hub._is_offline_mode", override.get("HF_HUB_OFFLINE", False) == "1"), + ): + yield + except Exception: + # in case of an error we have to assume that we didn't access the model properly from the hub + # for the first time, so the next call cannot be considered cached. + if _HUB_MODEL_ACCESSES.get(model_id) == 0: + del _HUB_MODEL_ACCESSES[model_id] + raise