TST Add missing configs to test_config.py (#2781)

The test_config.py tests were missing a few configs from recently added
PEFT methods. Those are now included. After adding those, it was
revealed that for C3A and trainable tokens, super().__post_init__() was
not being called. This is now done.
This commit is contained in:
Benjamin Bossan
2025-09-19 17:52:58 +02:00
committed by GitHub
parent 20a9829f76
commit b774fd901e
3 changed files with 28 additions and 4 deletions

View File

@ -123,6 +123,7 @@ class C3AConfig(PeftConfig):
)
def __post_init__(self):
super().__post_init__()
self.peft_type = PeftType.C3A
self.target_modules = (
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules

View File

@ -85,4 +85,5 @@ class TrainableTokensConfig(PeftConfig):
)
def __post_init__(self):
super().__post_init__()
self.peft_type = PeftType.TRAINABLE_TOKENS

View File

@ -24,6 +24,8 @@ from peft import (
AdaLoraConfig,
AdaptionPromptConfig,
BOFTConfig,
BoneConfig,
C3AConfig,
FourierFTConfig,
HRAConfig,
IA3Config,
@ -31,6 +33,7 @@ from peft import (
LoHaConfig,
LoKrConfig,
LoraConfig,
MissConfig,
MultitaskPromptTuningConfig,
OFTConfig,
PeftConfig,
@ -41,9 +44,12 @@ from peft import (
PromptEncoderConfig,
PromptTuningConfig,
RoadConfig,
ShiraConfig,
TaskType,
TrainableTokensConfig,
VBLoRAConfig,
VeraConfig,
XLoraConfig,
)
@ -54,6 +60,8 @@ ALL_CONFIG_CLASSES = (
(AdaLoraConfig, {"total_step": 1}),
(AdaptionPromptConfig, {}),
(BOFTConfig, {}),
(BoneConfig, {}),
(C3AConfig, {}),
(FourierFTConfig, {}),
(HRAConfig, {}),
(IA3Config, {}),
@ -61,14 +69,18 @@ ALL_CONFIG_CLASSES = (
(LoHaConfig, {}),
(LoKrConfig, {}),
(LoraConfig, {}),
(MissConfig, {}),
(MultitaskPromptTuningConfig, {}),
(PolyConfig, {}),
(PrefixTuningConfig, {}),
(PromptEncoderConfig, {}),
(PromptTuningConfig, {}),
(RoadConfig, {}),
(ShiraConfig, {}),
(TrainableTokensConfig, {}),
(VeraConfig, {}),
(VBLoRAConfig, {}),
(XLoraConfig, {"hidden_size": 32, "adapters": {}}),
)
@ -399,8 +411,13 @@ class TestPeftConfig:
msg = f"Unexpected keyword arguments ['foobar', 'spam'] for class {config_class.__name__}, these are ignored."
config_from_pretrained = config_class.from_pretrained(tmp_path)
assert len(recwarn) == 1
assert recwarn.list[0].message.args[0].startswith(msg)
expected_num_warnings = 1
# TODO: remove once Bone is removed in v0.19.0
if config_class == BoneConfig:
expected_num_warnings = 2 # Bone has 1 more warning about it being deprecated
assert len(recwarn) == expected_num_warnings
assert recwarn.list[-1].message.args[0].startswith(msg)
assert "foo" not in config_from_pretrained.to_dict()
assert "spam" not in config_from_pretrained.to_dict()
assert config.to_dict() == config_from_pretrained.to_dict()
@ -429,8 +446,13 @@ class TestPeftConfig:
msg = f"Unexpected keyword arguments ['foobar', 'spam'] for class {config_class.__name__}, these are ignored."
config_from_pretrained = PeftConfig.from_pretrained(tmp_path) # <== use PeftConfig here
assert len(recwarn) == 1
assert recwarn.list[0].message.args[0].startswith(msg)
expected_num_warnings = 1
# TODO: remove once Bone is removed in v0.19.0
if config_class == BoneConfig:
expected_num_warnings = 2 # Bone has 1 more warning about it being deprecated
assert len(recwarn) == expected_num_warnings
assert recwarn.list[-1].message.args[0].startswith(msg)
assert "foo" not in config_from_pretrained.to_dict()
assert "spam" not in config_from_pretrained.to_dict()
assert config.to_dict() == config_from_pretrained.to_dict()