mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 14:34:54 +08:00
Fix torch.hub for new zipfile format. (#42333)
Summary: Fixes https://github.com/pytorch/pytorch/issues/42239 Pull Request resolved: https://github.com/pytorch/pytorch/pull/42333 Reviewed By: VitalyFedyunin Differential Revision: D23215210 Pulled By: ailzhang fbshipit-source-id: 161ead8b457c11655dd2cab5eecfd0edf7ae5c2b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
dae2973fae
commit
51bab0877d
@ -571,6 +571,18 @@ class TestHub(TestCase):
|
||||
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
|
||||
SUM_OF_HUB_EXAMPLE)
|
||||
|
||||
# Test the default zipfile serialization format produced by >=1.6 release.
|
||||
@retry(URLError, tries=3, skip_after_retries=True)
|
||||
def test_load_zip_1_6_checkpoint(self):
|
||||
hub_model = hub.load(
|
||||
'ailzhang/torchhub_example',
|
||||
'mnist_zip_1_6',
|
||||
pretrained=True,
|
||||
verbose=False)
|
||||
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
|
||||
SUM_OF_HUB_EXAMPLE)
|
||||
|
||||
|
||||
def test_hub_dir(self):
|
||||
with tempfile.TemporaryDirectory('hub_dir') as dirname:
|
||||
torch.hub.set_dir(dirname)
|
||||
|
||||
Reference in New Issue
Block a user