use context manager for path extension in torch.hub (#75786)

We are using the idiom

```py
sys.path.insert(0, path)

# do something

sys.path.remove(path)
```

three times in `torch.hub`. This is a textbook case for using a context manager. In addition, by using `try` / `finally` we can enforce the Python path is back in its original state even if the actual action raises an exception:

```py
import sys

path = "/tmp"

# PR
try:
    sys.path.insert(0, path)
    try:
        # Any exception raised while performing the actual functionality
        raise Exception
    finally:
        sys.path.remove(path)
except Exception:
    assert path not in sys.path

# main
try:
    sys.path.insert(0, path)

    # Any exception raised while performing the actual functionality
    raise Exception

    sys.path.remove(path)
except Exception:
    assert path in sys.path
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75786
Approved by: https://github.com/NicolasHug
This commit is contained in:
Philip Meier
2023-01-09 07:08:35 +00:00
committed by PyTorch MergeBot
parent d85f3c8237
commit fe80f190df

View File

@ -1,3 +1,4 @@
import contextlib
import errno
import hashlib
import json
@ -77,6 +78,15 @@ READ_DATA_CHUNK = 8192
_hub_dir = None
@contextlib.contextmanager
def _add_to_sys_path(path):
sys.path.insert(0, path)
try:
yield
finally:
sys.path.remove(path)
# Copied from tools/shared/module_loader to be included in torch package
def _import_module(name, path):
import importlib.util
@ -394,12 +404,9 @@ def list(github, force_reload=False, skip_validation=False, trust_repo=None):
repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "list", verbose=True,
skip_validation=skip_validation)
sys.path.insert(0, repo_dir)
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
sys.path.remove(repo_dir)
with _add_to_sys_path(repo_dir):
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
# We take functions starts with '_' as internal helper functions
entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')]
@ -447,12 +454,9 @@ def help(github, model, force_reload=False, skip_validation=False, trust_repo=No
repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "help", verbose=True,
skip_validation=skip_validation)
sys.path.insert(0, repo_dir)
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
sys.path.remove(repo_dir)
with _add_to_sys_path(repo_dir):
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
entry = _load_entry_from_hubconf(hub_module, model)
@ -564,15 +568,12 @@ def _load_local(hubconf_dir, model, *args, **kwargs):
>>> path = '/some/local/path/pytorch/vision'
>>> model = _load_local(path, 'resnet50', weights='ResNet50_Weights.IMAGENET1K_V1')
"""
sys.path.insert(0, hubconf_dir)
with _add_to_sys_path(hubconf_dir):
hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF)
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF)
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
entry = _load_entry_from_hubconf(hub_module, model)
model = entry(*args, **kwargs)
sys.path.remove(hubconf_dir)
entry = _load_entry_from_hubconf(hub_module, model)
model = entry(*args, **kwargs)
return model