mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
Check if PEFT triggers transformers FutureWarning or DeprecationWarning by converting these warnings into failures.
87 lines
3.6 KiB
Python
87 lines
3.6 KiB
Python
# Copyright 2023-present the HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import logging
|
|
import platform
|
|
import re
|
|
|
|
import pytest
|
|
|
|
|
|
def pytest_addoption(parser):
|
|
parser.addoption("--regression", action="store_true", default=False, help="run regression tests")
|
|
|
|
|
|
def pytest_configure(config):
|
|
config.addinivalue_line("markers", "regression: mark regression tests")
|
|
|
|
# Errors from transformers deprecations
|
|
logger = logging.getLogger("transformers")
|
|
|
|
class ErrorOnDeprecation(logging.Handler):
|
|
def emit(self, record):
|
|
msg = record.getMessage().lower()
|
|
if "deprecat" in msg or "future" in msg:
|
|
if "torch_dtype" not in msg:
|
|
# let's ignore the torch_dtype => dtype deprecation for now
|
|
raise AssertionError(f"**Transformers Deprecation**: {msg}")
|
|
|
|
# Add our handler
|
|
handler = ErrorOnDeprecation()
|
|
logger.addHandler(handler)
|
|
logger.setLevel(logging.WARNING)
|
|
|
|
|
|
def pytest_collection_modifyitems(config, items):
|
|
if config.getoption("--regression"):
|
|
return
|
|
|
|
skip_regression = pytest.mark.skip(reason="need --regression option to run regression tests")
|
|
for item in items:
|
|
if "regression" in item.keywords:
|
|
item.add_marker(skip_regression)
|
|
|
|
|
|
# TODO: remove this once support for PyTorch 2.2 (the latest one still supported by GitHub MacOS x86_64 runners) is
|
|
# dropped, or if MacOS is removed from the test matrix, see https://github.com/huggingface/peft/issues/2431.
|
|
# Note: the function name is fixed by the pytest plugin system, don't change it
|
|
@pytest.hookimpl(hookwrapper=True)
|
|
def pytest_runtest_makereport(item, call):
|
|
"""
|
|
Plug into the pytest test report generation to skip a specific MacOS failure caused by transformers.
|
|
|
|
The error was introduced by https://github.com/huggingface/transformers/pull/37785, which results in torch.load
|
|
failing when using torch < 2.6.
|
|
|
|
Since the MacOS x86 runners need to use an older torch version, those steps are necessary to get the CI green.
|
|
"""
|
|
outcome = yield
|
|
rep = outcome.get_result()
|
|
# ref:
|
|
# https://github.com/huggingface/transformers/blob/858ce6879a4aa7fa76a7c4e2ac20388e087ace26/src/transformers/utils/import_utils.py#L1418
|
|
error_msg = re.compile(r"Due to a serious vulnerability issue in `torch.load`")
|
|
|
|
# notes:
|
|
# - pytest uses hard-coded strings, we cannot import and use constants
|
|
# https://docs.pytest.org/en/stable/reference/reference.html#pytest.TestReport
|
|
# - errors can happen during call (running the test) but also setup (e.g. in fixtures)
|
|
if rep.failed and (rep.when in ("setup", "call")) and (platform.system() == "Darwin"):
|
|
exc_msg = str(call.excinfo.value)
|
|
if error_msg.search(exc_msg):
|
|
# turn this failure into an xfail:
|
|
rep.outcome = "skipped"
|
|
# for this attribute, see:
|
|
# https://github.com/pytest-dev/pytest/blob/bd6877e5874b50ee57d0f63b342a67298ee9a1c3/src/_pytest/reports.py#L266C5-L266C13
|
|
rep.wasxfail = "Error known to occur on MacOS with older torch versions, won't be fixed"
|