Fix: loading DBRX back from saved path (#35728)

* fix dtype as dict for some models + add test

* add comment in tests
This commit is contained in:
Raushan Turganbay
2025-01-28 11:38:45 +01:00
committed by GitHub
parent 3613f568cd
commit b764c20b09
4 changed files with 13 additions and 4 deletions

View File

@ -466,13 +466,14 @@ class ModelUtilsTest(TestCasePlus):
def test_model_from_config_torch_dtype_composite(self):
"""
Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config
Tiny-Llava has saved auto dtype as `torch.float32` for all modules.
"""
# should be able to set torch_dtype as a simple string and the model loads it correctly
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32")
self.assertEqual(model.language_model.dtype, torch.float32)
self.assertEqual(model.vision_tower.dtype, torch.float32)
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float16")
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype=torch.float16)
self.assertEqual(model.language_model.dtype, torch.float16)
self.assertEqual(model.vision_tower.dtype, torch.float16)