mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Consider hub_dir alongside TORCH_HOME env variable for storing hub models (#32844)
Summary: Fixes https://github.com/pytorch/pytorch/issues/31944 Pull Request resolved: https://github.com/pytorch/pytorch/pull/32844 Differential Revision: D19747566 Pulled By: ailzhang fbshipit-source-id: caca41a3a057d7d280d4783515aba2cc48c82012
This commit is contained in:
committed by
Facebook Github Bot
parent
74ce3a032c
commit
1b746b95fb
@ -10,6 +10,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.data
|
||||
import torch.cuda
|
||||
from torch._six import PY2
|
||||
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
|
||||
import torch.hub as hub
|
||||
from torch.autograd._functions.utils import check_onnx_broadcast
|
||||
@ -562,6 +563,12 @@ class TestHub(TestCase):
|
||||
self.assertEqual(sum_of_state_dict(hub_model.state_dict()),
|
||||
SUM_OF_HUB_EXAMPLE)
|
||||
|
||||
@unittest.skipIf(PY2, "Requires python 3")
|
||||
def test_hub_dir(self):
|
||||
with tempfile.TemporaryDirectory('hub_dir') as dirname:
|
||||
torch.hub.set_dir(dirname)
|
||||
self.assertEqual(torch.hub._get_torch_home(), dirname)
|
||||
|
||||
|
||||
class TestHipify(TestCase):
|
||||
def test_import_hipify(self):
|
||||
|
@ -102,9 +102,12 @@ def _load_attr_from_module(module, func_name):
|
||||
|
||||
|
||||
def _get_torch_home():
|
||||
torch_home = os.path.expanduser(
|
||||
os.getenv(ENV_TORCH_HOME,
|
||||
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch')))
|
||||
torch_home = hub_dir
|
||||
if torch_home is None:
|
||||
torch_home = os.path.expanduser(
|
||||
os.getenv(ENV_TORCH_HOME,
|
||||
os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
|
||||
DEFAULT_CACHE_DIR), 'torch')))
|
||||
return torch_home
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user