mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
use context manager for path extension in torch.hub (#75786)
We are using the idiom
```py
sys.path.insert(0, path)
# do something
sys.path.remove(path)
```
three times in `torch.hub`. This is a textbook case for using a context manager. In addition, by using `try` / `finally` we can enforce the Python path is back in its original state even if the actual action raises an exception:
```py
import sys
path = "/tmp"
# PR
try:
sys.path.insert(0, path)
try:
# Any exception raised while performing the actual functionality
raise Exception
finally:
sys.path.remove(path)
except Exception:
assert path not in sys.path
# main
try:
sys.path.insert(0, path)
# Any exception raised while performing the actual functionality
raise Exception
sys.path.remove(path)
except Exception:
assert path in sys.path
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75786
Approved by: https://github.com/NicolasHug
This commit is contained in:
committed by
PyTorch MergeBot
parent
d85f3c8237
commit
fe80f190df
41
torch/hub.py
41
torch/hub.py
@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
import errno
|
||||
import hashlib
|
||||
import json
|
||||
@ -77,6 +78,15 @@ READ_DATA_CHUNK = 8192
|
||||
_hub_dir = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _add_to_sys_path(path):
|
||||
sys.path.insert(0, path)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.path.remove(path)
|
||||
|
||||
|
||||
# Copied from tools/shared/module_loader to be included in torch package
|
||||
def _import_module(name, path):
|
||||
import importlib.util
|
||||
@ -394,12 +404,9 @@ def list(github, force_reload=False, skip_validation=False, trust_repo=None):
|
||||
repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "list", verbose=True,
|
||||
skip_validation=skip_validation)
|
||||
|
||||
sys.path.insert(0, repo_dir)
|
||||
|
||||
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
|
||||
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
|
||||
|
||||
sys.path.remove(repo_dir)
|
||||
with _add_to_sys_path(repo_dir):
|
||||
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
|
||||
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
|
||||
|
||||
# We take functions starts with '_' as internal helper functions
|
||||
entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')]
|
||||
@ -447,12 +454,9 @@ def help(github, model, force_reload=False, skip_validation=False, trust_repo=No
|
||||
repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "help", verbose=True,
|
||||
skip_validation=skip_validation)
|
||||
|
||||
sys.path.insert(0, repo_dir)
|
||||
|
||||
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
|
||||
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
|
||||
|
||||
sys.path.remove(repo_dir)
|
||||
with _add_to_sys_path(repo_dir):
|
||||
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
|
||||
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
|
||||
|
||||
entry = _load_entry_from_hubconf(hub_module, model)
|
||||
|
||||
@ -564,15 +568,12 @@ def _load_local(hubconf_dir, model, *args, **kwargs):
|
||||
>>> path = '/some/local/path/pytorch/vision'
|
||||
>>> model = _load_local(path, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1')
|
||||
"""
|
||||
sys.path.insert(0, hubconf_dir)
|
||||
with _add_to_sys_path(hubconf_dir):
|
||||
hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF)
|
||||
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
|
||||
|
||||
hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF)
|
||||
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
|
||||
|
||||
entry = _load_entry_from_hubconf(hub_module, model)
|
||||
model = entry(*args, **kwargs)
|
||||
|
||||
sys.path.remove(hubconf_dir)
|
||||
entry = _load_entry_from_hubconf(hub_module, model)
|
||||
model = entry(*args, **kwargs)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user