mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: manual inspection & sandcastle Reviewed By: zertosh Differential Revision: D30279364 fbshipit-source-id: c1ed77dfe43a3bde358f92737cd5535ae5d13c9a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
aac3c7bd06
commit
b004307252
199
torch/hub.py
199
torch/hub.py
@ -6,24 +6,31 @@ import re
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import torch
|
||||
import warnings
|
||||
import zipfile
|
||||
|
||||
from urllib.request import urlopen, Request
|
||||
from urllib.parse import urlparse # noqa: F401
|
||||
from urllib.request import urlopen, Request
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from tqdm.auto import tqdm # automatically select proper tqdm submodule if available
|
||||
from tqdm.auto import (
|
||||
tqdm,
|
||||
) # automatically select proper tqdm submodule if available
|
||||
except ImportError:
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
except ImportError:
|
||||
# fake tqdm if it's not installed
|
||||
class tqdm(object): # type: ignore[no-redef]
|
||||
|
||||
def __init__(self, total=None, disable=False,
|
||||
unit=None, unit_scale=None, unit_divisor=None):
|
||||
def __init__(
|
||||
self,
|
||||
total=None,
|
||||
disable=False,
|
||||
unit=None,
|
||||
unit_scale=None,
|
||||
unit_divisor=None,
|
||||
):
|
||||
self.total = total
|
||||
self.disable = disable
|
||||
self.n = 0
|
||||
@ -37,7 +44,9 @@ except ImportError:
|
||||
if self.total is None:
|
||||
sys.stderr.write("\r{0:.1f} bytes".format(self.n))
|
||||
else:
|
||||
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total)))
|
||||
sys.stderr.write(
|
||||
"\r{0:.1f}%".format(100 * self.n / float(self.total))
|
||||
)
|
||||
sys.stderr.flush()
|
||||
|
||||
def close(self):
|
||||
@ -50,18 +59,19 @@ except ImportError:
|
||||
if self.disable:
|
||||
return
|
||||
|
||||
sys.stderr.write('\n')
|
||||
sys.stderr.write("\n")
|
||||
|
||||
|
||||
# matches bfd8deac from resnet18-bfd8deac.pth
|
||||
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
|
||||
HASH_REGEX = re.compile(r"-([a-f0-9]*)\.")
|
||||
|
||||
MASTER_BRANCH = 'master'
|
||||
ENV_GITHUB_TOKEN = 'GITHUB_TOKEN'
|
||||
ENV_TORCH_HOME = 'TORCH_HOME'
|
||||
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
||||
DEFAULT_CACHE_DIR = '~/.cache'
|
||||
VAR_DEPENDENCY = 'dependencies'
|
||||
MODULE_HUBCONF = 'hubconf.py'
|
||||
MASTER_BRANCH = "master"
|
||||
ENV_GITHUB_TOKEN = "GITHUB_TOKEN"
|
||||
ENV_TORCH_HOME = "TORCH_HOME"
|
||||
ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
|
||||
DEFAULT_CACHE_DIR = "~/.cache"
|
||||
VAR_DEPENDENCY = "dependencies"
|
||||
MODULE_HUBCONF = "hubconf.py"
|
||||
READ_DATA_CHUNK = 8192
|
||||
_hub_dir = None
|
||||
|
||||
@ -70,6 +80,7 @@ _hub_dir = None
|
||||
def import_module(name, path):
|
||||
import importlib.util
|
||||
from importlib.abc import Loader
|
||||
|
||||
spec = importlib.util.spec_from_file_location(name, path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert isinstance(spec.loader, Loader)
|
||||
@ -86,7 +97,9 @@ def _remove_if_exists(path):
|
||||
|
||||
|
||||
def _git_archive_link(repo_owner, repo_name, branch):
|
||||
return 'https://github.com/{}/{}/archive/{}.zip'.format(repo_owner, repo_name, branch)
|
||||
return "https://github.com/{}/{}/archive/{}.zip".format(
|
||||
repo_owner, repo_name, branch
|
||||
)
|
||||
|
||||
|
||||
def _load_attr_from_module(module, func_name):
|
||||
@ -98,50 +111,55 @@ def _load_attr_from_module(module, func_name):
|
||||
|
||||
def _get_torch_home():
|
||||
torch_home = os.path.expanduser(
|
||||
os.getenv(ENV_TORCH_HOME,
|
||||
os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
|
||||
DEFAULT_CACHE_DIR), 'torch')))
|
||||
os.getenv(
|
||||
ENV_TORCH_HOME,
|
||||
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"),
|
||||
)
|
||||
)
|
||||
return torch_home
|
||||
|
||||
|
||||
def _parse_repo_info(github):
|
||||
branch = MASTER_BRANCH
|
||||
if ':' in github:
|
||||
repo_info, branch = github.split(':')
|
||||
if ":" in github:
|
||||
repo_info, branch = github.split(":")
|
||||
else:
|
||||
repo_info = github
|
||||
repo_owner, repo_name = repo_info.split('/')
|
||||
repo_owner, repo_name = repo_info.split("/")
|
||||
return repo_owner, repo_name, branch
|
||||
|
||||
|
||||
def _read_url(url):
|
||||
with urlopen(url) as r:
|
||||
return r.read().decode(r.headers.get_content_charset('utf-8'))
|
||||
return r.read().decode(r.headers.get_content_charset("utf-8"))
|
||||
|
||||
|
||||
def _validate_not_a_forked_repo(repo_owner, repo_name, branch):
|
||||
# Use urlopen to avoid depending on local git.
|
||||
headers = {'Accept': 'application/vnd.github.v3+json'}
|
||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||
token = os.environ.get(ENV_GITHUB_TOKEN)
|
||||
if token is not None:
|
||||
headers['Authorization'] = f'token {token}'
|
||||
headers["Authorization"] = f"token {token}"
|
||||
for url_prefix in (
|
||||
f'https://api.github.com/repos/{repo_owner}/{repo_name}/branches',
|
||||
f'https://api.github.com/repos/{repo_owner}/{repo_name}/tags'):
|
||||
f"https://api.github.com/repos/{repo_owner}/{repo_name}/branches",
|
||||
f"https://api.github.com/repos/{repo_owner}/{repo_name}/tags",
|
||||
):
|
||||
page = 0
|
||||
while True:
|
||||
page += 1
|
||||
url = f'{url_prefix}?per_page=100&page={page}'
|
||||
url = f"{url_prefix}?per_page=100&page={page}"
|
||||
response = json.loads(_read_url(Request(url, headers=headers)))
|
||||
# Empty response means no more data to process
|
||||
if not response:
|
||||
break
|
||||
for br in response:
|
||||
if br['name'] == branch or br['commit']['sha'].startswith(branch):
|
||||
if br["name"] == branch or br["commit"]["sha"].startswith(branch):
|
||||
return
|
||||
|
||||
raise ValueError(f'Cannot find {branch} in https://github.com/{repo_owner}/{repo_name}. '
|
||||
'If it\'s a commit from a forked repo, please call hub.load() with forked repo directly.')
|
||||
raise ValueError(
|
||||
f"Cannot find {branch} in https://github.com/{repo_owner}/{repo_name}. "
|
||||
"If it's a commit from a forked repo, please call hub.load() with forked repo directly."
|
||||
)
|
||||
|
||||
|
||||
def _get_cache_or_reload(github, force_reload, verbose=True, skip_validation=False):
|
||||
@ -155,28 +173,28 @@ def _get_cache_or_reload(github, force_reload, verbose=True, skip_validation=Fal
|
||||
# this causes confusion with path on both Linux and Windows.
|
||||
# Backslash is not allowed in Github branch name so no need to
|
||||
# to worry about it.
|
||||
normalized_br = branch.replace('/', '_')
|
||||
normalized_br = branch.replace("/", "_")
|
||||
# Github renames folder repo-v1.x.x to repo-1.x.x
|
||||
# We don't know the repo name before downloading the zip file
|
||||
# and inspect name from it.
|
||||
# To check if cached repo exists, we need to normalize folder names.
|
||||
repo_dir = os.path.join(hub_dir, '_'.join([repo_owner, repo_name, normalized_br]))
|
||||
repo_dir = os.path.join(hub_dir, "_".join([repo_owner, repo_name, normalized_br]))
|
||||
|
||||
use_cache = (not force_reload) and os.path.exists(repo_dir)
|
||||
|
||||
if use_cache:
|
||||
if verbose:
|
||||
sys.stderr.write('Using cache found in {}\n'.format(repo_dir))
|
||||
sys.stderr.write("Using cache found in {}\n".format(repo_dir))
|
||||
else:
|
||||
# Validate the tag/branch is from the original repo instead of a forked repo
|
||||
if not skip_validation:
|
||||
_validate_not_a_forked_repo(repo_owner, repo_name, branch)
|
||||
|
||||
cached_file = os.path.join(hub_dir, normalized_br + '.zip')
|
||||
cached_file = os.path.join(hub_dir, normalized_br + ".zip")
|
||||
_remove_if_exists(cached_file)
|
||||
|
||||
url = _git_archive_link(repo_owner, repo_name, branch)
|
||||
sys.stderr.write('Downloading: \"{}\" to {}\n'.format(url, cached_file))
|
||||
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
download_url_to_file(url, cached_file, progress=False)
|
||||
|
||||
with zipfile.ZipFile(cached_file) as cached_zipfile:
|
||||
@ -195,6 +213,7 @@ def _get_cache_or_reload(github, force_reload, verbose=True, skip_validation=Fal
|
||||
|
||||
def _check_module_exists(name):
|
||||
import importlib.util
|
||||
|
||||
return importlib.util.find_spec(name) is not None
|
||||
|
||||
|
||||
@ -204,12 +223,14 @@ def _check_dependencies(m):
|
||||
if dependencies is not None:
|
||||
missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)]
|
||||
if len(missing_deps):
|
||||
raise RuntimeError('Missing dependencies: {}'.format(', '.join(missing_deps)))
|
||||
raise RuntimeError(
|
||||
"Missing dependencies: {}".format(", ".join(missing_deps))
|
||||
)
|
||||
|
||||
|
||||
def _load_entry_from_hubconf(m, model):
|
||||
if not isinstance(model, str):
|
||||
raise ValueError('Invalid input: model should be a string of function name')
|
||||
raise ValueError("Invalid input: model should be a string of function name")
|
||||
|
||||
# Note that if a missing dependency is imported at top level of hubconf, it will
|
||||
# throw before this function. It's a chicken and egg situation where we have to
|
||||
@ -220,7 +241,7 @@ def _load_entry_from_hubconf(m, model):
|
||||
func = _load_attr_from_module(m, model)
|
||||
|
||||
if func is None or not callable(func):
|
||||
raise RuntimeError('Cannot find callable {} in hubconf'.format(model))
|
||||
raise RuntimeError("Cannot find callable {} in hubconf".format(model))
|
||||
|
||||
return func
|
||||
|
||||
@ -236,12 +257,12 @@ def get_dir():
|
||||
variable is not set.
|
||||
"""
|
||||
# Issue warning to move data if old env is set
|
||||
if os.getenv('TORCH_HUB'):
|
||||
warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead')
|
||||
if os.getenv("TORCH_HUB"):
|
||||
warnings.warn("TORCH_HUB is deprecated, please use env TORCH_HOME instead")
|
||||
|
||||
if _hub_dir is not None:
|
||||
return _hub_dir
|
||||
return os.path.join(_get_torch_home(), 'hub')
|
||||
return os.path.join(_get_torch_home(), "hub")
|
||||
|
||||
|
||||
def set_dir(d):
|
||||
@ -273,16 +294,22 @@ def list(github, force_reload=False, skip_validation=False):
|
||||
Example:
|
||||
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
|
||||
"""
|
||||
repo_dir = _get_cache_or_reload(github, force_reload, verbose=True, skip_validation=skip_validation)
|
||||
repo_dir = _get_cache_or_reload(
|
||||
github, force_reload, verbose=True, skip_validation=skip_validation
|
||||
)
|
||||
|
||||
sys.path.insert(0, repo_dir)
|
||||
|
||||
hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)
|
||||
hub_module = import_module(MODULE_HUBCONF, repo_dir + "/" + MODULE_HUBCONF)
|
||||
|
||||
sys.path.remove(repo_dir)
|
||||
|
||||
# We take functions starts with '_' as internal helper functions
|
||||
entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')]
|
||||
entrypoints = [
|
||||
f
|
||||
for f in dir(hub_module)
|
||||
if callable(getattr(hub_module, f)) and not f.startswith("_")
|
||||
]
|
||||
|
||||
return entrypoints
|
||||
|
||||
@ -303,11 +330,13 @@ def help(github, model, force_reload=False, skip_validation=False):
|
||||
Example:
|
||||
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
|
||||
"""
|
||||
repo_dir = _get_cache_or_reload(github, force_reload, verbose=True, skip_validation=skip_validation)
|
||||
repo_dir = _get_cache_or_reload(
|
||||
github, force_reload, verbose=True, skip_validation=skip_validation
|
||||
)
|
||||
|
||||
sys.path.insert(0, repo_dir)
|
||||
|
||||
hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)
|
||||
hub_module = import_module(MODULE_HUBCONF, repo_dir + "/" + MODULE_HUBCONF)
|
||||
|
||||
sys.path.remove(repo_dir)
|
||||
|
||||
@ -367,17 +396,20 @@ def load(repo_or_dir, model, *args, **kwargs):
|
||||
>>> path = '/some/local/path/pytorch/vision'
|
||||
>>> model = torch.hub.load(path, 'resnet50', pretrained=True)
|
||||
"""
|
||||
source = kwargs.pop('source', 'github').lower()
|
||||
force_reload = kwargs.pop('force_reload', False)
|
||||
verbose = kwargs.pop('verbose', True)
|
||||
skip_validation = kwargs.pop('skip_validation', False)
|
||||
source = kwargs.pop("source", "github").lower()
|
||||
force_reload = kwargs.pop("force_reload", False)
|
||||
verbose = kwargs.pop("verbose", True)
|
||||
skip_validation = kwargs.pop("skip_validation", False)
|
||||
|
||||
if source not in ('github', 'local'):
|
||||
if source not in ("github", "local"):
|
||||
raise ValueError(
|
||||
f'Unknown source: "{source}". Allowed values: "github" | "local".')
|
||||
f'Unknown source: "{source}". Allowed values: "github" | "local".'
|
||||
)
|
||||
|
||||
if source == 'github':
|
||||
repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, verbose, skip_validation)
|
||||
if source == "github":
|
||||
repo_or_dir = _get_cache_or_reload(
|
||||
repo_or_dir, force_reload, verbose, skip_validation
|
||||
)
|
||||
|
||||
model = _load_local(repo_or_dir, model, *args, **kwargs)
|
||||
return model
|
||||
@ -436,7 +468,7 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True):
|
||||
req = Request(url, headers={"User-Agent": "torch.hub"})
|
||||
u = urlopen(req)
|
||||
meta = u.info()
|
||||
if hasattr(meta, 'getheaders'):
|
||||
if hasattr(meta, "getheaders"):
|
||||
content_length = meta.getheaders("Content-Length")
|
||||
else:
|
||||
content_length = meta.get_all("Content-Length")
|
||||
@ -453,8 +485,13 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True):
|
||||
try:
|
||||
if hash_prefix is not None:
|
||||
sha256 = hashlib.sha256()
|
||||
with tqdm(total=file_size, disable=not progress,
|
||||
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
|
||||
with tqdm(
|
||||
total=file_size,
|
||||
disable=not progress,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as pbar:
|
||||
while True:
|
||||
buffer = u.read(8192)
|
||||
if len(buffer) == 0:
|
||||
@ -467,9 +504,12 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True):
|
||||
f.close()
|
||||
if hash_prefix is not None:
|
||||
digest = sha256.hexdigest()
|
||||
if digest[:len(hash_prefix)] != hash_prefix:
|
||||
raise RuntimeError('invalid hash value (expected "{}", got "{}")'
|
||||
.format(hash_prefix, digest))
|
||||
if digest[: len(hash_prefix)] != hash_prefix:
|
||||
raise RuntimeError(
|
||||
'invalid hash value (expected "{}", got "{}")'.format(
|
||||
hash_prefix, digest
|
||||
)
|
||||
)
|
||||
shutil.move(f.name, dst)
|
||||
finally:
|
||||
f.close()
|
||||
@ -478,9 +518,11 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True):
|
||||
|
||||
|
||||
def _download_url_to_file(url, dst, hash_prefix=None, progress=True):
|
||||
warnings.warn('torch.hub._download_url_to_file has been renamed to\
|
||||
warnings.warn(
|
||||
"torch.hub._download_url_to_file has been renamed to\
|
||||
torch.hub.download_url_to_file to be a public API,\
|
||||
_download_url_to_file will be removed in after 1.3 release')
|
||||
_download_url_to_file will be removed in after 1.3 release"
|
||||
)
|
||||
download_url_to_file(url, dst, hash_prefix, progress)
|
||||
|
||||
|
||||
@ -495,23 +537,32 @@ def _is_legacy_zip_format(filename):
|
||||
|
||||
|
||||
def _legacy_zip_load(filename, model_dir, map_location):
|
||||
warnings.warn('Falling back to the old format < 1.6. This support will be '
|
||||
'deprecated in favor of default zipfile format introduced in 1.6. '
|
||||
'Please redo torch.save() to save it in the new zipfile format.')
|
||||
warnings.warn(
|
||||
"Falling back to the old format < 1.6. This support will be "
|
||||
"deprecated in favor of default zipfile format introduced in 1.6. "
|
||||
"Please redo torch.save() to save it in the new zipfile format."
|
||||
)
|
||||
# Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
|
||||
# We deliberately don't handle tarfile here since our legacy serialization format was in tar.
|
||||
# E.g. resnet18-5c106cde.pth which is widely used.
|
||||
with zipfile.ZipFile(filename) as f:
|
||||
members = f.infolist()
|
||||
if len(members) != 1:
|
||||
raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
|
||||
raise RuntimeError("Only one file(not dir) is allowed in the zipfile")
|
||||
f.extractall(model_dir)
|
||||
extraced_name = members[0].filename
|
||||
extracted_file = os.path.join(model_dir, extraced_name)
|
||||
return torch.load(extracted_file, map_location=map_location)
|
||||
|
||||
|
||||
def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None):
|
||||
def load_state_dict_from_url(
|
||||
url,
|
||||
model_dir=None,
|
||||
map_location=None,
|
||||
progress=True,
|
||||
check_hash=False,
|
||||
file_name=None,
|
||||
):
|
||||
r"""Loads the Torch serialized object at the given URL.
|
||||
|
||||
If downloaded file is a zip file, it will be automatically
|
||||
@ -540,12 +591,14 @@ def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=Tr
|
||||
|
||||
"""
|
||||
# Issue warning to move data if old env is set
|
||||
if os.getenv('TORCH_MODEL_ZOO'):
|
||||
warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
||||
if os.getenv("TORCH_MODEL_ZOO"):
|
||||
warnings.warn(
|
||||
"TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead"
|
||||
)
|
||||
|
||||
if model_dir is None:
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||
model_dir = os.path.join(hub_dir, "checkpoints")
|
||||
|
||||
try:
|
||||
os.makedirs(model_dir)
|
||||
|
Reference in New Issue
Block a user