Use hub_online_once in trainable token tests (#2701)

Also fix a minor import nit where `TrainableTokensWrapper` was not
added to `utils/__init__.py`. Fixed the corresponding imports as well.

Another housekeeping job is to move hub_online_once to testing_utils.py since it has 
grown to be used in a lot of places and testing_utils.py is the better place to keep 
such utilities.
This commit is contained in:
githubnemo
2025-08-05 12:58:55 +02:00
committed by GitHub
parent ff12d13be6
commit 44f001c695
10 changed files with 89 additions and 72 deletions

View File

@ -35,6 +35,7 @@ from .other import (
WEIGHTS_NAME,
AuxiliaryTrainingWrapper,
ModulesToSaveWrapper,
TrainableTokensWrapper,
_freeze_adapter,
_get_batch_size,
_get_input_embeddings_name,
@ -82,6 +83,7 @@ __all__ = [
"ModulesToSaveWrapper",
"PeftType",
"TaskType",
"TrainableTokensWrapper",
"_freeze_adapter",
"_get_batch_size",
"_get_input_embeddings_name",

View File

@ -48,8 +48,8 @@ from peft import (
get_peft_model,
)
from .testing_common import PeftCommonTester, hub_online_once
from .testing_utils import device_count, load_dataset_english_quotes, set_init_weights_false
from .testing_common import PeftCommonTester
from .testing_utils import device_count, hub_online_once, load_dataset_english_quotes, set_init_weights_false
PEFT_DECODER_MODELS_TO_TEST = [

View File

@ -20,7 +20,7 @@ from transformers import AutoModelForCausalLM
from peft import AutoPeftModelForCausalLM, BoneConfig, LoraConfig, PeftConfig, PeftModel, TaskType, get_peft_model
from .testing_common import hub_online_once
from .testing_utils import hub_online_once
PEFT_MODELS_TO_TEST = [("peft-internal-testing/test-lora-subfolder", "test")]

View File

@ -34,7 +34,7 @@ from peft import (
from peft.tuners import lora
from peft.utils import ModulesToSaveWrapper
from .testing_common import hub_online_once
from .testing_utils import hub_online_once
class DummyModel(torch.nn.Module):

View File

@ -37,7 +37,8 @@ from peft import (
)
from peft.utils.other import ModulesToSaveWrapper
from .testing_common import PeftCommonTester, hub_online_once
from .testing_common import PeftCommonTester
from .testing_utils import hub_online_once
PEFT_SEQ_CLS_MODELS_TO_TEST = [

View File

@ -18,8 +18,8 @@ from transformers import AutoModelForCausalLM
from peft import LoraConfig, TaskType, get_peft_model
from .testing_common import PeftCommonTester, hub_online_once
from .testing_utils import set_init_weights_false
from .testing_common import PeftCommonTester
from .testing_utils import hub_online_once, set_init_weights_false
PEFT_DECODER_MODELS_TO_TEST = [

View File

@ -23,8 +23,9 @@ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokeni
from peft import AutoPeftModel, LoraConfig, PeftModel, TrainableTokensConfig, get_peft_model
from peft.tuners.trainable_tokens.layer import TrainableTokensLayer
from peft.utils import get_peft_model_state_dict
from peft.utils.other import TrainableTokensWrapper
from peft.utils import TrainableTokensWrapper, get_peft_model_state_dict
from .testing_utils import hub_online_once
class ModelEmb(torch.nn.Module):
@ -103,6 +104,9 @@ class TestTrainableTokens:
@pytest.fixture
def model(self, model_id):
with hub_online_once(model_id):
# This must not be a yield fixture so that we don't carry the hub_online_once
# behavior over to the rest of the test that uses this fixture
return AutoModelForCausalLM.from_pretrained(model_id)
@pytest.fixture

View File

@ -58,8 +58,7 @@ from peft.tuners.tuners_utils import (
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND, ModulesToSaveWrapper, infer_device
from peft.utils.constants import DUMMY_MODEL_CONFIG, MIN_TARGET_MODULES_FOR_OPTIMIZATION
from .testing_common import hub_online_once
from .testing_utils import require_bitsandbytes, require_non_cpu
from .testing_utils import hub_online_once, require_bitsandbytes, require_non_cpu
# Implements tests for regex matching logic common for all BaseTuner subclasses, and

View File

@ -20,9 +20,7 @@ import re
import shutil
import tempfile
import warnings
from contextlib import contextmanager
from dataclasses import replace
from unittest import mock
import pytest
import torch
@ -61,14 +59,17 @@ from peft import (
)
from peft.tuners.lora import LoraLayer
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import _get_submodules, infer_device
from peft.utils.other import AuxiliaryTrainingWrapper, ModulesToSaveWrapper, TrainableTokensWrapper
from peft.utils import (
AuxiliaryTrainingWrapper,
ModulesToSaveWrapper,
TrainableTokensWrapper,
_get_submodules,
infer_device,
)
from .testing_utils import get_state_dict
from .testing_utils import get_state_dict, hub_online_once
HUB_MODEL_ACCESSES = {}
CONFIG_TESTING_KWARGS = (
# IA³
{
@ -189,59 +190,6 @@ CLASSES_MAPPING = {
DECODER_MODELS_EXTRA = {"cpt": (CPTConfig, CONFIG_TESTING_KWARGS[15])}
@contextmanager
def hub_online_once(model_id: str):
"""Set env[HF_HUB_OFFLINE]=1 (and patch transformers/hugging_face_hub to think that it was always that way)
for model ids that were seen already so that the hub is not contacted twice for the same model id in said context.
The cache (`HUB_MODEL_ACCESSES`) also tracks the number of cache hits per model id.
The reason for doing a context manager and not patching specific methods (e.g., `from_pretrained`) is that there
are a lot of places (`PeftConfig.from_pretrained`, `get_peft_state_dict`, `load_adapter`, ...) that possibly
communicate with the hub to download files / check versions / etc.
Note that using this context manager can cause problems when used in code sections that access different resources.
Example:
```
def test_something(model_id, config_kwargs):
with hub_online_once(model_id):
model = ...from_pretrained(model_id)
self.do_something_specific_with_model(model)
```
It is assumed that `do_something_specific_with_model` is an absract method that is implement by several tests.
Imagine the first test simply does `model.generate([1,2,3])`. The second call from another test suite however uses
a tokenizer (`AutoTokenizer.from_pretrained(model_id)`) - this will fail since the first pass was online but didn't
use the tokenizer and we're now in offline mode and cannot fetch the tokenizer. The recommended workaround is to
extend the cache key (`model_id` passed to `hub_online_once` in this case) by something in case the tokenizer is
used, so that these tests don't share a cache pool with the tests that don't use a tokenizer.
"""
global HUB_MODEL_ACCESSES
override = {}
try:
if model_id in HUB_MODEL_ACCESSES:
override = {"HF_HUB_OFFLINE": "1"}
HUB_MODEL_ACCESSES[model_id] += 1
else:
if model_id not in HUB_MODEL_ACCESSES:
HUB_MODEL_ACCESSES[model_id] = 0
with (
# strictly speaking it is not necessary to set the environment variable since most code that's out there
# is evaluating it at import time and we'd have to reload the modules for it to take effect. It's
# probably still a good idea to have it if there's some dynamic code that checks it.
mock.patch.dict(os.environ, override),
mock.patch("huggingface_hub.constants.HF_HUB_OFFLINE", override.get("HF_HUB_OFFLINE", False) == "1"),
mock.patch("transformers.utils.hub._is_offline_mode", override.get("HF_HUB_OFFLINE", False) == "1"),
):
yield
except Exception:
# in case of an error we have to assume that we didn't access the model properly from the hub
# for the first time, so the next call cannot be considered cached.
if HUB_MODEL_ACCESSES.get(model_id) == 0:
del HUB_MODEL_ACCESSES[model_id]
raise
class PeftCommonTester:
r"""
A large testing suite for testing common functionality of the PEFT models.

View File

@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from contextlib import contextmanager
from functools import lru_cache, wraps
from unittest import mock
import numpy as np
import pytest
@ -41,6 +43,10 @@ from peft.import_utils import (
)
# Globally shared model cache used by `hub_online_once`.
_HUB_MODEL_ACCESSES = {}
torch_device, device_count, memory_allocated_func = get_backend()
@ -241,3 +247,60 @@ def set_init_weights_false(config_cls, kwargs):
else:
kwargs["init_weights"] = False
return kwargs
@contextmanager
def hub_online_once(model_id: str):
"""Set env[HF_HUB_OFFLINE]=1 (and patch transformers/hugging_face_hub to think that it was always that way)
for model ids that were already to avoid contacting the hub twice for the same model id in the context. The global
variable `_HUB_MODEL_ACCESSES` tracks the number of hits per model id between `hub_online_once` calls.
The reason for doing a context manager and not patching specific methods (e.g., `from_pretrained`) is that there
are a lot of places (`PeftConfig.from_pretrained`, `get_peft_state_dict`, `load_adapter`, ...) that possibly
communicate with the hub to download files / check versions / etc.
Note that using this context manager can cause problems when used in code sections that access different resources.
Example:
```
def test_something(model_id, config_kwargs):
with hub_online_once(model_id):
model = ...from_pretrained(model_id)
self.do_something_specific_with_model(model)
```
It is assumed that `do_something_specific_with_model` is an absract method that is implement by several tests.
Imagine the first test simply does `model.generate([1,2,3])`. The second call from another test suite however uses
a tokenizer (`AutoTokenizer.from_pretrained(model_id)`) - this will fail since the first pass was online but didn't
use the tokenizer and we're now in offline mode and cannot fetch the tokenizer. The recommended workaround is to
extend the cache key (`model_id` passed to `hub_online_once` in this case) by something in case the tokenizer is
used, so that these tests don't share a cache pool with the tests that don't use a tokenizer.
It is best to avoid using this context manager in *yield* fixtures (normal fixtures are fine) as this is equivalent
to wrapping the whole test in the context manager without explicitly writing it out, leading to unexpected
`HF_HUB_OFFLINE` behavior in the test body.
"""
global _HUB_MODEL_ACCESSES
override = {}
try:
if model_id in _HUB_MODEL_ACCESSES:
override = {"HF_HUB_OFFLINE": "1"}
_HUB_MODEL_ACCESSES[model_id] += 1
else:
if model_id not in _HUB_MODEL_ACCESSES:
_HUB_MODEL_ACCESSES[model_id] = 0
with (
# strictly speaking it is not necessary to set the environment variable since most code that's out there
# is evaluating it at import time and we'd have to reload the modules for it to take effect. It's
# probably still a good idea to have it if there's some dynamic code that checks it.
mock.patch.dict(os.environ, override),
mock.patch("huggingface_hub.constants.HF_HUB_OFFLINE", override.get("HF_HUB_OFFLINE", False) == "1"),
mock.patch("transformers.utils.hub._is_offline_mode", override.get("HF_HUB_OFFLINE", False) == "1"),
):
yield
except Exception:
# in case of an error we have to assume that we didn't access the model properly from the hub
# for the first time, so the next call cannot be considered cached.
if _HUB_MODEL_ACCESSES.get(model_id) == 0:
del _HUB_MODEL_ACCESSES[model_id]
raise