Support dataclass model configs (#2778)

LeRobot uses dataclasses to manage policy configs. If we want to
support LeRobot policy fine-tuning it'd be easiest to support
these configs in `get_model_config`.

While it is possible to fix this on LeRobot's side (add a to_dict implementation to the config classes) I think it'd be cleaner to support it on our side since the cost is relatively low and dataclasses are getting more popular anyway.

Thanks @xliu0105 for raising this issue and proposing a fix.
This commit is contained in:
githubnemo
2025-09-08 13:35:47 +02:00
committed by GitHub
parent 5d97453235
commit c81363bd4e
2 changed files with 18 additions and 0 deletions

View File

@ -14,6 +14,7 @@
from __future__ import annotations
import copy
import dataclasses
import os
import re
import textwrap
@ -879,6 +880,8 @@ class BaseTuner(nn.Module, ABC):
model_config = getattr(model, "config", DUMMY_MODEL_CONFIG)
if hasattr(model_config, "to_dict"):
model_config = model_config.to_dict()
elif dataclasses.is_dataclass(model_config):
model_config = dataclasses.asdict(model_config)
return model_config
def _get_tied_target_modules(self, model: nn.Module) -> list[str]:

View File

@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import re
import unittest
from copy import deepcopy
@ -1312,6 +1313,11 @@ class MockModelConfig:
return self.config
@dataclasses.dataclass
class MockModelDataclassConfig:
mock_key: str
class ModelWithConfig(nn.Module):
def __init__(self):
self.config = MockModelConfig()
@ -1322,6 +1328,11 @@ class ModelWithDictConfig(nn.Module):
self.config = MockModelConfig.config
class ModelWithDataclassConfig(nn.Module):
def __init__(self):
self.config = MockModelDataclassConfig(**MockModelConfig().to_dict())
class ModelWithNoConfig(nn.Module):
pass
@ -1339,6 +1350,10 @@ class TestBaseTunerGetModelConfig(unittest.TestCase):
config = BaseTuner.get_model_config(ModelWithNoConfig())
assert config == DUMMY_MODEL_CONFIG
def test_get_model_config_with_dataclass(self):
config = BaseTuner.get_model_config(ModelWithDataclassConfig())
assert config == MockModelConfig.config
class TestBaseTunerWarnForTiedEmbeddings:
model_id = "HuggingFaceH4/tiny-random-LlamaForCausalLM"