mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
TST Add missing configs to test_config.py (#2781)
The test_config.py tests were missing a few configs from recently added PEFT methods. Those are now included. After adding those, it was revealed that for C3A and trainable tokens, super().__post_init__() was not being called. This is now done.
This commit is contained in:
@ -123,6 +123,7 @@ class C3AConfig(PeftConfig):
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.peft_type = PeftType.C3A
|
||||
self.target_modules = (
|
||||
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
|
||||
|
@ -85,4 +85,5 @@ class TrainableTokensConfig(PeftConfig):
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.peft_type = PeftType.TRAINABLE_TOKENS
|
||||
|
@ -24,6 +24,8 @@ from peft import (
|
||||
AdaLoraConfig,
|
||||
AdaptionPromptConfig,
|
||||
BOFTConfig,
|
||||
BoneConfig,
|
||||
C3AConfig,
|
||||
FourierFTConfig,
|
||||
HRAConfig,
|
||||
IA3Config,
|
||||
@ -31,6 +33,7 @@ from peft import (
|
||||
LoHaConfig,
|
||||
LoKrConfig,
|
||||
LoraConfig,
|
||||
MissConfig,
|
||||
MultitaskPromptTuningConfig,
|
||||
OFTConfig,
|
||||
PeftConfig,
|
||||
@ -41,9 +44,12 @@ from peft import (
|
||||
PromptEncoderConfig,
|
||||
PromptTuningConfig,
|
||||
RoadConfig,
|
||||
ShiraConfig,
|
||||
TaskType,
|
||||
TrainableTokensConfig,
|
||||
VBLoRAConfig,
|
||||
VeraConfig,
|
||||
XLoraConfig,
|
||||
)
|
||||
|
||||
|
||||
@ -54,6 +60,8 @@ ALL_CONFIG_CLASSES = (
|
||||
(AdaLoraConfig, {"total_step": 1}),
|
||||
(AdaptionPromptConfig, {}),
|
||||
(BOFTConfig, {}),
|
||||
(BoneConfig, {}),
|
||||
(C3AConfig, {}),
|
||||
(FourierFTConfig, {}),
|
||||
(HRAConfig, {}),
|
||||
(IA3Config, {}),
|
||||
@ -61,14 +69,18 @@ ALL_CONFIG_CLASSES = (
|
||||
(LoHaConfig, {}),
|
||||
(LoKrConfig, {}),
|
||||
(LoraConfig, {}),
|
||||
(MissConfig, {}),
|
||||
(MultitaskPromptTuningConfig, {}),
|
||||
(PolyConfig, {}),
|
||||
(PrefixTuningConfig, {}),
|
||||
(PromptEncoderConfig, {}),
|
||||
(PromptTuningConfig, {}),
|
||||
(RoadConfig, {}),
|
||||
(ShiraConfig, {}),
|
||||
(TrainableTokensConfig, {}),
|
||||
(VeraConfig, {}),
|
||||
(VBLoRAConfig, {}),
|
||||
(XLoraConfig, {"hidden_size": 32, "adapters": {}}),
|
||||
)
|
||||
|
||||
|
||||
@ -399,8 +411,13 @@ class TestPeftConfig:
|
||||
msg = f"Unexpected keyword arguments ['foobar', 'spam'] for class {config_class.__name__}, these are ignored."
|
||||
config_from_pretrained = config_class.from_pretrained(tmp_path)
|
||||
|
||||
assert len(recwarn) == 1
|
||||
assert recwarn.list[0].message.args[0].startswith(msg)
|
||||
expected_num_warnings = 1
|
||||
# TODO: remove once Bone is removed in v0.19.0
|
||||
if config_class == BoneConfig:
|
||||
expected_num_warnings = 2 # Bone has 1 more warning about it being deprecated
|
||||
|
||||
assert len(recwarn) == expected_num_warnings
|
||||
assert recwarn.list[-1].message.args[0].startswith(msg)
|
||||
assert "foo" not in config_from_pretrained.to_dict()
|
||||
assert "spam" not in config_from_pretrained.to_dict()
|
||||
assert config.to_dict() == config_from_pretrained.to_dict()
|
||||
@ -429,8 +446,13 @@ class TestPeftConfig:
|
||||
msg = f"Unexpected keyword arguments ['foobar', 'spam'] for class {config_class.__name__}, these are ignored."
|
||||
config_from_pretrained = PeftConfig.from_pretrained(tmp_path) # <== use PeftConfig here
|
||||
|
||||
assert len(recwarn) == 1
|
||||
assert recwarn.list[0].message.args[0].startswith(msg)
|
||||
expected_num_warnings = 1
|
||||
# TODO: remove once Bone is removed in v0.19.0
|
||||
if config_class == BoneConfig:
|
||||
expected_num_warnings = 2 # Bone has 1 more warning about it being deprecated
|
||||
|
||||
assert len(recwarn) == expected_num_warnings
|
||||
assert recwarn.list[-1].message.args[0].startswith(msg)
|
||||
assert "foo" not in config_from_pretrained.to_dict()
|
||||
assert "spam" not in config_from_pretrained.to_dict()
|
||||
assert config.to_dict() == config_from_pretrained.to_dict()
|
||||
|
Reference in New Issue
Block a user