mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-30 08:34:36 +08:00
Fix: loading DBRX back from saved path (#35728)
* fix dtype as dict for some models + add test * add comment in tests
This commit is contained in:
committed by
GitHub
parent
3613f568cd
commit
b764c20b09
@ -466,13 +466,14 @@ class ModelUtilsTest(TestCasePlus):
|
||||
def test_model_from_config_torch_dtype_composite(self):
|
||||
"""
|
||||
Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config
|
||||
Tiny-Llava has saved auto dtype as `torch.float32` for all modules.
|
||||
"""
|
||||
# should be able to set torch_dtype as a simple string and the model loads it correctly
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32")
|
||||
self.assertEqual(model.language_model.dtype, torch.float32)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float32)
|
||||
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float16")
|
||||
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype=torch.float16)
|
||||
self.assertEqual(model.language_model.dtype, torch.float16)
|
||||
self.assertEqual(model.vision_tower.dtype, torch.float16)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user