Fix failures when default is flipped for weights_only (#127627)

Tests on XLA shard not fixed yet but there is an issue here https://github.com/pytorch/xla/issues/7799

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127627
Approved by: https://github.com/albanD
ghstack dependencies: #132349
This commit is contained in:
Mikayla Gawarecki
2024-08-15 19:48:35 +00:00
committed by PyTorch MergeBot
parent c8ad5e37e8
commit d9576c9440
22 changed files with 135 additions and 78 deletions

View File

@ -657,7 +657,8 @@ class QuantizationTestCase(TestCase):
b = io.BytesIO()
torch.save(model_dict, b)
b.seek(0)
loaded_dict = torch.load(b)
# weights_only=False as we sometimes get a ScriptObect here (weird)
loaded_dict = torch.load(b, weights_only=False)
loaded_model.load_state_dict(loaded_dict)
ref_out = ref_model(*x)
load_out = loaded_model(*x)
@ -674,7 +675,8 @@ class QuantizationTestCase(TestCase):
b = io.BytesIO()
torch.save(ref_model, b)
b.seek(0)
loaded = torch.load(b)
# weights_only=False as this is legacy code that saves the model
loaded = torch.load(b, weights_only=False)
load_out = loaded(*x)
check_outputs(ref_out, load_out)