Strictly type everything in .github and tools (#59117)

Summary:
This PR greatly simplifies `mypy-strict.ini` by strictly typing everything in `.github` and `tools`, rather than picking and choosing only specific files in those two dirs. It also removes `warn_unused_ignores` from `mypy-strict.ini`, for reasons described in https://github.com/pytorch/pytorch/pull/56402#issuecomment-822743795: basically, that setting makes life more difficult depending on what libraries you have installed locally vs in CI (e.g. `ruamel`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59117

Test Plan:
```
flake8
mypy --config mypy-strict.ini
```

Reviewed By: malfet

Differential Revision: D28765386

Pulled By: samestep

fbshipit-source-id: 3e744e301c7a464f8a2a2428fcdbad534e231f2e
This commit is contained in:
Sam Estep
2021-06-07 14:48:29 -07:00
committed by Facebook GitHub Bot
parent 6ff001c125
commit 737d920b21
43 changed files with 463 additions and 312 deletions

View File

@ -40,10 +40,10 @@ import contextlib
import subprocess
from ast import literal_eval
from argparse import ArgumentParser
from typing import Dict, Optional, Iterator
from typing import (Any, Callable, Dict, Generator, Iterable, Iterator, List,
Optional, Sequence, Set, Tuple, TypeVar, cast)
LOGGER = None
LOGGER: Optional[logging.Logger] = None
URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
SHA1_RE = re.compile("([0-9a-fA-F]{40})")
@ -133,7 +133,7 @@ def logging_rotate() -> None:
@contextlib.contextmanager
def logging_manager(*, debug: bool = False) -> Iterator[None]:
def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, None]:
"""Setup logging. If a failure starts here we won't
be able to save the user in a reasonable way.
@ -179,7 +179,7 @@ def logging_manager(*, debug: bool = False) -> Iterator[None]:
sys.exit(1)
def check_in_repo():
def check_in_repo() -> Optional[str]:
"""Ensures that we are in the PyTorch repo."""
if not os.path.isfile("setup.py"):
return "Not in root-level PyTorch repo, no setup.py found"
@ -187,12 +187,13 @@ def check_in_repo():
s = f.read()
if "PyTorch" not in s:
return "Not in PyTorch repo, 'PyTorch' not found in setup.py"
return None
def check_branch(subcommand, branch):
def check_branch(subcommand: str, branch: Optional[str]) -> Optional[str]:
"""Checks that the branch name can be checked out."""
if subcommand != "checkout":
return
return None
# first make sure actual branch name was given
if branch is None:
return "Branch name to checkout must be supplied with '-b' option"
@ -203,36 +204,44 @@ def check_branch(subcommand, branch):
return "Need to have clean working tree to checkout!\n\n" + p.stdout
# next check that the branch name doesn't already exist
cmd = ["git", "show-ref", "--verify", "--quiet", "refs/heads/" + branch]
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False)
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False) # type: ignore[assignment]
if not p.returncode:
return f"Branch {branch!r} already exists"
return None
@contextlib.contextmanager
def timer(logger, prefix):
def timer(logger: logging.Logger, prefix: str) -> Iterator[None]:
"""Timed context manager"""
start_time = time.time()
yield
logger.info(f"{prefix} took {time.time() - start_time:.3f} [s]")
def timed(prefix):
F = TypeVar('F', bound=Callable[..., Any])
def timed(prefix: str) -> Callable[[F], F]:
"""Decorator for timing functions"""
def dec(f):
def dec(f: F) -> F:
@functools.wraps(f)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> Any:
global LOGGER
LOGGER.info(prefix)
with timer(LOGGER, prefix):
logger = cast(logging.Logger, LOGGER)
logger.info(prefix)
with timer(logger, prefix):
return f(*args, **kwargs)
return wrapper
return cast(F, wrapper)
return dec
def _make_channel_args(channels=("pytorch-nightly",), override_channels=False):
def _make_channel_args(
channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False,
) -> List[str]:
args = []
for channel in channels:
args.append("--channel")
@ -244,8 +253,11 @@ def _make_channel_args(channels=("pytorch-nightly",), override_channels=False):
@timed("Solving conda environment")
def conda_solve(
name=None, prefix=None, channels=("pytorch-nightly",), override_channels=False
):
name: Optional[str] = None,
prefix: Optional[str] = None,
channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False,
) -> Tuple[List[str], str, str, bool, List[str]]:
"""Performs the conda solve and splits the deps from the package."""
# compute what environment to use
if prefix is not None:
@ -299,7 +311,7 @@ def conda_solve(
@timed("Installing dependencies")
def deps_install(deps, existing_env, env_opts):
def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> None:
"""Install dependencies to deps environment"""
if not existing_env:
# first remove previous pytorch-deps env
@ -312,7 +324,7 @@ def deps_install(deps, existing_env, env_opts):
@timed("Installing pytorch nightly binaries")
def pytorch_install(url):
def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]:
""""Install pytorch into a temporary directory"""
pytdir = tempfile.TemporaryDirectory()
cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url]
@ -320,7 +332,7 @@ def pytorch_install(url):
return pytdir
def _site_packages(dirname, platform):
def _site_packages(dirname: str, platform: str) -> str:
if platform.startswith("win"):
template = os.path.join(dirname, "Lib", "site-packages")
else:
@ -329,7 +341,7 @@ def _site_packages(dirname, platform):
return spdir
def _ensure_commit(git_sha1):
def _ensure_commit(git_sha1: str) -> None:
"""Make sure that we actually have the commit locally"""
cmd = ["git", "cat-file", "-e", git_sha1 + "^{commit}"]
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False)
@ -341,7 +353,7 @@ def _ensure_commit(git_sha1):
p = subprocess.run(cmd, check=True)
def _nightly_version(spdir):
def _nightly_version(spdir: str) -> str:
# first get the git version from the installed module
version_fname = os.path.join(spdir, "torch", "version.py")
with open(version_fname) as f:
@ -371,7 +383,7 @@ def _nightly_version(spdir):
@timed("Checking out nightly PyTorch")
def checkout_nightly_version(branch, spdir):
def checkout_nightly_version(branch: str, spdir: str) -> None:
"""Get's the nightly version and then checks it out."""
nightly_version = _nightly_version(spdir)
cmd = ["git", "checkout", "-b", branch, nightly_version]
@ -379,40 +391,40 @@ def checkout_nightly_version(branch, spdir):
@timed("Pulling nightly PyTorch")
def pull_nightly_version(spdir):
def pull_nightly_version(spdir: str) -> None:
"""Fetches the nightly version and then merges it ."""
nightly_version = _nightly_version(spdir)
cmd = ["git", "merge", nightly_version]
p = subprocess.run(cmd, check=True)
def _get_listing_linux(source_dir):
def _get_listing_linux(source_dir: str) -> List[str]:
listing = glob.glob(os.path.join(source_dir, "*.so"))
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.so")))
return listing
def _get_listing_osx(source_dir):
def _get_listing_osx(source_dir: str) -> List[str]:
# oddly, these are .so files even on Mac
listing = glob.glob(os.path.join(source_dir, "*.so"))
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dylib")))
return listing
def _get_listing_win(source_dir):
def _get_listing_win(source_dir: str) -> List[str]:
listing = glob.glob(os.path.join(source_dir, "*.pyd"))
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.lib")))
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dll")))
return listing
def _glob_pyis(d):
def _glob_pyis(d: str) -> Set[str]:
search = os.path.join(d, "**", "*.pyi")
pyis = {os.path.relpath(p, d) for p in glob.iglob(search)}
return pyis
def _find_missing_pyi(source_dir, target_dir):
def _find_missing_pyi(source_dir: str, target_dir: str) -> List[str]:
source_pyis = _glob_pyis(source_dir)
target_pyis = _glob_pyis(target_dir)
missing_pyis = [os.path.join(source_dir, p) for p in (source_pyis - target_pyis)]
@ -420,7 +432,7 @@ def _find_missing_pyi(source_dir, target_dir):
return missing_pyis
def _get_listing(source_dir, target_dir, platform):
def _get_listing(source_dir: str, target_dir: str, platform: str) -> List[str]:
if platform.startswith("linux"):
listing = _get_listing_linux(source_dir)
elif platform.startswith("osx"):
@ -437,7 +449,7 @@ def _get_listing(source_dir, target_dir, platform):
return listing
def _remove_existing(trg, is_dir):
def _remove_existing(trg: str, is_dir: bool) -> None:
if os.path.exists(trg):
if is_dir:
shutil.rmtree(trg)
@ -445,7 +457,13 @@ def _remove_existing(trg, is_dir):
os.remove(trg)
def _move_single(src, source_dir, target_dir, mover, verb):
def _move_single(
src: str,
source_dir: str,
target_dir: str,
mover: Callable[[str, str], None],
verb: str,
) -> None:
is_dir = os.path.isdir(src)
relpath = os.path.relpath(src, source_dir)
trg = os.path.join(target_dir, relpath)
@ -469,18 +487,18 @@ def _move_single(src, source_dir, target_dir, mover, verb):
mover(src, trg)
def _copy_files(listing, source_dir, target_dir):
def _copy_files(listing: List[str], source_dir: str, target_dir: str) -> None:
for src in listing:
_move_single(src, source_dir, target_dir, shutil.copy2, "Copying")
def _link_files(listing, source_dir, target_dir):
def _link_files(listing: List[str], source_dir: str, target_dir: str) -> None:
for src in listing:
_move_single(src, source_dir, target_dir, os.link, "Linking")
@timed("Moving nightly files into repo")
def move_nightly_files(spdir, platform):
def move_nightly_files(spdir: str, platform: str) -> None:
"""Moves PyTorch files from temporary installed location to repo."""
# get file listing
source_dir = os.path.join(spdir, "torch")
@ -496,7 +514,7 @@ def move_nightly_files(spdir, platform):
_copy_files(listing, source_dir, target_dir)
def _available_envs():
def _available_envs() -> Dict[str, str]:
cmd = ["conda", "env", "list"]
p = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
lines = p.stdout.splitlines()
@ -513,7 +531,7 @@ def _available_envs():
@timed("Writing pytorch-nightly.pth")
def write_pth(env_opts, platform):
def write_pth(env_opts: List[str], platform: str) -> None:
"""Writes Python path file for this dir."""
env_type, env_dir = env_opts
if env_type == "--name":
@ -533,17 +551,16 @@ def write_pth(env_opts, platform):
def install(
subcommand="checkout",
branch=None,
name=None,
prefix=None,
channels=("pytorch-nightly",),
override_channels=False,
logger=None,
):
*,
logger: logging.Logger,
subcommand: str = "checkout",
branch: Optional[str] = None,
name: Optional[str] = None,
prefix: Optional[str] = None,
channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False,
) -> None:
"""Development install of PyTorch"""
global LOGGER
logger = logger or LOGGER
deps, pytorch, platform, existing_env, env_opts = conda_solve(
name=name, prefix=prefix, channels=channels, override_channels=override_channels
)
@ -552,7 +569,7 @@ def install(
pytdir = pytorch_install(pytorch)
spdir = _site_packages(pytdir.name, platform)
if subcommand == "checkout":
checkout_nightly_version(branch, spdir)
checkout_nightly_version(cast(str, branch), spdir)
elif subcommand == "pull":
pull_nightly_version(spdir)
else:
@ -566,7 +583,7 @@ def install(
)
def make_parser():
def make_parser() -> ArgumentParser:
p = ArgumentParser("nightly")
# subcommands
subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute")
@ -627,7 +644,7 @@ def make_parser():
return p
def main(args=None):
def main(args: Optional[Sequence[str]] = None) -> None:
"""Main entry point"""
global LOGGER
p = make_parser()