!241 修正文档中关于evaluation_strategy更名为eval_strategy的问题
Merge pull request !241 from 幽若/master-0616
This commit is contained in:
@ -55,7 +55,8 @@ small_eval_dataset = tokenized_datasets["validation"].shuffle(seed=42).select(ra
|
||||
from openmind import TrainingArguments, Trainer, metrics
|
||||
import numpy as np
|
||||
|
||||
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")
|
||||
# 在4.51.3版本的transformers中,evaluation_strategy参数已更名为eval_strategy, 参见https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/training_args.py#L239
|
||||
training_args = TrainingArguments(output_dir="test_trainer", eval_strategy="epoch")
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
logits, labels = eval_pred
|
||||
|
@ -259,7 +259,7 @@ openMind Library提供了一个`Trainer`类来实现训练模型所需功能。
|
||||
per_device_train_batch_size=4,
|
||||
per_device_eval_batch_size=4,
|
||||
num_train_epochs=10,
|
||||
evaluation_strategy="epoch",
|
||||
eval_strategy="epoch",
|
||||
)
|
||||
```
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import sys
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@ -30,6 +31,7 @@ from transformers.utils import (
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available,
|
||||
)
|
||||
from transformers.utils.quantization_config import QuantizationMethod
|
||||
|
||||
|
||||
def create_quantized_param_patch(
|
||||
@ -187,7 +189,6 @@ def is_bitsandbytes_available_patch():
|
||||
|
||||
|
||||
def validate_environment_patch(self, *args, **kwargs):
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ImportError(
|
||||
f"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`"
|
||||
@ -235,6 +236,57 @@ def validate_environment_patch(self, *args, **kwargs):
|
||||
)
|
||||
|
||||
|
||||
@wraps(torch.nn.Module.to)
|
||||
def to(self, *args, **kwargs):
|
||||
# For BNB/GPTQ models, we prevent users from casting the model to another dtype to restrict unwanted behaviours.
|
||||
# the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
|
||||
dtype_present_in_args = "dtype" in kwargs
|
||||
|
||||
if not dtype_present_in_args:
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.dtype):
|
||||
dtype_present_in_args = True
|
||||
break
|
||||
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
|
||||
raise ValueError("`.to` is not supported for HQQ-quantized models.")
|
||||
|
||||
if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK:
|
||||
raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.")
|
||||
|
||||
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||
if dtype_present_in_args:
|
||||
raise ValueError(
|
||||
"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
|
||||
" desired `dtype` by passing the correct `torch_dtype` argument."
|
||||
)
|
||||
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
raise ValueError(
|
||||
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
|
||||
" model has already been set to the correct devices and casted to the correct `dtype`."
|
||||
)
|
||||
elif is_torch_npu_available():
|
||||
if version.parse(importlib.metadata.version("bitsandbytes-npu-beta")) < version.parse("0.43.1"):
|
||||
raise ValueError(
|
||||
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
||||
)
|
||||
elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
|
||||
raise ValueError(
|
||||
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
||||
)
|
||||
elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
|
||||
if dtype_present_in_args:
|
||||
raise ValueError(
|
||||
"You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
|
||||
" `dtype` by passing the correct `torch_dtype` argument."
|
||||
)
|
||||
return super(PreTrainedModel, self).to(*args, **kwargs)
|
||||
|
||||
|
||||
def is_serializable_patch(self, safe_serialization=None):
|
||||
_is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes-npu-beta")) >= version.parse(
|
||||
"0.41.3"
|
||||
@ -249,7 +301,6 @@ def is_serializable_patch(self, safe_serialization=None):
|
||||
|
||||
|
||||
def patch_bnb():
|
||||
|
||||
for module in sys.modules.values():
|
||||
if hasattr(module, "is_bitsandbytes_available"):
|
||||
module.is_bitsandbytes_available = is_bitsandbytes_available_patch
|
||||
@ -267,3 +318,4 @@ def patch_bnb():
|
||||
validate_environment_patch,
|
||||
)
|
||||
setattr(transformers.quantizers.quantizer_bnb_4bit.Bnb4BitHfQuantizer, "is_serializable", is_serializable_patch)
|
||||
setattr(transformers.modeling_utils.PreTrainedModel, "to", to)
|
||||
|
Reference in New Issue
Block a user