mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598 ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18598 Turn on F401: Unused import warning.** This was requested by someone at Facebook; this lint is turned on for Facebook by default. "Sure, why not." I had to noqa a number of imports in __init__. Hypothetically we're supposed to use __all__ in this case, but I was too lazy to fix it. Left for future work. Be careful! flake8-2 and flake8-3 behave differently with respect to import resolution for # type: comments. flake8-3 will report an import unused; flake8-2 will not. For now, I just noqa'd all these sites. All the changes were done by hand. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: D14687478 fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3
171 lines
5.3 KiB
Python
171 lines
5.3 KiB
Python
import importlib
|
|
import os
|
|
import shutil
|
|
import sys
|
|
import zipfile
|
|
|
|
if sys.version_info[0] == 2:
|
|
from urlparse import urlparse
|
|
from urllib2 import urlopen # noqa f811
|
|
else:
|
|
from urllib.request import urlopen
|
|
from urllib.parse import urlparse # noqa: F401
|
|
|
|
MASTER_BRANCH = 'master'
|
|
ENV_TORCH_HUB_DIR = 'TORCH_HUB_DIR'
|
|
DEFAULT_TORCH_HUB_DIR = '~/.torch/hub'
|
|
READ_DATA_CHUNK = 8192
|
|
hub_dir = None
|
|
|
|
|
|
def _check_module_exists(name):
|
|
if sys.version_info >= (3, 4):
|
|
import importlib.util
|
|
return importlib.util.find_spec(name) is not None
|
|
elif sys.version_info >= (3, 3):
|
|
# Special case for python3.3
|
|
import importlib.find_loader
|
|
return importlib.find_loader(name) is not None
|
|
else:
|
|
# NB: imp doesn't handle hierarchical module names (names contains dots).
|
|
try:
|
|
import imp
|
|
imp.find_module(name)
|
|
except Exception:
|
|
return False
|
|
return True
|
|
|
|
|
|
def _remove_if_exists(path):
|
|
if os.path.exists(path):
|
|
if os.path.isfile(path):
|
|
os.remove(path)
|
|
else:
|
|
shutil.rmtree(path)
|
|
|
|
|
|
def _git_archive_link(repo, branch):
|
|
return 'https://github.com/' + repo + '/archive/' + branch + '.zip'
|
|
|
|
|
|
def _download_url_to_file(url, filename):
|
|
sys.stderr.write('Downloading: \"{}\" to {}'.format(url, filename))
|
|
response = urlopen(url)
|
|
with open(filename, 'wb') as f:
|
|
while True:
|
|
data = response.read(READ_DATA_CHUNK)
|
|
if len(data) == 0:
|
|
break
|
|
f.write(data)
|
|
|
|
|
|
def _load_attr_from_module(module_name, func_name):
|
|
m = importlib.import_module(module_name)
|
|
# Check if callable is defined in the module
|
|
if func_name not in dir(m):
|
|
return None
|
|
return getattr(m, func_name)
|
|
|
|
|
|
def set_dir(d):
|
|
r"""
|
|
Optionally set hub_dir to a local dir to save downloaded models & weights.
|
|
|
|
If this argument is not set, env variable `TORCH_HUB_DIR` will be searched first,
|
|
`~/.torch/hub` will be created and used as fallback.
|
|
|
|
Args:
|
|
d: path to a local folder to save downloaded models & weights.
|
|
"""
|
|
global hub_dir
|
|
hub_dir = d
|
|
|
|
|
|
def load(github, model, force_reload=False, *args, **kwargs):
|
|
r"""
|
|
Load a model from a github repo, with pretrained weights.
|
|
|
|
Args:
|
|
github: Required, a string with format "repo_owner/repo_name[:tag_name]" with an optional
|
|
tag/branch. The default branch is `master` if not specified.
|
|
Example: 'pytorch/vision[:hub]'
|
|
model: Required, a string of entrypoint name defined in repo's hubconf.py
|
|
force_reload: Optional, whether to discard the existing cache and force a fresh download.
|
|
Default is `False`.
|
|
*args: Optional, the corresponding args for callable `model`.
|
|
**kwargs: Optional, the corresponding kwargs for callable `model`.
|
|
|
|
Returns:
|
|
a single model with corresponding pretrained weights.
|
|
"""
|
|
|
|
if not isinstance(model, str):
|
|
raise ValueError('Invalid input: model should be a string of function name')
|
|
|
|
# Setup hub_dir to save downloaded files
|
|
global hub_dir
|
|
if hub_dir is None:
|
|
hub_dir = os.getenv(ENV_TORCH_HUB_DIR, DEFAULT_TORCH_HUB_DIR)
|
|
|
|
if '~' in hub_dir:
|
|
hub_dir = os.path.expanduser(hub_dir)
|
|
|
|
if not os.path.exists(hub_dir):
|
|
os.makedirs(hub_dir)
|
|
|
|
# Parse github repo information
|
|
branch = MASTER_BRANCH
|
|
if ':' in github:
|
|
repo_info, branch = github.split(':')
|
|
else:
|
|
repo_info = github
|
|
repo_owner, repo_name = repo_info.split('/')
|
|
|
|
# Download zipped code from github
|
|
url = _git_archive_link(repo_info, branch)
|
|
cached_file = os.path.join(hub_dir, branch + '.zip')
|
|
repo_dir = os.path.join(hub_dir, repo_name + '_' + branch)
|
|
|
|
use_cache = (not force_reload) and os.path.exists(repo_dir)
|
|
|
|
# Github uses '{repo_name}-{branch_name}' as folder name which is not importable
|
|
# We need to manually rename it to '{repo_name}'
|
|
# Unzip the code and rename the base folder
|
|
if use_cache:
|
|
sys.stderr.write('Using cache found in {}'.format(repo_dir))
|
|
else:
|
|
_remove_if_exists(cached_file)
|
|
_download_url_to_file(url, cached_file)
|
|
|
|
cached_zipfile = zipfile.ZipFile(cached_file)
|
|
|
|
# Github renames folder repo-v1.x.x to repo-1.x.x
|
|
extraced_repo_name = cached_zipfile.infolist()[0].filename
|
|
extracted_repo = os.path.join(hub_dir, extraced_repo_name)
|
|
_remove_if_exists(extracted_repo)
|
|
cached_zipfile.extractall(hub_dir)
|
|
|
|
_remove_if_exists(cached_file)
|
|
_remove_if_exists(repo_dir)
|
|
shutil.move(extracted_repo, repo_dir) # rename the repo
|
|
|
|
sys.path.insert(0, repo_dir) # Make Python interpreter aware of the repo
|
|
|
|
dependencies = _load_attr_from_module('hubconf', 'dependencies')
|
|
|
|
if dependencies is not None:
|
|
missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)]
|
|
if len(missing_deps):
|
|
raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps)))
|
|
|
|
func = _load_attr_from_module('hubconf', model)
|
|
if func is None:
|
|
raise RuntimeError('Cannot find callable {} in hubconf'.format(model))
|
|
|
|
# Check if func is callable
|
|
if not callable(func):
|
|
raise RuntimeError('{} is not callable'.format(func))
|
|
|
|
# Call the function
|
|
return func(*args, **kwargs)
|