mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
Support dataclass model configs (#2778)
LeRobot uses dataclasses to manage policy configs. If we want to support LeRobot policy fine-tuning it'd be easiest to support these configs in `get_model_config`. While it is possible to fix this on LeRobot's side (add a to_dict implementation to the config classes) I think it'd be cleaner to support it on our side since the cost is relatively low and dataclasses are getting more popular anyway. Thanks @xliu0105 for raising this issue and proposing a fix.
This commit is contained in:
@ -14,6 +14,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
@ -879,6 +880,8 @@ class BaseTuner(nn.Module, ABC):
|
||||
model_config = getattr(model, "config", DUMMY_MODEL_CONFIG)
|
||||
if hasattr(model_config, "to_dict"):
|
||||
model_config = model_config.to_dict()
|
||||
elif dataclasses.is_dataclass(model_config):
|
||||
model_config = dataclasses.asdict(model_config)
|
||||
return model_config
|
||||
|
||||
def _get_tied_target_modules(self, model: nn.Module) -> list[str]:
|
||||
|
@ -14,6 +14,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import dataclasses
|
||||
import re
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
@ -1312,6 +1313,11 @@ class MockModelConfig:
|
||||
return self.config
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MockModelDataclassConfig:
|
||||
mock_key: str
|
||||
|
||||
|
||||
class ModelWithConfig(nn.Module):
|
||||
def __init__(self):
|
||||
self.config = MockModelConfig()
|
||||
@ -1322,6 +1328,11 @@ class ModelWithDictConfig(nn.Module):
|
||||
self.config = MockModelConfig.config
|
||||
|
||||
|
||||
class ModelWithDataclassConfig(nn.Module):
|
||||
def __init__(self):
|
||||
self.config = MockModelDataclassConfig(**MockModelConfig().to_dict())
|
||||
|
||||
|
||||
class ModelWithNoConfig(nn.Module):
|
||||
pass
|
||||
|
||||
@ -1339,6 +1350,10 @@ class TestBaseTunerGetModelConfig(unittest.TestCase):
|
||||
config = BaseTuner.get_model_config(ModelWithNoConfig())
|
||||
assert config == DUMMY_MODEL_CONFIG
|
||||
|
||||
def test_get_model_config_with_dataclass(self):
|
||||
config = BaseTuner.get_model_config(ModelWithDataclassConfig())
|
||||
assert config == MockModelConfig.config
|
||||
|
||||
|
||||
class TestBaseTunerWarnForTiedEmbeddings:
|
||||
model_id = "HuggingFaceH4/tiny-random-LlamaForCausalLM"
|
||||
|
Reference in New Issue
Block a user