Fix torch.hub for new zipfile format. (#42333)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/42239

Pull Request resolved: https://github.com/pytorch/pytorch/pull/42333

Reviewed By: VitalyFedyunin

Differential Revision: D23215210

Pulled By: ailzhang

fbshipit-source-id: 161ead8b457c11655dd2cab5eecfd0edf7ae5c2b
This commit is contained in:
Ailing Zhang
2020-08-20 14:49:00 -07:00
committed by Facebook GitHub Bot
parent dae2973fae
commit 51bab0877d
2 changed files with 39 additions and 12 deletions

View File

@ -571,6 +571,18 @@ class TestHub(TestCase):
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
SUM_OF_HUB_EXAMPLE)
# Test the default zipfile serialization format produced by >=1.6 release.
@retry(URLError, tries=3, skip_after_retries=True)
def test_load_zip_1_6_checkpoint(self):
hub_model = hub.load(
'ailzhang/torchhub_example',
'mnist_zip_1_6',
pretrained=True,
verbose=False)
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
SUM_OF_HUB_EXAMPLE)
def test_hub_dir(self):
with tempfile.TemporaryDirectory('hub_dir') as dirname:
torch.hub.set_dir(dirname)