mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Strictly type everything in .github and tools (#59117)
Summary: This PR greatly simplifies `mypy-strict.ini` by strictly typing everything in `.github` and `tools`, rather than picking and choosing only specific files in those two dirs. It also removes `warn_unused_ignores` from `mypy-strict.ini`, for reasons described in https://github.com/pytorch/pytorch/pull/56402#issuecomment-822743795: basically, that setting makes life more difficult depending on what libraries you have installed locally vs in CI (e.g. `ruamel`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/59117 Test Plan: ``` flake8 mypy --config mypy-strict.ini ``` Reviewed By: malfet Differential Revision: D28765386 Pulled By: samestep fbshipit-source-id: 3e744e301c7a464f8a2a2428fcdbad534e231f2e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
6ff001c125
commit
737d920b21
119
tools/nightly.py
119
tools/nightly.py
@ -40,10 +40,10 @@ import contextlib
|
||||
import subprocess
|
||||
from ast import literal_eval
|
||||
from argparse import ArgumentParser
|
||||
from typing import Dict, Optional, Iterator
|
||||
from typing import (Any, Callable, Dict, Generator, Iterable, Iterator, List,
|
||||
Optional, Sequence, Set, Tuple, TypeVar, cast)
|
||||
|
||||
|
||||
LOGGER = None
|
||||
LOGGER: Optional[logging.Logger] = None
|
||||
URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
|
||||
DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
|
||||
SHA1_RE = re.compile("([0-9a-fA-F]{40})")
|
||||
@ -133,7 +133,7 @@ def logging_rotate() -> None:
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def logging_manager(*, debug: bool = False) -> Iterator[None]:
|
||||
def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, None]:
|
||||
"""Setup logging. If a failure starts here we won't
|
||||
be able to save the user in a reasonable way.
|
||||
|
||||
@ -179,7 +179,7 @@ def logging_manager(*, debug: bool = False) -> Iterator[None]:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def check_in_repo():
|
||||
def check_in_repo() -> Optional[str]:
|
||||
"""Ensures that we are in the PyTorch repo."""
|
||||
if not os.path.isfile("setup.py"):
|
||||
return "Not in root-level PyTorch repo, no setup.py found"
|
||||
@ -187,12 +187,13 @@ def check_in_repo():
|
||||
s = f.read()
|
||||
if "PyTorch" not in s:
|
||||
return "Not in PyTorch repo, 'PyTorch' not found in setup.py"
|
||||
return None
|
||||
|
||||
|
||||
def check_branch(subcommand, branch):
|
||||
def check_branch(subcommand: str, branch: Optional[str]) -> Optional[str]:
|
||||
"""Checks that the branch name can be checked out."""
|
||||
if subcommand != "checkout":
|
||||
return
|
||||
return None
|
||||
# first make sure actual branch name was given
|
||||
if branch is None:
|
||||
return "Branch name to checkout must be supplied with '-b' option"
|
||||
@ -203,36 +204,44 @@ def check_branch(subcommand, branch):
|
||||
return "Need to have clean working tree to checkout!\n\n" + p.stdout
|
||||
# next check that the branch name doesn't already exist
|
||||
cmd = ["git", "show-ref", "--verify", "--quiet", "refs/heads/" + branch]
|
||||
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False)
|
||||
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False) # type: ignore[assignment]
|
||||
if not p.returncode:
|
||||
return f"Branch {branch!r} already exists"
|
||||
return None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def timer(logger, prefix):
|
||||
def timer(logger: logging.Logger, prefix: str) -> Iterator[None]:
|
||||
"""Timed context manager"""
|
||||
start_time = time.time()
|
||||
yield
|
||||
logger.info(f"{prefix} took {time.time() - start_time:.3f} [s]")
|
||||
|
||||
|
||||
def timed(prefix):
|
||||
F = TypeVar('F', bound=Callable[..., Any])
|
||||
|
||||
|
||||
def timed(prefix: str) -> Callable[[F], F]:
|
||||
"""Decorator for timing functions"""
|
||||
|
||||
def dec(f):
|
||||
def dec(f: F) -> F:
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
global LOGGER
|
||||
LOGGER.info(prefix)
|
||||
with timer(LOGGER, prefix):
|
||||
logger = cast(logging.Logger, LOGGER)
|
||||
logger.info(prefix)
|
||||
with timer(logger, prefix):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return cast(F, wrapper)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
def _make_channel_args(channels=("pytorch-nightly",), override_channels=False):
|
||||
def _make_channel_args(
|
||||
channels: Iterable[str] = ("pytorch-nightly",),
|
||||
override_channels: bool = False,
|
||||
) -> List[str]:
|
||||
args = []
|
||||
for channel in channels:
|
||||
args.append("--channel")
|
||||
@ -244,8 +253,11 @@ def _make_channel_args(channels=("pytorch-nightly",), override_channels=False):
|
||||
|
||||
@timed("Solving conda environment")
|
||||
def conda_solve(
|
||||
name=None, prefix=None, channels=("pytorch-nightly",), override_channels=False
|
||||
):
|
||||
name: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
channels: Iterable[str] = ("pytorch-nightly",),
|
||||
override_channels: bool = False,
|
||||
) -> Tuple[List[str], str, str, bool, List[str]]:
|
||||
"""Performs the conda solve and splits the deps from the package."""
|
||||
# compute what environment to use
|
||||
if prefix is not None:
|
||||
@ -299,7 +311,7 @@ def conda_solve(
|
||||
|
||||
|
||||
@timed("Installing dependencies")
|
||||
def deps_install(deps, existing_env, env_opts):
|
||||
def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> None:
|
||||
"""Install dependencies to deps environment"""
|
||||
if not existing_env:
|
||||
# first remove previous pytorch-deps env
|
||||
@ -312,7 +324,7 @@ def deps_install(deps, existing_env, env_opts):
|
||||
|
||||
|
||||
@timed("Installing pytorch nightly binaries")
|
||||
def pytorch_install(url):
|
||||
def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]:
|
||||
""""Install pytorch into a temporary directory"""
|
||||
pytdir = tempfile.TemporaryDirectory()
|
||||
cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url]
|
||||
@ -320,7 +332,7 @@ def pytorch_install(url):
|
||||
return pytdir
|
||||
|
||||
|
||||
def _site_packages(dirname, platform):
|
||||
def _site_packages(dirname: str, platform: str) -> str:
|
||||
if platform.startswith("win"):
|
||||
template = os.path.join(dirname, "Lib", "site-packages")
|
||||
else:
|
||||
@ -329,7 +341,7 @@ def _site_packages(dirname, platform):
|
||||
return spdir
|
||||
|
||||
|
||||
def _ensure_commit(git_sha1):
|
||||
def _ensure_commit(git_sha1: str) -> None:
|
||||
"""Make sure that we actually have the commit locally"""
|
||||
cmd = ["git", "cat-file", "-e", git_sha1 + "^{commit}"]
|
||||
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False)
|
||||
@ -341,7 +353,7 @@ def _ensure_commit(git_sha1):
|
||||
p = subprocess.run(cmd, check=True)
|
||||
|
||||
|
||||
def _nightly_version(spdir):
|
||||
def _nightly_version(spdir: str) -> str:
|
||||
# first get the git version from the installed module
|
||||
version_fname = os.path.join(spdir, "torch", "version.py")
|
||||
with open(version_fname) as f:
|
||||
@ -371,7 +383,7 @@ def _nightly_version(spdir):
|
||||
|
||||
|
||||
@timed("Checking out nightly PyTorch")
|
||||
def checkout_nightly_version(branch, spdir):
|
||||
def checkout_nightly_version(branch: str, spdir: str) -> None:
|
||||
"""Get's the nightly version and then checks it out."""
|
||||
nightly_version = _nightly_version(spdir)
|
||||
cmd = ["git", "checkout", "-b", branch, nightly_version]
|
||||
@ -379,40 +391,40 @@ def checkout_nightly_version(branch, spdir):
|
||||
|
||||
|
||||
@timed("Pulling nightly PyTorch")
|
||||
def pull_nightly_version(spdir):
|
||||
def pull_nightly_version(spdir: str) -> None:
|
||||
"""Fetches the nightly version and then merges it ."""
|
||||
nightly_version = _nightly_version(spdir)
|
||||
cmd = ["git", "merge", nightly_version]
|
||||
p = subprocess.run(cmd, check=True)
|
||||
|
||||
|
||||
def _get_listing_linux(source_dir):
|
||||
def _get_listing_linux(source_dir: str) -> List[str]:
|
||||
listing = glob.glob(os.path.join(source_dir, "*.so"))
|
||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.so")))
|
||||
return listing
|
||||
|
||||
|
||||
def _get_listing_osx(source_dir):
|
||||
def _get_listing_osx(source_dir: str) -> List[str]:
|
||||
# oddly, these are .so files even on Mac
|
||||
listing = glob.glob(os.path.join(source_dir, "*.so"))
|
||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dylib")))
|
||||
return listing
|
||||
|
||||
|
||||
def _get_listing_win(source_dir):
|
||||
def _get_listing_win(source_dir: str) -> List[str]:
|
||||
listing = glob.glob(os.path.join(source_dir, "*.pyd"))
|
||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.lib")))
|
||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dll")))
|
||||
return listing
|
||||
|
||||
|
||||
def _glob_pyis(d):
|
||||
def _glob_pyis(d: str) -> Set[str]:
|
||||
search = os.path.join(d, "**", "*.pyi")
|
||||
pyis = {os.path.relpath(p, d) for p in glob.iglob(search)}
|
||||
return pyis
|
||||
|
||||
|
||||
def _find_missing_pyi(source_dir, target_dir):
|
||||
def _find_missing_pyi(source_dir: str, target_dir: str) -> List[str]:
|
||||
source_pyis = _glob_pyis(source_dir)
|
||||
target_pyis = _glob_pyis(target_dir)
|
||||
missing_pyis = [os.path.join(source_dir, p) for p in (source_pyis - target_pyis)]
|
||||
@ -420,7 +432,7 @@ def _find_missing_pyi(source_dir, target_dir):
|
||||
return missing_pyis
|
||||
|
||||
|
||||
def _get_listing(source_dir, target_dir, platform):
|
||||
def _get_listing(source_dir: str, target_dir: str, platform: str) -> List[str]:
|
||||
if platform.startswith("linux"):
|
||||
listing = _get_listing_linux(source_dir)
|
||||
elif platform.startswith("osx"):
|
||||
@ -437,7 +449,7 @@ def _get_listing(source_dir, target_dir, platform):
|
||||
return listing
|
||||
|
||||
|
||||
def _remove_existing(trg, is_dir):
|
||||
def _remove_existing(trg: str, is_dir: bool) -> None:
|
||||
if os.path.exists(trg):
|
||||
if is_dir:
|
||||
shutil.rmtree(trg)
|
||||
@ -445,7 +457,13 @@ def _remove_existing(trg, is_dir):
|
||||
os.remove(trg)
|
||||
|
||||
|
||||
def _move_single(src, source_dir, target_dir, mover, verb):
|
||||
def _move_single(
|
||||
src: str,
|
||||
source_dir: str,
|
||||
target_dir: str,
|
||||
mover: Callable[[str, str], None],
|
||||
verb: str,
|
||||
) -> None:
|
||||
is_dir = os.path.isdir(src)
|
||||
relpath = os.path.relpath(src, source_dir)
|
||||
trg = os.path.join(target_dir, relpath)
|
||||
@ -469,18 +487,18 @@ def _move_single(src, source_dir, target_dir, mover, verb):
|
||||
mover(src, trg)
|
||||
|
||||
|
||||
def _copy_files(listing, source_dir, target_dir):
|
||||
def _copy_files(listing: List[str], source_dir: str, target_dir: str) -> None:
|
||||
for src in listing:
|
||||
_move_single(src, source_dir, target_dir, shutil.copy2, "Copying")
|
||||
|
||||
|
||||
def _link_files(listing, source_dir, target_dir):
|
||||
def _link_files(listing: List[str], source_dir: str, target_dir: str) -> None:
|
||||
for src in listing:
|
||||
_move_single(src, source_dir, target_dir, os.link, "Linking")
|
||||
|
||||
|
||||
@timed("Moving nightly files into repo")
|
||||
def move_nightly_files(spdir, platform):
|
||||
def move_nightly_files(spdir: str, platform: str) -> None:
|
||||
"""Moves PyTorch files from temporary installed location to repo."""
|
||||
# get file listing
|
||||
source_dir = os.path.join(spdir, "torch")
|
||||
@ -496,7 +514,7 @@ def move_nightly_files(spdir, platform):
|
||||
_copy_files(listing, source_dir, target_dir)
|
||||
|
||||
|
||||
def _available_envs():
|
||||
def _available_envs() -> Dict[str, str]:
|
||||
cmd = ["conda", "env", "list"]
|
||||
p = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
|
||||
lines = p.stdout.splitlines()
|
||||
@ -513,7 +531,7 @@ def _available_envs():
|
||||
|
||||
|
||||
@timed("Writing pytorch-nightly.pth")
|
||||
def write_pth(env_opts, platform):
|
||||
def write_pth(env_opts: List[str], platform: str) -> None:
|
||||
"""Writes Python path file for this dir."""
|
||||
env_type, env_dir = env_opts
|
||||
if env_type == "--name":
|
||||
@ -533,17 +551,16 @@ def write_pth(env_opts, platform):
|
||||
|
||||
|
||||
def install(
|
||||
subcommand="checkout",
|
||||
branch=None,
|
||||
name=None,
|
||||
prefix=None,
|
||||
channels=("pytorch-nightly",),
|
||||
override_channels=False,
|
||||
logger=None,
|
||||
):
|
||||
*,
|
||||
logger: logging.Logger,
|
||||
subcommand: str = "checkout",
|
||||
branch: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
channels: Iterable[str] = ("pytorch-nightly",),
|
||||
override_channels: bool = False,
|
||||
) -> None:
|
||||
"""Development install of PyTorch"""
|
||||
global LOGGER
|
||||
logger = logger or LOGGER
|
||||
deps, pytorch, platform, existing_env, env_opts = conda_solve(
|
||||
name=name, prefix=prefix, channels=channels, override_channels=override_channels
|
||||
)
|
||||
@ -552,7 +569,7 @@ def install(
|
||||
pytdir = pytorch_install(pytorch)
|
||||
spdir = _site_packages(pytdir.name, platform)
|
||||
if subcommand == "checkout":
|
||||
checkout_nightly_version(branch, spdir)
|
||||
checkout_nightly_version(cast(str, branch), spdir)
|
||||
elif subcommand == "pull":
|
||||
pull_nightly_version(spdir)
|
||||
else:
|
||||
@ -566,7 +583,7 @@ def install(
|
||||
)
|
||||
|
||||
|
||||
def make_parser():
|
||||
def make_parser() -> ArgumentParser:
|
||||
p = ArgumentParser("nightly")
|
||||
# subcommands
|
||||
subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute")
|
||||
@ -627,7 +644,7 @@ def make_parser():
|
||||
return p
|
||||
|
||||
|
||||
def main(args=None):
|
||||
def main(args: Optional[Sequence[str]] = None) -> None:
|
||||
"""Main entry point"""
|
||||
global LOGGER
|
||||
p = make_parser()
|
||||
|
Reference in New Issue
Block a user