!241 修正文档中关于evaluation_strategy更名为eval_strategy的问题

Merge pull request !241 from 幽若/master-0616
This commit is contained in:
2025-06-17 08:17:53 +00:00
committed by i-robot
parent ea3a0e073f
commit d6d874b0ec
3 changed files with 57 additions and 4 deletions

View File

@ -55,7 +55,8 @@ small_eval_dataset = tokenized_datasets["validation"].shuffle(seed=42).select(ra
from openmind import TrainingArguments, Trainer, metrics
import numpy as np
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")
# 在4.51.3版本的transformers中evaluation_strategy参数已更名为eval_strategy, 参见https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/training_args.py#L239
training_args = TrainingArguments(output_dir="test_trainer", eval_strategy="epoch")
def compute_metrics(eval_pred):
logits, labels = eval_pred

View File

@ -259,7 +259,7 @@ openMind Library提供了一个`Trainer`类来实现训练模型所需功能。
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
num_train_epochs=10,
evaluation_strategy="epoch",
eval_strategy="epoch",
)
```

View File

@ -15,6 +15,7 @@
import importlib
import importlib.metadata
import sys
from functools import wraps
from typing import Any, Dict, List, Optional
import torch
@ -30,6 +31,7 @@ from transformers.utils import (
is_torch_npu_available,
is_torch_xpu_available,
)
from transformers.utils.quantization_config import QuantizationMethod
def create_quantized_param_patch(
@ -187,7 +189,6 @@ def is_bitsandbytes_available_patch():
def validate_environment_patch(self, *args, **kwargs):
if not is_accelerate_available():
raise ImportError(
f"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
@ -235,6 +236,57 @@ def validate_environment_patch(self, *args, **kwargs):
)
@wraps(torch.nn.Module.to)
def to(self, *args, **kwargs):
# For BNB/GPTQ models, we prevent users from casting the model to another dtype to restrict unwanted behaviours.
# the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
dtype_present_in_args = "dtype" in kwargs
if not dtype_present_in_args:
for arg in args:
if isinstance(arg, torch.dtype):
dtype_present_in_args = True
break
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
raise ValueError("`.to` is not supported for HQQ-quantized models.")
if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK:
raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.")
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
if dtype_present_in_args:
raise ValueError(
"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
" desired `dtype` by passing the correct `torch_dtype` argument."
)
if getattr(self, "is_loaded_in_8bit", False):
raise ValueError(
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
)
elif is_torch_npu_available():
if version.parse(importlib.metadata.version("bitsandbytes-npu-beta")) < version.parse("0.43.1"):
raise ValueError(
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
)
elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
raise ValueError(
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
)
elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
if dtype_present_in_args:
raise ValueError(
"You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
" `dtype` by passing the correct `torch_dtype` argument."
)
return super(PreTrainedModel, self).to(*args, **kwargs)
def is_serializable_patch(self, safe_serialization=None):
_is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes-npu-beta")) >= version.parse(
"0.41.3"
@ -249,7 +301,6 @@ def is_serializable_patch(self, safe_serialization=None):
def patch_bnb():
for module in sys.modules.values():
if hasattr(module, "is_bitsandbytes_available"):
module.is_bitsandbytes_available = is_bitsandbytes_available_patch
@ -267,3 +318,4 @@ def patch_bnb():
validate_environment_patch,
)
setattr(transformers.quantizers.quantizer_bnb_4bit.Bnb4BitHfQuantizer, "is_serializable", is_serializable_patch)
setattr(transformers.modeling_utils.PreTrainedModel, "to", to)