From c81363bd4e6b4f1f44ccdf9d86f16cdcf3f48683 Mon Sep 17 00:00:00 2001 From: githubnemo Date: Mon, 8 Sep 2025 13:35:47 +0200 Subject: [PATCH] Support dataclass model configs (#2778) LeRobot uses dataclasses to manage policy configs. If we want to support LeRobot policy fine-tuning it'd be easiest to support these configs in `get_model_config`. While it is possible to fix this on LeRobot's side (add a to_dict implementation to the config classes) I think it'd be cleaner to support it on our side since the cost is relatively low and dataclasses are getting more popular anyway. Thanks @xliu0105 for raising this issue and proposing a fix. --- src/peft/tuners/tuners_utils.py | 3 +++ tests/test_tuners_utils.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 62e2c515..94b27285 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -14,6 +14,7 @@ from __future__ import annotations import copy +import dataclasses import os import re import textwrap @@ -879,6 +880,8 @@ class BaseTuner(nn.Module, ABC): model_config = getattr(model, "config", DUMMY_MODEL_CONFIG) if hasattr(model_config, "to_dict"): model_config = model_config.to_dict() + elif dataclasses.is_dataclass(model_config): + model_config = dataclasses.asdict(model_config) return model_config def _get_tied_target_modules(self, model: nn.Module) -> list[str]: diff --git a/tests/test_tuners_utils.py b/tests/test_tuners_utils.py index 6a32a590..441ecff0 100644 --- a/tests/test_tuners_utils.py +++ b/tests/test_tuners_utils.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import re import unittest from copy import deepcopy @@ -1312,6 +1313,11 @@ class MockModelConfig: return self.config +@dataclasses.dataclass +class MockModelDataclassConfig: + mock_key: str + + class ModelWithConfig(nn.Module): def __init__(self): self.config = MockModelConfig() @@ -1322,6 +1328,11 @@ class ModelWithDictConfig(nn.Module): self.config = MockModelConfig.config +class ModelWithDataclassConfig(nn.Module): + def __init__(self): + self.config = MockModelDataclassConfig(**MockModelConfig().to_dict()) + + class ModelWithNoConfig(nn.Module): pass @@ -1339,6 +1350,10 @@ class TestBaseTunerGetModelConfig(unittest.TestCase): config = BaseTuner.get_model_config(ModelWithNoConfig()) assert config == DUMMY_MODEL_CONFIG + def test_get_model_config_with_dataclass(self): + config = BaseTuner.get_model_config(ModelWithDataclassConfig()) + assert config == MockModelConfig.config + class TestBaseTunerWarnForTiedEmbeddings: model_id = "HuggingFaceH4/tiny-random-LlamaForCausalLM"