Consider hub_dir alongside TORCH_HOME env variable for storing hub models (#32844)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/31944
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32844

Differential Revision: D19747566

Pulled By: ailzhang

fbshipit-source-id: caca41a3a057d7d280d4783515aba2cc48c82012
This commit is contained in:
Edgar Andrés Margffoy Tuay
2020-02-05 15:32:37 -08:00
committed by Facebook Github Bot
parent 74ce3a032c
commit 1b746b95fb
2 changed files with 13 additions and 3 deletions

View File

@ -10,6 +10,7 @@ import torch
import torch.nn as nn
import torch.utils.data
import torch.cuda
from torch._six import PY2
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
import torch.hub as hub
from torch.autograd._functions.utils import check_onnx_broadcast
@ -562,6 +563,12 @@ class TestHub(TestCase):
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
SUM_OF_HUB_EXAMPLE)
@unittest.skipIf(PY2, "Requires python 3")
def test_hub_dir(self):
with tempfile.TemporaryDirectory('hub_dir') as dirname:
torch.hub.set_dir(dirname)
self.assertEqual(torch.hub._get_torch_home(), dirname)
class TestHipify(TestCase):
def test_import_hipify(self):

View File

@ -102,9 +102,12 @@ def _load_attr_from_module(module, func_name):
def _get_torch_home():
torch_home = os.path.expanduser(
os.getenv(ENV_TORCH_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch')))
torch_home = hub_dir
if torch_home is None:
torch_home = os.path.expanduser(
os.getenv(ENV_TORCH_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
DEFAULT_CACHE_DIR), 'torch')))
return torch_home