mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Strictly type everything in .github and tools (#59117)
Summary: This PR greatly simplifies `mypy-strict.ini` by strictly typing everything in `.github` and `tools`, rather than picking and choosing only specific files in those two dirs. It also removes `warn_unused_ignores` from `mypy-strict.ini`, for reasons described in https://github.com/pytorch/pytorch/pull/56402#issuecomment-822743795: basically, that setting makes life more difficult depending on what libraries you have installed locally vs in CI (e.g. `ruamel`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/59117 Test Plan: ``` flake8 mypy --config mypy-strict.ini ``` Reviewed By: malfet Differential Revision: D28765386 Pulled By: samestep fbshipit-source-id: 3e744e301c7a464f8a2a2428fcdbad534e231f2e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
6ff001c125
commit
737d920b21
@ -11,12 +11,12 @@ REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
WORKFLOWS = REPO_ROOT / ".github" / "workflows"
|
||||
|
||||
|
||||
def concurrency_key(filename):
|
||||
def concurrency_key(filename: Path) -> str:
|
||||
workflow_name = filename.with_suffix("").name.replace("_", "-")
|
||||
return f"{workflow_name}-${{{{ github.event.pull_request.number || github.sha }}}}"
|
||||
|
||||
|
||||
def should_check(filename):
|
||||
def should_check(filename: Path) -> bool:
|
||||
with open(filename, "r") as f:
|
||||
content = f.read()
|
||||
|
||||
@ -31,7 +31,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
files = WORKFLOWS.glob("*.yml")
|
||||
files = list(WORKFLOWS.glob("*.yml"))
|
||||
|
||||
errors_found = False
|
||||
files = [f for f in files if should_check(f)]
|
||||
|
21
.github/scripts/generate_pytorch_version.py
vendored
21
.github/scripts/generate_pytorch_version.py
vendored
@ -16,12 +16,12 @@ LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$")
|
||||
class NoGitTagException(Exception):
|
||||
pass
|
||||
|
||||
def get_pytorch_root():
|
||||
def get_pytorch_root() -> Path:
|
||||
return Path(subprocess.check_output(
|
||||
['git', 'rev-parse', '--show-toplevel']
|
||||
).decode('ascii').strip())
|
||||
|
||||
def get_tag():
|
||||
def get_tag() -> str:
|
||||
root = get_pytorch_root()
|
||||
# We're on a tag
|
||||
am_on_tag = (
|
||||
@ -46,7 +46,7 @@ def get_tag():
|
||||
tag = re.sub(TRAILING_RC_PATTERN, "", tag)
|
||||
return tag
|
||||
|
||||
def get_base_version():
|
||||
def get_base_version() -> str:
|
||||
root = get_pytorch_root()
|
||||
dirty_version = open(root / 'version.txt', 'r').read().strip()
|
||||
# Strips trailing a0 from version.txt, not too sure why it's there in the
|
||||
@ -54,29 +54,34 @@ def get_base_version():
|
||||
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)
|
||||
|
||||
class PytorchVersion:
|
||||
def __init__(self, gpu_arch_type, gpu_arch_version, no_build_suffix):
|
||||
def __init__(
|
||||
self,
|
||||
gpu_arch_type: str,
|
||||
gpu_arch_version: str,
|
||||
no_build_suffix: bool,
|
||||
) -> None:
|
||||
self.gpu_arch_type = gpu_arch_type
|
||||
self.gpu_arch_version = gpu_arch_version
|
||||
self.no_build_suffix = no_build_suffix
|
||||
|
||||
def get_post_build_suffix(self):
|
||||
def get_post_build_suffix(self) -> str:
|
||||
if self.gpu_arch_type == "cuda":
|
||||
return f"+cu{self.gpu_arch_version.replace('.', '')}"
|
||||
return f"+{self.gpu_arch_type}{self.gpu_arch_version}"
|
||||
|
||||
def get_release_version(self):
|
||||
def get_release_version(self) -> str:
|
||||
if not get_tag():
|
||||
raise NoGitTagException(
|
||||
"Not on a git tag, are you sure you want a release version?"
|
||||
)
|
||||
return f"{get_tag()}{self.get_post_build_suffix()}"
|
||||
|
||||
def get_nightly_version(self):
|
||||
def get_nightly_version(self) -> str:
|
||||
date_str = datetime.today().strftime('%Y%m%d')
|
||||
build_suffix = self.get_post_build_suffix()
|
||||
return f"{get_base_version()}.dev{date_str}{build_suffix}"
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate pytorch version for binary builds"
|
||||
)
|
||||
|
6
.github/scripts/lint_native_functions.py
vendored
6
.github/scripts/lint_native_functions.py
vendored
@ -14,19 +14,19 @@ is simply to make sure that there is *some* configuration of ruamel that can rou
|
||||
the YAML, not to be prescriptive about it.
|
||||
'''
|
||||
|
||||
import ruamel.yaml
|
||||
import ruamel.yaml # type: ignore[import]
|
||||
import difflib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from io import StringIO
|
||||
|
||||
def fn(base):
|
||||
def fn(base: str) -> str:
|
||||
return str(base / Path("aten/src/ATen/native/native_functions.yaml"))
|
||||
|
||||
with open(Path(__file__).parent.parent.parent / fn('.'), "r") as f:
|
||||
contents = f.read()
|
||||
|
||||
yaml = ruamel.yaml.YAML()
|
||||
yaml = ruamel.yaml.YAML() # type: ignore[attr-defined]
|
||||
yaml.preserve_quotes = True
|
||||
yaml.width = 1000
|
||||
yaml.boolean_representation = ['False', 'True']
|
||||
|
6
.github/scripts/run_torchbench.py
vendored
6
.github/scripts/run_torchbench.py
vendored
@ -31,7 +31,7 @@ direction: decrease
|
||||
timeout: 720
|
||||
tests:"""
|
||||
|
||||
def gen_abtest_config(control: str, treatment: str, models: List[str]):
|
||||
def gen_abtest_config(control: str, treatment: str, models: List[str]) -> str:
|
||||
d = {}
|
||||
d["control"] = control
|
||||
d["treatment"] = treatment
|
||||
@ -43,7 +43,7 @@ def gen_abtest_config(control: str, treatment: str, models: List[str]):
|
||||
config = config + "\n"
|
||||
return config
|
||||
|
||||
def deploy_torchbench_config(output_dir: str, config: str):
|
||||
def deploy_torchbench_config(output_dir: str, config: str) -> None:
|
||||
# Create test dir if needed
|
||||
pathlib.Path(output_dir).mkdir(exist_ok=True)
|
||||
# TorchBench config file name
|
||||
@ -71,7 +71,7 @@ def extract_models_from_pr(torchbench_path: str, prbody_file: str) -> List[str]:
|
||||
return []
|
||||
return model_list
|
||||
|
||||
def run_torchbench(pytorch_path: str, torchbench_path: str, output_dir: str):
|
||||
def run_torchbench(pytorch_path: str, torchbench_path: str, output_dir: str) -> None:
|
||||
# Copy system environment so that we will not override
|
||||
env = dict(os.environ)
|
||||
command = ["python", "bisection.py", "--work-dir", output_dir,
|
||||
|
@ -492,7 +492,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
||||
"${TOOLS_PATH}/autograd/templates/python_linalg_functions.cpp"
|
||||
"${TOOLS_PATH}/autograd/templates/python_special_functions.cpp"
|
||||
"${TOOLS_PATH}/autograd/templates/variable_factories.h"
|
||||
"${TOOLS_PATH}/autograd/templates/annotated_fn_args.py"
|
||||
"${TOOLS_PATH}/autograd/templates/annotated_fn_args.py.in"
|
||||
"${TOOLS_PATH}/autograd/deprecated.yaml"
|
||||
"${TOOLS_PATH}/autograd/derivatives.yaml"
|
||||
"${TOOLS_PATH}/autograd/gen_autograd_functions.py"
|
||||
|
@ -5,8 +5,6 @@
|
||||
# this config file to be used to ENFORCE that people are using mypy on codegen
|
||||
# files.
|
||||
|
||||
# For now, only code_template.py and benchmark utils Timer are covered this way
|
||||
|
||||
[mypy]
|
||||
python_version = 3.6
|
||||
plugins = mypy_plugins/check_mypy_version.py
|
||||
@ -30,36 +28,14 @@ check_untyped_defs = True
|
||||
disallow_untyped_decorators = True
|
||||
no_implicit_optional = True
|
||||
warn_redundant_casts = True
|
||||
warn_unused_ignores = True
|
||||
warn_return_any = True
|
||||
implicit_reexport = False
|
||||
strict_equality = True
|
||||
|
||||
files =
|
||||
.github/scripts/generate_binary_build_matrix.py,
|
||||
.github/scripts/generate_ci_workflows.py,
|
||||
.github/scripts/parse_ref.py,
|
||||
.github,
|
||||
benchmarks/instruction_counts,
|
||||
tools/actions_local_runner.py,
|
||||
tools/autograd/*.py,
|
||||
tools/clang_tidy.py,
|
||||
tools/codegen,
|
||||
tools/explicit_ci_jobs.py,
|
||||
tools/extract_scripts.py,
|
||||
tools/mypy_wrapper.py,
|
||||
tools/print_test_stats.py,
|
||||
tools/pyi,
|
||||
tools/stats_utils,
|
||||
tools/test_history.py,
|
||||
tools/test/test_actions_local_runner.py,
|
||||
tools/test/test_extract_scripts.py,
|
||||
tools/test/test_mypy_wrapper.py,
|
||||
tools/test/test_test_history.py,
|
||||
tools/test/test_trailing_newlines.py,
|
||||
tools/test/test_translate_annotations.py,
|
||||
tools/trailing_newlines.py,
|
||||
tools/translate_annotations.py,
|
||||
tools/vscode_settings.py,
|
||||
tools,
|
||||
torch/testing/_internal/framework_utils.py,
|
||||
torch/utils/_pytree.py,
|
||||
torch/utils/benchmark/utils/common.py,
|
||||
|
@ -12,7 +12,7 @@ sys.path.append(os.path.realpath(os.path.join(
|
||||
'torch',
|
||||
'utils')))
|
||||
|
||||
from hipify import hipify_python
|
||||
from hipify import hipify_python # type: ignore[import]
|
||||
|
||||
parser = argparse.ArgumentParser(description='Top-level script for HIPifying, filling in most common parameters')
|
||||
parser.add_argument(
|
||||
@ -115,7 +115,7 @@ ignores = [
|
||||
]
|
||||
|
||||
# Check if the compiler is hip-clang.
|
||||
def is_hip_clang():
|
||||
def is_hip_clang() -> bool:
|
||||
try:
|
||||
hip_path = os.getenv('HIP_PATH', '/opt/rocm/hip')
|
||||
return 'HIP_COMPILER=clang' in open(hip_path + '/lib/.hipInfo').read()
|
||||
|
@ -48,7 +48,7 @@ def gen_annotated(native_yaml_path: str, out: str, autograd_dir: str) -> None:
|
||||
|
||||
template_path = os.path.join(autograd_dir, 'templates')
|
||||
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
|
||||
fm.write_with_template('annotated_fn_args.py', 'annotated_fn_args.py', lambda: {
|
||||
fm.write_with_template('annotated_fn_args.py', 'annotated_fn_args.py.in', lambda: {
|
||||
'annotated_args': textwrap.indent('\n'.join(annotated_args), ' '),
|
||||
})
|
||||
|
||||
|
@ -1,15 +1,16 @@
|
||||
import os
|
||||
from glob import glob
|
||||
import shutil
|
||||
from typing import Dict, Optional
|
||||
|
||||
from .setup_helpers.env import IS_64BIT, IS_WINDOWS, check_negative_env_flag
|
||||
from .setup_helpers.cmake import USE_NINJA
|
||||
from .setup_helpers.cmake import USE_NINJA, CMake
|
||||
|
||||
from setuptools import distutils
|
||||
from setuptools import distutils # type: ignore[import]
|
||||
|
||||
def _overlay_windows_vcvars(env):
|
||||
def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]:
|
||||
vc_arch = 'x64' if IS_64BIT else 'x86'
|
||||
vc_env = distutils._msvccompiler._get_vc_env(vc_arch)
|
||||
vc_env: Dict[str, str] = distutils._msvccompiler._get_vc_env(vc_arch)
|
||||
# Keys in `_get_vc_env` are always lowercase.
|
||||
# We turn them into uppercase before overlaying vcvars
|
||||
# because OS environ keys are always uppercase on Windows.
|
||||
@ -22,7 +23,7 @@ def _overlay_windows_vcvars(env):
|
||||
return vc_env
|
||||
|
||||
|
||||
def _create_build_env():
|
||||
def _create_build_env() -> Dict[str, str]:
|
||||
# XXX - our cmake file sometimes looks at the system environment
|
||||
# and not cmake flags!
|
||||
# you should NEVER add something to this list. It is bad practice to
|
||||
@ -44,7 +45,14 @@ def _create_build_env():
|
||||
return my_env
|
||||
|
||||
|
||||
def build_caffe2(version, cmake_python_library, build_python, rerun_cmake, cmake_only, cmake):
|
||||
def build_caffe2(
|
||||
version: Optional[str],
|
||||
cmake_python_library: Optional[str],
|
||||
build_python: bool,
|
||||
rerun_cmake: bool,
|
||||
cmake_only: bool,
|
||||
cmake: CMake,
|
||||
) -> None:
|
||||
my_env = _create_build_env()
|
||||
build_test = not check_negative_env_flag('BUILD_TEST')
|
||||
cmake.generate(version,
|
||||
|
@ -12,7 +12,9 @@ import asyncio
|
||||
import re
|
||||
import os
|
||||
import sys
|
||||
from clang_format_utils import get_and_check_clang_format, CLANG_FORMAT_PATH
|
||||
from typing import List, Set
|
||||
|
||||
from .clang_format_utils import get_and_check_clang_format, CLANG_FORMAT_PATH
|
||||
|
||||
# Allowlist of directories to check. All files that in that directory
|
||||
# (recursively) will be checked.
|
||||
@ -28,7 +30,7 @@ CLANG_FORMAT_ALLOWLIST = [
|
||||
CPP_FILE_REGEX = re.compile(".*\\.(h|cpp|cc|c|hpp)$")
|
||||
|
||||
|
||||
def get_allowlisted_files():
|
||||
def get_allowlisted_files() -> Set[str]:
|
||||
"""
|
||||
Parse CLANG_FORMAT_ALLOWLIST and resolve all directories.
|
||||
Returns the set of allowlist cpp source files.
|
||||
@ -42,7 +44,11 @@ def get_allowlisted_files():
|
||||
return set(matches)
|
||||
|
||||
|
||||
async def run_clang_format_on_file(filename, semaphore, verbose=False):
|
||||
async def run_clang_format_on_file(
|
||||
filename: str,
|
||||
semaphore: asyncio.Semaphore,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Run clang-format on the provided file.
|
||||
"""
|
||||
@ -55,7 +61,11 @@ async def run_clang_format_on_file(filename, semaphore, verbose=False):
|
||||
print("Formatted {}".format(filename))
|
||||
|
||||
|
||||
async def file_clang_formatted_correctly(filename, semaphore, verbose=False):
|
||||
async def file_clang_formatted_correctly(
|
||||
filename: str,
|
||||
semaphore: asyncio.Semaphore,
|
||||
verbose: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Checks if a file is formatted correctly and returns True if so.
|
||||
"""
|
||||
@ -80,7 +90,11 @@ async def file_clang_formatted_correctly(filename, semaphore, verbose=False):
|
||||
return ok
|
||||
|
||||
|
||||
async def run_clang_format(max_processes, diff=False, verbose=False):
|
||||
async def run_clang_format(
|
||||
max_processes: int,
|
||||
diff: bool = False,
|
||||
verbose: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Run clang-format to all files in CLANG_FORMAT_ALLOWLIST that match CPP_FILE_REGEX.
|
||||
"""
|
||||
@ -114,7 +128,7 @@ async def run_clang_format(max_processes, diff=False, verbose=False):
|
||||
|
||||
return ok
|
||||
|
||||
def parse_args(args):
|
||||
def parse_args(args: List[str]) -> argparse.Namespace:
|
||||
"""
|
||||
Parse and return command-line arguments.
|
||||
"""
|
||||
@ -134,7 +148,7 @@ def parse_args(args):
|
||||
return parser.parse_args(args)
|
||||
|
||||
|
||||
def main(args):
|
||||
def main(args: List[str]) -> bool:
|
||||
# Parse arguments.
|
||||
options = parse_args(args)
|
||||
# Get clang-format and make sure it is the right binary and it is in the right place.
|
||||
|
@ -45,7 +45,7 @@ def compute_file_sha256(path: str) -> str:
|
||||
return hash.hexdigest()
|
||||
|
||||
|
||||
def report_download_progress(chunk_number, chunk_size, file_size):
|
||||
def report_download_progress(chunk_number: int, chunk_size: int, file_size: int) -> None:
|
||||
"""
|
||||
Pretty printer for file download progress.
|
||||
"""
|
||||
@ -55,7 +55,7 @@ def report_download_progress(chunk_number, chunk_size, file_size):
|
||||
sys.stdout.write("\r0% |{:<64}| {}%".format(bar, int(percent * 100)))
|
||||
|
||||
|
||||
def download_clang_format(path):
|
||||
def download_clang_format(path: str) -> bool:
|
||||
"""
|
||||
Downloads a clang-format binary appropriate for the host platform and stores it at the given location.
|
||||
"""
|
||||
@ -81,7 +81,7 @@ def download_clang_format(path):
|
||||
return True
|
||||
|
||||
|
||||
def get_and_check_clang_format(verbose=False):
|
||||
def get_and_check_clang_format(verbose: bool = False) -> bool:
|
||||
"""
|
||||
Download a platform-appropriate clang-format binary if one doesn't already exist at the expected location and verify
|
||||
that it is the right binary by checking its SHA256 hash against the expected hash.
|
||||
|
@ -12,14 +12,18 @@ import argparse
|
||||
import yaml
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Set
|
||||
|
||||
|
||||
def canonical_name(opname):
|
||||
def canonical_name(opname: str) -> str:
|
||||
# Skip the overload name part as it's not supported by code analyzer yet.
|
||||
return opname.split('.', 1)[0]
|
||||
|
||||
|
||||
def load_op_dep_graph(fname):
|
||||
DepGraph = Dict[str, Set[str]]
|
||||
|
||||
|
||||
def load_op_dep_graph(fname: str) -> DepGraph:
|
||||
with open(fname, 'r') as stream:
|
||||
result = defaultdict(set)
|
||||
for op in yaml.safe_load(stream):
|
||||
@ -27,10 +31,10 @@ def load_op_dep_graph(fname):
|
||||
for dep in op.get('depends', []):
|
||||
dep_name = canonical_name(dep['name'])
|
||||
result[op_name].add(dep_name)
|
||||
return result
|
||||
return dict(result)
|
||||
|
||||
|
||||
def load_root_ops(fname):
|
||||
def load_root_ops(fname: str) -> List[str]:
|
||||
result = []
|
||||
with open(fname, 'r') as stream:
|
||||
for op in yaml.safe_load(stream):
|
||||
@ -38,7 +42,11 @@ def load_root_ops(fname):
|
||||
return result
|
||||
|
||||
|
||||
def gen_transitive_closure(dep_graph, root_ops, train=False):
|
||||
def gen_transitive_closure(
|
||||
dep_graph: DepGraph,
|
||||
root_ops: List[str],
|
||||
train: bool = False,
|
||||
) -> List[str]:
|
||||
result = set(root_ops)
|
||||
queue = root_ops[:]
|
||||
|
||||
@ -64,7 +72,7 @@ def gen_transitive_closure(dep_graph, root_ops, train=False):
|
||||
|
||||
return sorted(result)
|
||||
|
||||
def gen_transitive_closure_str(dep_graph, root_ops):
|
||||
def gen_transitive_closure_str(dep_graph: DepGraph, root_ops: List[str]) -> str:
|
||||
return ' '.join(gen_transitive_closure(dep_graph, root_ops))
|
||||
|
||||
|
||||
|
@ -11,6 +11,7 @@ python -m tools.code_analyzer.op_deps_processor \
|
||||
|
||||
import argparse
|
||||
import yaml
|
||||
from typing import Any, List
|
||||
|
||||
from tools.codegen.code_template import CodeTemplate
|
||||
|
||||
@ -46,12 +47,12 @@ DOT_OP_DEP = CodeTemplate("""\
|
||||
""")
|
||||
|
||||
|
||||
def load_op_deps(fname):
|
||||
def load_op_deps(fname: str) -> Any:
|
||||
with open(fname, 'r') as stream:
|
||||
return yaml.safe_load(stream)
|
||||
|
||||
|
||||
def process_base_ops(graph, base_ops):
|
||||
def process_base_ops(graph: Any, base_ops: List[str]) -> None:
|
||||
# remove base ops from all `depends` lists to compress the output graph
|
||||
for op in graph:
|
||||
op['depends'] = [
|
||||
@ -64,7 +65,13 @@ def process_base_ops(graph, base_ops):
|
||||
'depends': [{'name': name} for name in base_ops]})
|
||||
|
||||
|
||||
def convert(fname, graph, output_template, op_template, op_dep_template):
|
||||
def convert(
|
||||
fname: str,
|
||||
graph: Any,
|
||||
output_template: CodeTemplate,
|
||||
op_template: CodeTemplate,
|
||||
op_dep_template: CodeTemplate,
|
||||
) -> None:
|
||||
ops = []
|
||||
for op in graph:
|
||||
op_name = op['name']
|
||||
|
@ -1,11 +1,11 @@
|
||||
from ..tool import clang_coverage
|
||||
from ..util.setting import CompilerType, Option, TestList, TestPlatform
|
||||
from ..util.utils import check_compiler_type
|
||||
from .init import detect_compiler_type
|
||||
from .init import detect_compiler_type # type: ignore[attr-defined]
|
||||
from .run import clang_run, gcc_run
|
||||
|
||||
|
||||
def get_json_report(test_list: TestList, options: Option):
|
||||
def get_json_report(test_list: TestList, options: Option) -> None:
|
||||
cov_type = detect_compiler_type()
|
||||
check_compiler_type(cov_type)
|
||||
if cov_type == CompilerType.CLANG:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import argparse
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, cast
|
||||
|
||||
from ..util.setting import (
|
||||
JSON_FOLDER_BASE_DIR,
|
||||
@ -129,7 +129,7 @@ def empty_list_if_none(arg_interested_folder: Optional[List[str]]) -> List[str]:
|
||||
return arg_interested_folder
|
||||
|
||||
|
||||
def gcc_export_init():
|
||||
def gcc_export_init() -> None:
|
||||
remove_folder(JSON_FOLDER_BASE_DIR)
|
||||
create_folder(JSON_FOLDER_BASE_DIR)
|
||||
|
||||
@ -161,7 +161,7 @@ def print_init_info() -> None:
|
||||
print_log("pytorch folder: ", get_pytorch_folder())
|
||||
print_log("cpp test binaries folder: ", get_oss_binary_folder(TestType.CPP))
|
||||
print_log("python test scripts folder: ", get_oss_binary_folder(TestType.PY))
|
||||
print_log("compiler type: ", detect_compiler_type().value)
|
||||
print_log("compiler type: ", cast(CompilerType, detect_compiler_type()).value)
|
||||
print_log(
|
||||
"llvm tool folder (only for clang, if you are using gcov please ignore it): ",
|
||||
get_llvm_tool_path(),
|
||||
|
@ -82,8 +82,7 @@ def get_gcda_files() -> List[str]:
|
||||
# TODO use glob
|
||||
# output = glob.glob(f"{folder_has_gcda}/**/*.gcda")
|
||||
output = subprocess.check_output(["find", folder_has_gcda, "-iname", "*.gcda"])
|
||||
output = output.decode("utf-8").split("\n")
|
||||
return output
|
||||
return output.decode("utf-8").split("\n")
|
||||
else:
|
||||
return []
|
||||
|
||||
|
@ -148,7 +148,7 @@ def export(test_list: TestList, platform_type: TestPlatform) -> None:
|
||||
binary_file = ""
|
||||
shared_library_list = []
|
||||
if platform_type == TestPlatform.FBCODE:
|
||||
from caffe2.fb.code_coverage.tool.package.fbcode.utils import (
|
||||
from caffe2.fb.code_coverage.tool.package.fbcode.utils import ( # type: ignore[import]
|
||||
get_fbcode_binary_folder,
|
||||
)
|
||||
|
||||
|
@ -10,11 +10,11 @@ class LlvmCoverageSegment(NamedTuple):
|
||||
is_gap_entry: Optional[int]
|
||||
|
||||
@property
|
||||
def has_coverage(self):
|
||||
def has_coverage(self) -> bool:
|
||||
return self.segment_count > 0
|
||||
|
||||
@property
|
||||
def is_executable(self):
|
||||
def is_executable(self) -> bool:
|
||||
return self.has_count > 0
|
||||
|
||||
def get_coverage(
|
||||
|
@ -1,20 +1,22 @@
|
||||
import os
|
||||
import subprocess
|
||||
from typing import IO, Dict, List, Set
|
||||
from typing import IO, Dict, List, Set, Tuple
|
||||
|
||||
from ..oss.utils import get_pytorch_folder
|
||||
from ..util.setting import SUMMARY_FOLDER_DIR, TestList, TestStatusType
|
||||
|
||||
CoverageItem = Tuple[str, float, int, int]
|
||||
|
||||
def key_by_percentage(x):
|
||||
|
||||
def key_by_percentage(x: CoverageItem) -> float:
|
||||
return x[1]
|
||||
|
||||
|
||||
def key_by_name(x):
|
||||
def key_by_name(x: CoverageItem) -> str:
|
||||
return x[0]
|
||||
|
||||
|
||||
def is_intrested_file(file_path: str, interested_folders: List[str]):
|
||||
def is_intrested_file(file_path: str, interested_folders: List[str]) -> bool:
|
||||
if "cuda" in file_path:
|
||||
return False
|
||||
if "aten/gen_aten" in file_path or "aten/aten_" in file_path:
|
||||
@ -34,7 +36,7 @@ def is_this_type_of_tests(target_name: str, test_set_by_type: Set[str]) -> bool:
|
||||
|
||||
|
||||
def print_test_by_type(
|
||||
tests: TestList, test_set_by_type: Set[str], type_name: str, summary_file: IO
|
||||
tests: TestList, test_set_by_type: Set[str], type_name: str, summary_file: IO[str]
|
||||
) -> None:
|
||||
|
||||
print("Tests " + type_name + " to collect coverage:", file=summary_file)
|
||||
@ -49,7 +51,7 @@ def print_test_condition(
|
||||
tests_type: TestStatusType,
|
||||
interested_folders: List[str],
|
||||
coverage_only: List[str],
|
||||
summary_file: IO,
|
||||
summary_file: IO[str],
|
||||
summary_type: str,
|
||||
) -> None:
|
||||
print_test_by_type(tests, tests_type["success"], "fully success", summary_file)
|
||||
@ -91,14 +93,8 @@ def line_oriented_report(
|
||||
"LINE SUMMARY",
|
||||
)
|
||||
for file_name in covered_lines:
|
||||
if len(covered_lines[file_name]) == 0:
|
||||
covered = {}
|
||||
else:
|
||||
covered = covered_lines[file_name]
|
||||
if len(uncovered_lines[file_name]) == 0:
|
||||
uncovered = {}
|
||||
else:
|
||||
uncovered = uncovered_lines[file_name]
|
||||
covered = covered_lines[file_name]
|
||||
uncovered = uncovered_lines[file_name]
|
||||
print(
|
||||
f"{file_name}\n covered lines: {sorted(covered)}\n unconvered lines:{sorted(uncovered)}",
|
||||
file=report_file,
|
||||
@ -106,7 +102,7 @@ def line_oriented_report(
|
||||
|
||||
|
||||
def print_file_summary(
|
||||
covered_summary: int, total_summary: int, summary_file: IO
|
||||
covered_summary: int, total_summary: int, summary_file: IO[str]
|
||||
) -> float:
|
||||
# print summary first
|
||||
try:
|
||||
@ -124,10 +120,10 @@ def print_file_summary(
|
||||
|
||||
def print_file_oriented_report(
|
||||
tests_type: TestStatusType,
|
||||
coverage,
|
||||
coverage: List[CoverageItem],
|
||||
covered_summary: int,
|
||||
total_summary: int,
|
||||
summary_file: IO,
|
||||
summary_file: IO[str],
|
||||
tests: TestList,
|
||||
interested_folders: List[str],
|
||||
coverage_only: List[str],
|
||||
@ -178,7 +174,7 @@ def file_oriented_report(
|
||||
except ZeroDivisionError:
|
||||
percentage = 0
|
||||
# store information in a list to be sorted
|
||||
coverage.append([file_name, percentage, covered_count, total_count])
|
||||
coverage.append((file_name, percentage, covered_count, total_count))
|
||||
# update summary
|
||||
covered_summary = covered_summary + covered_count
|
||||
total_summary = total_summary + total_count
|
||||
@ -202,7 +198,7 @@ def get_html_ignored_pattern() -> List[str]:
|
||||
return ["/usr/*", "*anaconda3/*", "*third_party/*"]
|
||||
|
||||
|
||||
def html_oriented_report():
|
||||
def html_oriented_report() -> None:
|
||||
# use lcov to generate the coverage report
|
||||
build_folder = os.path.join(get_pytorch_folder(), "build")
|
||||
coverage_info_file = os.path.join(SUMMARY_FOLDER_DIR, "coverage.info")
|
||||
|
@ -55,7 +55,7 @@ def transform_file_name(
|
||||
|
||||
def is_intrested_file(
|
||||
file_path: str, interested_folders: List[str], platform: TestPlatform
|
||||
):
|
||||
) -> bool:
|
||||
ignored_patterns = ["cuda", "aten/gen_aten", "aten/aten_", "build/"]
|
||||
if any([pattern in file_path for pattern in ignored_patterns]):
|
||||
return False
|
||||
|
@ -12,12 +12,12 @@ def run_cpp_test(binary_file: str) -> None:
|
||||
print_error(f"Binary failed to run: {binary_file}")
|
||||
|
||||
|
||||
def get_tool_path_by_platform(platform: TestPlatform):
|
||||
def get_tool_path_by_platform(platform: TestPlatform) -> str:
|
||||
if platform == TestPlatform.FBCODE:
|
||||
from caffe2.fb.code_coverage.tool.package.fbcode.utils import get_llvm_tool_path
|
||||
from caffe2.fb.code_coverage.tool.package.fbcode.utils import get_llvm_tool_path # type: ignore[import]
|
||||
|
||||
return get_llvm_tool_path()
|
||||
return get_llvm_tool_path() # type: ignore[no-any-return]
|
||||
else:
|
||||
from ..oss.utils import get_llvm_tool_path
|
||||
from ..oss.utils import get_llvm_tool_path # type: ignore[no-redef]
|
||||
|
||||
return get_llvm_tool_path()
|
||||
return get_llvm_tool_path() # type: ignore[no-any-return]
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
import shutil
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
from typing import Any, NoReturn, Optional
|
||||
|
||||
from .setting import (
|
||||
LOG_DIR,
|
||||
@ -71,7 +71,7 @@ def convert_to_relative_path(whole_path: str, base_path: str) -> str:
|
||||
return whole_path[len(base_path) + 1 :]
|
||||
|
||||
|
||||
def replace_extension(filename, ext):
|
||||
def replace_extension(filename: str, ext: str) -> str:
|
||||
return filename[: filename.rfind(".")] + ext
|
||||
|
||||
|
||||
@ -89,11 +89,11 @@ def get_raw_profiles_folder() -> str:
|
||||
|
||||
def detect_compiler_type(platform: TestPlatform) -> CompilerType:
|
||||
if platform == TestPlatform.OSS:
|
||||
from package.oss.utils import detect_compiler_type
|
||||
from package.oss.utils import detect_compiler_type # type: ignore[misc]
|
||||
|
||||
cov_type = detect_compiler_type()
|
||||
cov_type = detect_compiler_type() # type: ignore[call-arg]
|
||||
else:
|
||||
from caffe2.fb.code_coverage.tool.package.fbcode.utils import (
|
||||
from caffe2.fb.code_coverage.tool.package.fbcode.utils import ( # type: ignore[import]
|
||||
detect_compiler_type,
|
||||
)
|
||||
|
||||
@ -138,7 +138,7 @@ def check_test_type(test_type: str, target: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def raise_no_test_found_exception(cpp_binary_folder: str, python_binary_folder: str):
|
||||
def raise_no_test_found_exception(cpp_binary_folder: str, python_binary_folder: str) -> NoReturn:
|
||||
raise RuntimeError(
|
||||
f"No cpp and python tests found in folder **{cpp_binary_folder} and **{python_binary_folder}**"
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
import setuptools
|
||||
import setuptools # type: ignore[import]
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
@ -8,9 +8,10 @@ compiled code has been executed. This means that even if the code chunk is merel
|
||||
marked as covered.
|
||||
'''
|
||||
|
||||
from coverage import CoveragePlugin, CoverageData
|
||||
from coverage import CoveragePlugin, CoverageData # type: ignore[import]
|
||||
from inspect import ismodule, isclass, ismethod, isfunction, iscode, getsourcefile, getsourcelines
|
||||
from time import time
|
||||
from typing import Any
|
||||
|
||||
# All coverage stats resulting from this plug-in will be in a separate .coverage file that should be merged later with
|
||||
# `coverage combine`. The convention seems to be .coverage.dotted.suffix based on the following link:
|
||||
@ -18,17 +19,17 @@ from time import time
|
||||
cov_data = CoverageData(basename=f'.coverage.jit.{time()}')
|
||||
|
||||
|
||||
def is_not_builtin_class(obj):
|
||||
def is_not_builtin_class(obj: Any) -> bool:
|
||||
return isclass(obj) and not type(obj).__module__ == 'builtins'
|
||||
|
||||
|
||||
class JitPlugin(CoveragePlugin):
|
||||
class JitPlugin(CoveragePlugin): # type: ignore[misc, no-any-unimported]
|
||||
'''
|
||||
dynamic_context is an overridden function that gives us access to every frame run during the coverage process. We
|
||||
look for when the function being run is `should_drop`, as all functions that get passed into `should_drop` will be
|
||||
compiled and thus should be marked as covered.
|
||||
'''
|
||||
def dynamic_context(self, frame):
|
||||
def dynamic_context(self, frame: Any) -> None:
|
||||
if frame.f_code.co_name == 'should_drop':
|
||||
obj = frame.f_locals['fn']
|
||||
# The many conditions in the if statement below are based on the accepted arguments to getsourcefile. Based
|
||||
@ -54,5 +55,5 @@ class JitPlugin(CoveragePlugin):
|
||||
cov_data.add_lines(line_data)
|
||||
super().dynamic_context(frame)
|
||||
|
||||
def coverage_init(reg, options):
|
||||
def coverage_init(reg: Any, options: Any) -> None:
|
||||
reg.add_dynamic_context(JitPlugin())
|
||||
|
@ -18,14 +18,18 @@ RESOURCES = [
|
||||
]
|
||||
|
||||
|
||||
def report_download_progress(chunk_number, chunk_size, file_size):
|
||||
def report_download_progress(
|
||||
chunk_number: int,
|
||||
chunk_size: int,
|
||||
file_size: int,
|
||||
) -> None:
|
||||
if file_size != -1:
|
||||
percent = min(1, (chunk_number * chunk_size) / file_size)
|
||||
bar = '#' * int(64 * percent)
|
||||
sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100)))
|
||||
|
||||
|
||||
def download(destination_path, resource, quiet):
|
||||
def download(destination_path: str, resource: str, quiet: bool) -> None:
|
||||
if os.path.exists(destination_path):
|
||||
if not quiet:
|
||||
print('{} already exists, skipping ...'.format(destination_path))
|
||||
@ -48,7 +52,7 @@ def download(destination_path, resource, quiet):
|
||||
raise RuntimeError('Error downloading resource!')
|
||||
|
||||
|
||||
def unzip(zipped_path, quiet):
|
||||
def unzip(zipped_path: str, quiet: bool) -> None:
|
||||
unzipped_path = os.path.splitext(zipped_path)[0]
|
||||
if os.path.exists(unzipped_path):
|
||||
if not quiet:
|
||||
@ -61,7 +65,7 @@ def unzip(zipped_path, quiet):
|
||||
print('Unzipped {} ...'.format(zipped_path))
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Download the MNIST dataset from the internet')
|
||||
parser.add_argument(
|
||||
|
@ -17,7 +17,7 @@ def get_test_case_times() -> Dict[str, float]:
|
||||
# an entry will be like ("test_doc_examples (__main__.TestTypeHints)" -> [values]))
|
||||
test_names_to_times: DefaultDict[str, List[float]] = defaultdict(list)
|
||||
for report in reports:
|
||||
if report.get('format_version', 1) != 2:
|
||||
if report.get('format_version', 1) != 2: # type: ignore[misc]
|
||||
raise RuntimeError("S3 format currently handled is version 2 only")
|
||||
v2report = cast(Version2Report, report)
|
||||
for test_file in v2report['files'].values():
|
||||
@ -46,7 +46,7 @@ def export_slow_tests(filename: str) -> None:
|
||||
file.write('\n')
|
||||
|
||||
|
||||
def parse_args():
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Export a JSON of slow test cases in PyTorch unit test suite')
|
||||
parser.add_argument(
|
||||
@ -61,7 +61,7 @@ def parse_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
options = parse_args()
|
||||
export_slow_tests(options.filename)
|
||||
|
||||
|
@ -14,7 +14,10 @@ import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import (Awaitable, DefaultDict, Dict, List, Match, Optional, Set,
|
||||
cast)
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
help_msg = '''fast_nvcc [OPTION]... -- [NVCC_ARG]...
|
||||
|
||||
@ -78,14 +81,14 @@ url_vars = f'{url_base}#keeping-intermediate-phase-files'
|
||||
re_tmp = r'(?<![\w\-/])(?:/tmp/)?(tmp[^ \"\'\\]+)'
|
||||
|
||||
|
||||
def fast_nvcc_warn(warning):
|
||||
def fast_nvcc_warn(warning: str) -> None:
|
||||
"""
|
||||
Warn the user about something regarding fast_nvcc.
|
||||
"""
|
||||
print(f'warning (fast_nvcc): {warning}', file=sys.stderr)
|
||||
|
||||
|
||||
def warn_if_windows():
|
||||
def warn_if_windows() -> None:
|
||||
"""
|
||||
Warn the user that using fast_nvcc on Windows might not work.
|
||||
"""
|
||||
@ -97,7 +100,7 @@ def warn_if_windows():
|
||||
fast_nvcc_warn(url_vars)
|
||||
|
||||
|
||||
def warn_if_tmpdir_flag(args):
|
||||
def warn_if_tmpdir_flag(args: List[str]) -> None:
|
||||
"""
|
||||
Warn the user that using fast_nvcc with some flags might not work.
|
||||
"""
|
||||
@ -121,11 +124,17 @@ def warn_if_tmpdir_flag(args):
|
||||
fast_nvcc_warn(f'{url_base}#{frag}')
|
||||
|
||||
|
||||
def nvcc_dryrun_data(binary, args):
|
||||
class DryunData(TypedDict):
|
||||
env: Dict[str, str]
|
||||
commands: List[str]
|
||||
exit_code: int
|
||||
|
||||
|
||||
def nvcc_dryrun_data(binary: str, args: List[str]) -> DryunData:
|
||||
"""
|
||||
Return parsed environment variables and commands from nvcc --dryrun.
|
||||
"""
|
||||
result = subprocess.run(
|
||||
result = subprocess.run( # type: ignore[call-overload]
|
||||
[binary, '--dryrun'] + args,
|
||||
capture_output=True,
|
||||
encoding='ascii', # this is just a guess
|
||||
@ -148,7 +157,7 @@ def nvcc_dryrun_data(binary, args):
|
||||
return {'env': env, 'commands': commands, 'exit_code': result.returncode}
|
||||
|
||||
|
||||
def warn_if_tmpdir_set(env):
|
||||
def warn_if_tmpdir_set(env: Dict[str, str]) -> None:
|
||||
"""
|
||||
Warn the user that setting TMPDIR with fast_nvcc might not work.
|
||||
"""
|
||||
@ -157,7 +166,7 @@ def warn_if_tmpdir_set(env):
|
||||
fast_nvcc_warn(url_vars)
|
||||
|
||||
|
||||
def contains_non_executable(commands):
|
||||
def contains_non_executable(commands: List[str]) -> bool:
|
||||
for command in commands:
|
||||
# This is to deal with special command dry-run result from NVCC such as:
|
||||
# ```
|
||||
@ -170,7 +179,7 @@ def contains_non_executable(commands):
|
||||
return False
|
||||
|
||||
|
||||
def module_id_contents(command):
|
||||
def module_id_contents(command: List[str]) -> str:
|
||||
"""
|
||||
Guess the contents of the .module_id file contained within command.
|
||||
"""
|
||||
@ -187,7 +196,7 @@ def module_id_contents(command):
|
||||
return f'_{len(middle)}_{middle}_{suffix}'
|
||||
|
||||
|
||||
def unique_module_id_files(commands):
|
||||
def unique_module_id_files(commands: List[str]) -> List[str]:
|
||||
"""
|
||||
Give each command its own .module_id filename instead of sharing.
|
||||
"""
|
||||
@ -196,7 +205,7 @@ def unique_module_id_files(commands):
|
||||
for i, line in enumerate(commands):
|
||||
arr = []
|
||||
|
||||
def uniqueify(s):
|
||||
def uniqueify(s: Match[str]) -> str:
|
||||
filename = re.sub(r'\-(\d+)', r'-\1-' + str(i), s.group(0))
|
||||
arr.append(filename)
|
||||
return filename
|
||||
@ -212,14 +221,19 @@ def unique_module_id_files(commands):
|
||||
return uniqueified
|
||||
|
||||
|
||||
def make_rm_force(commands):
|
||||
def make_rm_force(commands: List[str]) -> List[str]:
|
||||
"""
|
||||
Add --force to all rm commands.
|
||||
"""
|
||||
return [f'{c} --force' if c.startswith('rm ') else c for c in commands]
|
||||
|
||||
|
||||
def print_verbose_output(*, env, commands, filename):
|
||||
def print_verbose_output(
|
||||
*,
|
||||
env: Dict[str, str],
|
||||
commands: List[List[str]],
|
||||
filename: str,
|
||||
) -> None:
|
||||
"""
|
||||
Human-readably write nvcc --dryrun data to stderr.
|
||||
"""
|
||||
@ -234,21 +248,24 @@ def print_verbose_output(*, env, commands, filename):
|
||||
print(f'#{" "*len(prefix)}{part}', file=f)
|
||||
|
||||
|
||||
def straight_line_dependencies(commands):
|
||||
Graph = List[Set[int]]
|
||||
|
||||
|
||||
def straight_line_dependencies(commands: List[str]) -> Graph:
|
||||
"""
|
||||
Return a straight-line dependency graph.
|
||||
"""
|
||||
return [({i - 1} if i > 0 else set()) for i in range(len(commands))]
|
||||
|
||||
|
||||
def files_mentioned(command):
|
||||
def files_mentioned(command: str) -> List[str]:
|
||||
"""
|
||||
Return fully-qualified names of all tmp files referenced by command.
|
||||
"""
|
||||
return [f'/tmp/{match.group(1)}' for match in re.finditer(re_tmp, command)]
|
||||
|
||||
|
||||
def nvcc_data_dependencies(commands):
|
||||
def nvcc_data_dependencies(commands: List[str]) -> Graph:
|
||||
"""
|
||||
Return a list of the set of dependencies for each command.
|
||||
"""
|
||||
@ -261,8 +278,8 @@ def nvcc_data_dependencies(commands):
|
||||
# data dependency is sort of flipped, because the steps that use the
|
||||
# files generated by cicc need to wait for the fatbinary step to
|
||||
# finish first
|
||||
tmp_files = {}
|
||||
fatbins = collections.defaultdict(set)
|
||||
tmp_files: Dict[str, int] = {}
|
||||
fatbins: DefaultDict[int, Set[str]] = collections.defaultdict(set)
|
||||
graph = []
|
||||
for i, line in enumerate(commands):
|
||||
deps = set()
|
||||
@ -284,13 +301,13 @@ def nvcc_data_dependencies(commands):
|
||||
return graph
|
||||
|
||||
|
||||
def is_weakly_connected(graph):
|
||||
def is_weakly_connected(graph: Graph) -> bool:
|
||||
"""
|
||||
Return true iff graph is weakly connected.
|
||||
"""
|
||||
if not graph:
|
||||
return True
|
||||
neighbors = [set() for _ in graph]
|
||||
neighbors: List[Set[int]] = [set() for _ in graph]
|
||||
for node, predecessors in enumerate(graph):
|
||||
for pred in predecessors:
|
||||
neighbors[pred].add(node)
|
||||
@ -307,7 +324,7 @@ def is_weakly_connected(graph):
|
||||
return len(found) == len(graph)
|
||||
|
||||
|
||||
def warn_if_not_weakly_connected(graph):
|
||||
def warn_if_not_weakly_connected(graph: Graph) -> None:
|
||||
"""
|
||||
Warn the user if the execution graph is not weakly connected.
|
||||
"""
|
||||
@ -315,11 +332,16 @@ def warn_if_not_weakly_connected(graph):
|
||||
fast_nvcc_warn('execution graph is not (weakly) connected')
|
||||
|
||||
|
||||
def print_dot_graph(*, commands, graph, filename):
|
||||
def print_dot_graph(
|
||||
*,
|
||||
commands: List[List[str]],
|
||||
graph: Graph,
|
||||
filename: str,
|
||||
) -> None:
|
||||
"""
|
||||
Print a DOT file displaying short versions of the commands in graph.
|
||||
"""
|
||||
def name(k):
|
||||
def name(k: int) -> str:
|
||||
return f'"{k} {os.path.basename(commands[k][0])}"'
|
||||
with open(filename, 'w') as f:
|
||||
print('digraph {', file=f)
|
||||
@ -332,7 +354,24 @@ def print_dot_graph(*, commands, graph, filename):
|
||||
print('}', file=f)
|
||||
|
||||
|
||||
async def run_command(command, *, env, deps, gather_data, i, save):
|
||||
|
||||
class Result(TypedDict, total=False):
|
||||
exit_code: int
|
||||
stdout: bytes
|
||||
stderr: bytes
|
||||
time: float
|
||||
files: Dict[str, int]
|
||||
|
||||
|
||||
async def run_command(
|
||||
command: str,
|
||||
*,
|
||||
env: Dict[str, str],
|
||||
deps: Set[Awaitable[Result]],
|
||||
gather_data: bool,
|
||||
i: int,
|
||||
save: Optional[str],
|
||||
) -> Result:
|
||||
"""
|
||||
Run the command with the given env after waiting for deps.
|
||||
"""
|
||||
@ -350,8 +389,8 @@ async def run_command(command, *, env, deps, gather_data, i, save):
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
code = proc.returncode
|
||||
results = {'exit_code': code, 'stdout': stdout, 'stderr': stderr}
|
||||
code = cast(int, proc.returncode)
|
||||
results: Result = {'exit_code': code, 'stdout': stdout, 'stderr': stderr}
|
||||
if gather_data:
|
||||
t2 = time.monotonic()
|
||||
results['time'] = t2 - t1
|
||||
@ -371,14 +410,21 @@ async def run_command(command, *, env, deps, gather_data, i, save):
|
||||
return results
|
||||
|
||||
|
||||
async def run_graph(*, env, commands, graph, gather_data=False, save=None):
|
||||
async def run_graph(
|
||||
*,
|
||||
env: Dict[str, str],
|
||||
commands: List[str],
|
||||
graph: Graph,
|
||||
gather_data: bool = False,
|
||||
save: Optional[str] = None,
|
||||
) -> List[Result]:
|
||||
"""
|
||||
Return outputs/errors (and optionally time/file info) from commands.
|
||||
"""
|
||||
tasks = []
|
||||
tasks: List[Awaitable[Result]] = []
|
||||
for i, (command, indices) in enumerate(zip(commands, graph)):
|
||||
deps = {tasks[j] for j in indices}
|
||||
tasks.append(asyncio.create_task(run_command(
|
||||
tasks.append(asyncio.create_task(run_command( # type: ignore[attr-defined]
|
||||
command,
|
||||
env=env,
|
||||
deps=deps,
|
||||
@ -389,7 +435,7 @@ async def run_graph(*, env, commands, graph, gather_data=False, save=None):
|
||||
return [await task for task in tasks]
|
||||
|
||||
|
||||
def print_command_outputs(command_results):
|
||||
def print_command_outputs(command_results: List[Result]) -> None:
|
||||
"""
|
||||
Print captured stdout and stderr from commands.
|
||||
"""
|
||||
@ -398,11 +444,16 @@ def print_command_outputs(command_results):
|
||||
sys.stderr.write(result.get('stderr', b'').decode('ascii'))
|
||||
|
||||
|
||||
def write_log_csv(command_parts, command_results, *, filename):
|
||||
def write_log_csv(
|
||||
command_parts: List[List[str]],
|
||||
command_results: List[Result],
|
||||
*,
|
||||
filename: str,
|
||||
) -> None:
|
||||
"""
|
||||
Write a CSV file of the times and /tmp file sizes from each command.
|
||||
"""
|
||||
tmp_files = []
|
||||
tmp_files: List[str] = []
|
||||
for result in command_results:
|
||||
tmp_files.extend(result.get('files', {}).keys())
|
||||
with open(filename, 'w', newline='') as csvfile:
|
||||
@ -415,7 +466,7 @@ def write_log_csv(command_parts, command_results, *, filename):
|
||||
writer.writerow({**row, **result.get('files', {})})
|
||||
|
||||
|
||||
def exit_code(results):
|
||||
def exit_code(results: List[Result]) -> int:
|
||||
"""
|
||||
Aggregate individual exit codes into a single code.
|
||||
"""
|
||||
@ -426,11 +477,18 @@ def exit_code(results):
|
||||
return 0
|
||||
|
||||
|
||||
def wrap_nvcc(args, config=default_config):
|
||||
def wrap_nvcc(
|
||||
args: List[str],
|
||||
config: argparse.Namespace = default_config,
|
||||
) -> int:
|
||||
return subprocess.call([config.nvcc] + args)
|
||||
|
||||
|
||||
def fast_nvcc(args, *, config=default_config):
|
||||
def fast_nvcc(
|
||||
args: List[str],
|
||||
*,
|
||||
config: argparse.Namespace = default_config,
|
||||
) -> int:
|
||||
"""
|
||||
Emulate the result of calling the given nvcc binary with args.
|
||||
|
||||
@ -465,7 +523,7 @@ def fast_nvcc(args, *, config=default_config):
|
||||
)
|
||||
if config.sequential:
|
||||
graph = straight_line_dependencies(commands)
|
||||
results = asyncio.run(run_graph(
|
||||
results = asyncio.run(run_graph( # type: ignore[attr-defined]
|
||||
env=env,
|
||||
commands=commands,
|
||||
graph=graph,
|
||||
@ -478,7 +536,7 @@ def fast_nvcc(args, *, config=default_config):
|
||||
return exit_code([dryrun_data] + results)
|
||||
|
||||
|
||||
def our_arg(arg):
|
||||
def our_arg(arg: str) -> bool:
|
||||
return arg != '--'
|
||||
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import sys
|
||||
|
||||
from flake8.main import git
|
||||
from flake8.main import git # type: ignore[import]
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(
|
||||
|
@ -1,5 +1,6 @@
|
||||
import gdb
|
||||
import gdb # type: ignore[import]
|
||||
import textwrap
|
||||
from typing import Any
|
||||
|
||||
class DisableBreakpoints:
|
||||
"""
|
||||
@ -8,18 +9,18 @@ class DisableBreakpoints:
|
||||
commands
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> None:
|
||||
self.disabled_breakpoints = []
|
||||
for b in gdb.breakpoints():
|
||||
if b.enabled:
|
||||
b.enabled = False
|
||||
self.disabled_breakpoints.append(b)
|
||||
|
||||
def __exit__(self, etype, evalue, tb):
|
||||
def __exit__(self, etype: Any, evalue: Any, tb: Any) -> None:
|
||||
for b in self.disabled_breakpoints:
|
||||
b.enabled = True
|
||||
|
||||
class TensorRepr(gdb.Command):
|
||||
class TensorRepr(gdb.Command): # type: ignore[misc, no-any-unimported]
|
||||
"""
|
||||
Print a human readable representation of the given at::Tensor.
|
||||
Usage: torch-tensor-repr EXP
|
||||
@ -31,11 +32,11 @@ class TensorRepr(gdb.Command):
|
||||
"""
|
||||
__doc__ = textwrap.dedent(__doc__).strip()
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
gdb.Command.__init__(self, 'torch-tensor-repr',
|
||||
gdb.COMMAND_USER, gdb.COMPLETE_EXPRESSION)
|
||||
|
||||
def invoke(self, args, from_tty):
|
||||
def invoke(self, args: str, from_tty: bool) -> None:
|
||||
args = gdb.string_to_argv(args)
|
||||
if len(args) != 1:
|
||||
print('Usage: torch-tensor-repr EXP')
|
||||
|
@ -2,7 +2,7 @@ import argparse
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from setuptools import distutils
|
||||
from setuptools import distutils # type: ignore[import]
|
||||
from typing import Optional, Union
|
||||
|
||||
def get_sha(pytorch_root: Union[str, Path]) -> str:
|
||||
|
@ -103,7 +103,7 @@ def write_selected_mobile_ops_with_all_dtypes(
|
||||
header_contents = "".join(body_parts)
|
||||
out_file.write(header_contents.encode("utf-8"))
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate selected_mobile_ops.h for selective build."
|
||||
)
|
||||
|
119
tools/nightly.py
119
tools/nightly.py
@ -40,10 +40,10 @@ import contextlib
|
||||
import subprocess
|
||||
from ast import literal_eval
|
||||
from argparse import ArgumentParser
|
||||
from typing import Dict, Optional, Iterator
|
||||
from typing import (Any, Callable, Dict, Generator, Iterable, Iterator, List,
|
||||
Optional, Sequence, Set, Tuple, TypeVar, cast)
|
||||
|
||||
|
||||
LOGGER = None
|
||||
LOGGER: Optional[logging.Logger] = None
|
||||
URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
|
||||
DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
|
||||
SHA1_RE = re.compile("([0-9a-fA-F]{40})")
|
||||
@ -133,7 +133,7 @@ def logging_rotate() -> None:
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def logging_manager(*, debug: bool = False) -> Iterator[None]:
|
||||
def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, None]:
|
||||
"""Setup logging. If a failure starts here we won't
|
||||
be able to save the user in a reasonable way.
|
||||
|
||||
@ -179,7 +179,7 @@ def logging_manager(*, debug: bool = False) -> Iterator[None]:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def check_in_repo():
|
||||
def check_in_repo() -> Optional[str]:
|
||||
"""Ensures that we are in the PyTorch repo."""
|
||||
if not os.path.isfile("setup.py"):
|
||||
return "Not in root-level PyTorch repo, no setup.py found"
|
||||
@ -187,12 +187,13 @@ def check_in_repo():
|
||||
s = f.read()
|
||||
if "PyTorch" not in s:
|
||||
return "Not in PyTorch repo, 'PyTorch' not found in setup.py"
|
||||
return None
|
||||
|
||||
|
||||
def check_branch(subcommand, branch):
|
||||
def check_branch(subcommand: str, branch: Optional[str]) -> Optional[str]:
|
||||
"""Checks that the branch name can be checked out."""
|
||||
if subcommand != "checkout":
|
||||
return
|
||||
return None
|
||||
# first make sure actual branch name was given
|
||||
if branch is None:
|
||||
return "Branch name to checkout must be supplied with '-b' option"
|
||||
@ -203,36 +204,44 @@ def check_branch(subcommand, branch):
|
||||
return "Need to have clean working tree to checkout!\n\n" + p.stdout
|
||||
# next check that the branch name doesn't already exist
|
||||
cmd = ["git", "show-ref", "--verify", "--quiet", "refs/heads/" + branch]
|
||||
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False)
|
||||
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False) # type: ignore[assignment]
|
||||
if not p.returncode:
|
||||
return f"Branch {branch!r} already exists"
|
||||
return None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def timer(logger, prefix):
|
||||
def timer(logger: logging.Logger, prefix: str) -> Iterator[None]:
|
||||
"""Timed context manager"""
|
||||
start_time = time.time()
|
||||
yield
|
||||
logger.info(f"{prefix} took {time.time() - start_time:.3f} [s]")
|
||||
|
||||
|
||||
def timed(prefix):
|
||||
F = TypeVar('F', bound=Callable[..., Any])
|
||||
|
||||
|
||||
def timed(prefix: str) -> Callable[[F], F]:
|
||||
"""Decorator for timing functions"""
|
||||
|
||||
def dec(f):
|
||||
def dec(f: F) -> F:
|
||||
@functools.wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
global LOGGER
|
||||
LOGGER.info(prefix)
|
||||
with timer(LOGGER, prefix):
|
||||
logger = cast(logging.Logger, LOGGER)
|
||||
logger.info(prefix)
|
||||
with timer(logger, prefix):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return cast(F, wrapper)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
def _make_channel_args(channels=("pytorch-nightly",), override_channels=False):
|
||||
def _make_channel_args(
|
||||
channels: Iterable[str] = ("pytorch-nightly",),
|
||||
override_channels: bool = False,
|
||||
) -> List[str]:
|
||||
args = []
|
||||
for channel in channels:
|
||||
args.append("--channel")
|
||||
@ -244,8 +253,11 @@ def _make_channel_args(channels=("pytorch-nightly",), override_channels=False):
|
||||
|
||||
@timed("Solving conda environment")
|
||||
def conda_solve(
|
||||
name=None, prefix=None, channels=("pytorch-nightly",), override_channels=False
|
||||
):
|
||||
name: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
channels: Iterable[str] = ("pytorch-nightly",),
|
||||
override_channels: bool = False,
|
||||
) -> Tuple[List[str], str, str, bool, List[str]]:
|
||||
"""Performs the conda solve and splits the deps from the package."""
|
||||
# compute what environment to use
|
||||
if prefix is not None:
|
||||
@ -299,7 +311,7 @@ def conda_solve(
|
||||
|
||||
|
||||
@timed("Installing dependencies")
|
||||
def deps_install(deps, existing_env, env_opts):
|
||||
def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> None:
|
||||
"""Install dependencies to deps environment"""
|
||||
if not existing_env:
|
||||
# first remove previous pytorch-deps env
|
||||
@ -312,7 +324,7 @@ def deps_install(deps, existing_env, env_opts):
|
||||
|
||||
|
||||
@timed("Installing pytorch nightly binaries")
|
||||
def pytorch_install(url):
|
||||
def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]:
|
||||
""""Install pytorch into a temporary directory"""
|
||||
pytdir = tempfile.TemporaryDirectory()
|
||||
cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url]
|
||||
@ -320,7 +332,7 @@ def pytorch_install(url):
|
||||
return pytdir
|
||||
|
||||
|
||||
def _site_packages(dirname, platform):
|
||||
def _site_packages(dirname: str, platform: str) -> str:
|
||||
if platform.startswith("win"):
|
||||
template = os.path.join(dirname, "Lib", "site-packages")
|
||||
else:
|
||||
@ -329,7 +341,7 @@ def _site_packages(dirname, platform):
|
||||
return spdir
|
||||
|
||||
|
||||
def _ensure_commit(git_sha1):
|
||||
def _ensure_commit(git_sha1: str) -> None:
|
||||
"""Make sure that we actually have the commit locally"""
|
||||
cmd = ["git", "cat-file", "-e", git_sha1 + "^{commit}"]
|
||||
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False)
|
||||
@ -341,7 +353,7 @@ def _ensure_commit(git_sha1):
|
||||
p = subprocess.run(cmd, check=True)
|
||||
|
||||
|
||||
def _nightly_version(spdir):
|
||||
def _nightly_version(spdir: str) -> str:
|
||||
# first get the git version from the installed module
|
||||
version_fname = os.path.join(spdir, "torch", "version.py")
|
||||
with open(version_fname) as f:
|
||||
@ -371,7 +383,7 @@ def _nightly_version(spdir):
|
||||
|
||||
|
||||
@timed("Checking out nightly PyTorch")
|
||||
def checkout_nightly_version(branch, spdir):
|
||||
def checkout_nightly_version(branch: str, spdir: str) -> None:
|
||||
"""Get's the nightly version and then checks it out."""
|
||||
nightly_version = _nightly_version(spdir)
|
||||
cmd = ["git", "checkout", "-b", branch, nightly_version]
|
||||
@ -379,40 +391,40 @@ def checkout_nightly_version(branch, spdir):
|
||||
|
||||
|
||||
@timed("Pulling nightly PyTorch")
|
||||
def pull_nightly_version(spdir):
|
||||
def pull_nightly_version(spdir: str) -> None:
|
||||
"""Fetches the nightly version and then merges it ."""
|
||||
nightly_version = _nightly_version(spdir)
|
||||
cmd = ["git", "merge", nightly_version]
|
||||
p = subprocess.run(cmd, check=True)
|
||||
|
||||
|
||||
def _get_listing_linux(source_dir):
|
||||
def _get_listing_linux(source_dir: str) -> List[str]:
|
||||
listing = glob.glob(os.path.join(source_dir, "*.so"))
|
||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.so")))
|
||||
return listing
|
||||
|
||||
|
||||
def _get_listing_osx(source_dir):
|
||||
def _get_listing_osx(source_dir: str) -> List[str]:
|
||||
# oddly, these are .so files even on Mac
|
||||
listing = glob.glob(os.path.join(source_dir, "*.so"))
|
||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dylib")))
|
||||
return listing
|
||||
|
||||
|
||||
def _get_listing_win(source_dir):
|
||||
def _get_listing_win(source_dir: str) -> List[str]:
|
||||
listing = glob.glob(os.path.join(source_dir, "*.pyd"))
|
||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.lib")))
|
||||
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dll")))
|
||||
return listing
|
||||
|
||||
|
||||
def _glob_pyis(d):
|
||||
def _glob_pyis(d: str) -> Set[str]:
|
||||
search = os.path.join(d, "**", "*.pyi")
|
||||
pyis = {os.path.relpath(p, d) for p in glob.iglob(search)}
|
||||
return pyis
|
||||
|
||||
|
||||
def _find_missing_pyi(source_dir, target_dir):
|
||||
def _find_missing_pyi(source_dir: str, target_dir: str) -> List[str]:
|
||||
source_pyis = _glob_pyis(source_dir)
|
||||
target_pyis = _glob_pyis(target_dir)
|
||||
missing_pyis = [os.path.join(source_dir, p) for p in (source_pyis - target_pyis)]
|
||||
@ -420,7 +432,7 @@ def _find_missing_pyi(source_dir, target_dir):
|
||||
return missing_pyis
|
||||
|
||||
|
||||
def _get_listing(source_dir, target_dir, platform):
|
||||
def _get_listing(source_dir: str, target_dir: str, platform: str) -> List[str]:
|
||||
if platform.startswith("linux"):
|
||||
listing = _get_listing_linux(source_dir)
|
||||
elif platform.startswith("osx"):
|
||||
@ -437,7 +449,7 @@ def _get_listing(source_dir, target_dir, platform):
|
||||
return listing
|
||||
|
||||
|
||||
def _remove_existing(trg, is_dir):
|
||||
def _remove_existing(trg: str, is_dir: bool) -> None:
|
||||
if os.path.exists(trg):
|
||||
if is_dir:
|
||||
shutil.rmtree(trg)
|
||||
@ -445,7 +457,13 @@ def _remove_existing(trg, is_dir):
|
||||
os.remove(trg)
|
||||
|
||||
|
||||
def _move_single(src, source_dir, target_dir, mover, verb):
|
||||
def _move_single(
|
||||
src: str,
|
||||
source_dir: str,
|
||||
target_dir: str,
|
||||
mover: Callable[[str, str], None],
|
||||
verb: str,
|
||||
) -> None:
|
||||
is_dir = os.path.isdir(src)
|
||||
relpath = os.path.relpath(src, source_dir)
|
||||
trg = os.path.join(target_dir, relpath)
|
||||
@ -469,18 +487,18 @@ def _move_single(src, source_dir, target_dir, mover, verb):
|
||||
mover(src, trg)
|
||||
|
||||
|
||||
def _copy_files(listing, source_dir, target_dir):
|
||||
def _copy_files(listing: List[str], source_dir: str, target_dir: str) -> None:
|
||||
for src in listing:
|
||||
_move_single(src, source_dir, target_dir, shutil.copy2, "Copying")
|
||||
|
||||
|
||||
def _link_files(listing, source_dir, target_dir):
|
||||
def _link_files(listing: List[str], source_dir: str, target_dir: str) -> None:
|
||||
for src in listing:
|
||||
_move_single(src, source_dir, target_dir, os.link, "Linking")
|
||||
|
||||
|
||||
@timed("Moving nightly files into repo")
|
||||
def move_nightly_files(spdir, platform):
|
||||
def move_nightly_files(spdir: str, platform: str) -> None:
|
||||
"""Moves PyTorch files from temporary installed location to repo."""
|
||||
# get file listing
|
||||
source_dir = os.path.join(spdir, "torch")
|
||||
@ -496,7 +514,7 @@ def move_nightly_files(spdir, platform):
|
||||
_copy_files(listing, source_dir, target_dir)
|
||||
|
||||
|
||||
def _available_envs():
|
||||
def _available_envs() -> Dict[str, str]:
|
||||
cmd = ["conda", "env", "list"]
|
||||
p = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
|
||||
lines = p.stdout.splitlines()
|
||||
@ -513,7 +531,7 @@ def _available_envs():
|
||||
|
||||
|
||||
@timed("Writing pytorch-nightly.pth")
|
||||
def write_pth(env_opts, platform):
|
||||
def write_pth(env_opts: List[str], platform: str) -> None:
|
||||
"""Writes Python path file for this dir."""
|
||||
env_type, env_dir = env_opts
|
||||
if env_type == "--name":
|
||||
@ -533,17 +551,16 @@ def write_pth(env_opts, platform):
|
||||
|
||||
|
||||
def install(
|
||||
subcommand="checkout",
|
||||
branch=None,
|
||||
name=None,
|
||||
prefix=None,
|
||||
channels=("pytorch-nightly",),
|
||||
override_channels=False,
|
||||
logger=None,
|
||||
):
|
||||
*,
|
||||
logger: logging.Logger,
|
||||
subcommand: str = "checkout",
|
||||
branch: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
channels: Iterable[str] = ("pytorch-nightly",),
|
||||
override_channels: bool = False,
|
||||
) -> None:
|
||||
"""Development install of PyTorch"""
|
||||
global LOGGER
|
||||
logger = logger or LOGGER
|
||||
deps, pytorch, platform, existing_env, env_opts = conda_solve(
|
||||
name=name, prefix=prefix, channels=channels, override_channels=override_channels
|
||||
)
|
||||
@ -552,7 +569,7 @@ def install(
|
||||
pytdir = pytorch_install(pytorch)
|
||||
spdir = _site_packages(pytdir.name, platform)
|
||||
if subcommand == "checkout":
|
||||
checkout_nightly_version(branch, spdir)
|
||||
checkout_nightly_version(cast(str, branch), spdir)
|
||||
elif subcommand == "pull":
|
||||
pull_nightly_version(spdir)
|
||||
else:
|
||||
@ -566,7 +583,7 @@ def install(
|
||||
)
|
||||
|
||||
|
||||
def make_parser():
|
||||
def make_parser() -> ArgumentParser:
|
||||
p = ArgumentParser("nightly")
|
||||
# subcommands
|
||||
subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute")
|
||||
@ -627,7 +644,7 @@ def make_parser():
|
||||
return p
|
||||
|
||||
|
||||
def main(args=None):
|
||||
def main(args: Optional[Sequence[str]] = None) -> None:
|
||||
"""Main entry point"""
|
||||
global LOGGER
|
||||
p = make_parser()
|
||||
|
@ -16,8 +16,8 @@ try:
|
||||
except ImportError:
|
||||
print("rich not found, for color output use 'pip install rich'")
|
||||
|
||||
def parse_junit_reports(path_to_reports: str) -> List[TestCase]:
|
||||
def parse_file(path: str) -> List[TestCase]:
|
||||
def parse_junit_reports(path_to_reports: str) -> List[TestCase]: # type: ignore[no-any-unimported]
|
||||
def parse_file(path: str) -> List[TestCase]: # type: ignore[no-any-unimported]
|
||||
try:
|
||||
return convert_junit_to_testcases(JUnitXml.fromfile(path))
|
||||
except Exception as err:
|
||||
@ -37,7 +37,7 @@ def parse_junit_reports(path_to_reports: str) -> List[TestCase]:
|
||||
return ret_xml
|
||||
|
||||
|
||||
def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase]:
|
||||
def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase]: # type: ignore[no-any-unimported]
|
||||
testcases = []
|
||||
for item in xml:
|
||||
if isinstance(item, TestSuite):
|
||||
@ -46,7 +46,7 @@ def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase
|
||||
testcases.append(item)
|
||||
return testcases
|
||||
|
||||
def render_tests(testcases: List[TestCase]) -> None:
|
||||
def render_tests(testcases: List[TestCase]) -> None: # type: ignore[no-any-unimported]
|
||||
num_passed = 0
|
||||
num_skipped = 0
|
||||
num_failed = 0
|
||||
|
@ -1,8 +1,9 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def which(thefile):
|
||||
def which(thefile: str) -> Optional[str]:
|
||||
path = os.environ.get("PATH", os.defpath).split(os.pathsep)
|
||||
for d in path:
|
||||
fname = os.path.join(d, thefile)
|
||||
|
@ -8,14 +8,15 @@ import re
|
||||
from subprocess import check_call, check_output, CalledProcessError
|
||||
import sys
|
||||
import sysconfig
|
||||
from setuptools import distutils
|
||||
from setuptools import distutils # type: ignore[import]
|
||||
from typing import IO, Any, Dict, List, Optional, Union
|
||||
|
||||
from . import which
|
||||
from .env import (BUILD_DIR, IS_64BIT, IS_DARWIN, IS_WINDOWS, check_negative_env_flag)
|
||||
from .numpy_ import USE_NUMPY, NUMPY_INCLUDE_DIR
|
||||
|
||||
|
||||
def _mkdir_p(d):
|
||||
def _mkdir_p(d: str) -> None:
|
||||
try:
|
||||
os.makedirs(d)
|
||||
except OSError:
|
||||
@ -28,7 +29,11 @@ def _mkdir_p(d):
|
||||
USE_NINJA = (not check_negative_env_flag('USE_NINJA') and
|
||||
which('ninja') is not None)
|
||||
|
||||
def convert_cmake_value_to_python_value(cmake_value, cmake_type):
|
||||
|
||||
CMakeValue = Optional[Union[bool, str]]
|
||||
|
||||
|
||||
def convert_cmake_value_to_python_value(cmake_value: str, cmake_type: str) -> CMakeValue:
|
||||
r"""Convert a CMake value in a string form to a Python value.
|
||||
|
||||
Args:
|
||||
@ -52,7 +57,7 @@ def convert_cmake_value_to_python_value(cmake_value, cmake_type):
|
||||
else: # Directly return the cmake_value.
|
||||
return cmake_value
|
||||
|
||||
def get_cmake_cache_variables_from_file(cmake_cache_file):
|
||||
def get_cmake_cache_variables_from_file(cmake_cache_file: IO[str]) -> Dict[str, CMakeValue]:
|
||||
r"""Gets values in CMakeCache.txt into a dictionary.
|
||||
|
||||
Args:
|
||||
@ -93,12 +98,12 @@ def get_cmake_cache_variables_from_file(cmake_cache_file):
|
||||
class CMake:
|
||||
"Manages cmake."
|
||||
|
||||
def __init__(self, build_dir=BUILD_DIR):
|
||||
def __init__(self, build_dir: str = BUILD_DIR) -> None:
|
||||
self._cmake_command = CMake._get_cmake_command()
|
||||
self.build_dir = build_dir
|
||||
|
||||
@property
|
||||
def _cmake_cache_file(self):
|
||||
def _cmake_cache_file(self) -> str:
|
||||
r"""Returns the path to CMakeCache.txt.
|
||||
|
||||
Returns:
|
||||
@ -107,7 +112,7 @@ class CMake:
|
||||
return os.path.join(self.build_dir, 'CMakeCache.txt')
|
||||
|
||||
@staticmethod
|
||||
def _get_cmake_command():
|
||||
def _get_cmake_command() -> str:
|
||||
"Returns cmake command."
|
||||
|
||||
cmake_command = 'cmake'
|
||||
@ -124,7 +129,7 @@ class CMake:
|
||||
raise RuntimeError('no cmake or cmake3 with version >= 3.5.0 found')
|
||||
|
||||
@staticmethod
|
||||
def _get_version(cmd):
|
||||
def _get_version(cmd: str) -> Any:
|
||||
"Returns cmake version."
|
||||
|
||||
for line in check_output([cmd, '--version']).decode('utf-8').split('\n'):
|
||||
@ -132,7 +137,7 @@ class CMake:
|
||||
return distutils.version.LooseVersion(line.strip().split(' ')[2])
|
||||
raise RuntimeError('no version found')
|
||||
|
||||
def run(self, args, env):
|
||||
def run(self, args: List[str], env: Dict[str, str]) -> None:
|
||||
"Executes cmake with arguments and an environment."
|
||||
|
||||
command = [self._cmake_command] + args
|
||||
@ -146,13 +151,13 @@ class CMake:
|
||||
sys.exit(1)
|
||||
|
||||
@staticmethod
|
||||
def defines(args, **kwargs):
|
||||
def defines(args: List[str], **kwargs: CMakeValue) -> None:
|
||||
"Adds definitions to a cmake argument list."
|
||||
for key, value in sorted(kwargs.items()):
|
||||
if value is not None:
|
||||
args.append('-D{}={}'.format(key, value))
|
||||
|
||||
def get_cmake_cache_variables(self):
|
||||
def get_cmake_cache_variables(self) -> Dict[str, CMakeValue]:
|
||||
r"""Gets values in CMakeCache.txt into a dictionary.
|
||||
Returns:
|
||||
dict: A ``dict`` containing the value of cached CMake variables.
|
||||
@ -160,7 +165,15 @@ class CMake:
|
||||
with open(self._cmake_cache_file) as f:
|
||||
return get_cmake_cache_variables_from_file(f)
|
||||
|
||||
def generate(self, version, cmake_python_library, build_python, build_test, my_env, rerun):
|
||||
def generate(
|
||||
self,
|
||||
version: Optional[str],
|
||||
cmake_python_library: Optional[str],
|
||||
build_python: bool,
|
||||
build_test: bool,
|
||||
my_env: Dict[str, str],
|
||||
rerun: bool,
|
||||
) -> None:
|
||||
"Runs cmake to generate native build files."
|
||||
|
||||
if rerun and os.path.isfile(self._cmake_cache_file):
|
||||
@ -215,7 +228,7 @@ class CMake:
|
||||
_mkdir_p(self.build_dir)
|
||||
|
||||
# Store build options that are directly stored in environment variables
|
||||
build_options = {
|
||||
build_options: Dict[str, CMakeValue] = {
|
||||
# The default value cannot be easily obtained in CMakeLists.txt. We set it here.
|
||||
'CMAKE_PREFIX_PATH': sysconfig.get_path('purelib')
|
||||
}
|
||||
@ -335,7 +348,7 @@ class CMake:
|
||||
args.append(base_dir)
|
||||
self.run(args, env=my_env)
|
||||
|
||||
def build(self, my_env):
|
||||
def build(self, my_env: Dict[str, str]) -> None:
|
||||
"Runs cmake to build binaries."
|
||||
|
||||
from .env import build_type
|
||||
|
@ -3,6 +3,7 @@ import platform
|
||||
import struct
|
||||
import sys
|
||||
from itertools import chain
|
||||
from typing import Iterable, List, Optional, cast
|
||||
|
||||
|
||||
IS_WINDOWS = (platform.system() == 'Windows')
|
||||
@ -17,19 +18,19 @@ IS_64BIT = (struct.calcsize("P") == 8)
|
||||
BUILD_DIR = 'build'
|
||||
|
||||
|
||||
def check_env_flag(name, default=''):
|
||||
def check_env_flag(name: str, default: str = '') -> bool:
|
||||
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']
|
||||
|
||||
|
||||
def check_negative_env_flag(name, default=''):
|
||||
def check_negative_env_flag(name: str, default: str = '') -> bool:
|
||||
return os.getenv(name, default).upper() in ['OFF', '0', 'NO', 'FALSE', 'N']
|
||||
|
||||
|
||||
def gather_paths(env_vars):
|
||||
def gather_paths(env_vars: Iterable[str]) -> List[str]:
|
||||
return list(chain(*(os.getenv(v, '').split(os.pathsep) for v in env_vars)))
|
||||
|
||||
|
||||
def lib_paths_from_base(base_path):
|
||||
def lib_paths_from_base(base_path: str) -> List[str]:
|
||||
return [os.path.join(base_path, s) for s in ['lib/x64', 'lib', 'lib64']]
|
||||
|
||||
|
||||
@ -49,7 +50,7 @@ class BuildType(object):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, cmake_build_type_env=None):
|
||||
def __init__(self, cmake_build_type_env: Optional[str] = None) -> None:
|
||||
if cmake_build_type_env is not None:
|
||||
self.build_type_string = cmake_build_type_env
|
||||
return
|
||||
@ -63,19 +64,19 @@ class BuildType(object):
|
||||
# Normally it is anti-pattern to determine build type from CMAKE_BUILD_TYPE because it is not used for
|
||||
# multi-configuration build tools, such as Visual Studio and XCode. But since we always communicate with
|
||||
# CMake using CMAKE_BUILD_TYPE from our Python scripts, this is OK here.
|
||||
self.build_type_string = cmake_cache_vars['CMAKE_BUILD_TYPE']
|
||||
self.build_type_string = cast(str, cmake_cache_vars['CMAKE_BUILD_TYPE'])
|
||||
else:
|
||||
self.build_type_string = os.environ.get('CMAKE_BUILD_TYPE', 'Release')
|
||||
|
||||
def is_debug(self):
|
||||
def is_debug(self) -> bool:
|
||||
"Checks Debug build."
|
||||
return self.build_type_string == 'Debug'
|
||||
|
||||
def is_rel_with_deb_info(self):
|
||||
def is_rel_with_deb_info(self) -> bool:
|
||||
"Checks RelWithDebInfo build."
|
||||
return self.build_type_string == 'RelWithDebInfo'
|
||||
|
||||
def is_release(self):
|
||||
def is_release(self) -> bool:
|
||||
"Checks Release build."
|
||||
return self.build_type_string == 'Release'
|
||||
|
||||
|
@ -4,9 +4,12 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Dict, Tuple, cast
|
||||
|
||||
Version = Tuple[int, int, int]
|
||||
|
||||
|
||||
def parse_version(version: str) -> (int, int, int):
|
||||
def parse_version(version: str) -> Version:
|
||||
"""
|
||||
Parses a version string into (major, minor, patch) version numbers.
|
||||
|
||||
@ -24,10 +27,10 @@ def parse_version(version: str) -> (int, int, int):
|
||||
version_number_str = version[:i]
|
||||
break
|
||||
|
||||
return tuple([int(n) for n in version_number_str.split(".")])
|
||||
return cast(Version, tuple([int(n) for n in version_number_str.split(".")]))
|
||||
|
||||
|
||||
def apply_replacements(replacements, text):
|
||||
def apply_replacements(replacements: Dict[str, str], text: str) -> str:
|
||||
"""
|
||||
Applies the given replacements within the text.
|
||||
|
||||
@ -43,7 +46,7 @@ def apply_replacements(replacements, text):
|
||||
return text
|
||||
|
||||
|
||||
def main(args):
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
with open(args.version_path) as f:
|
||||
version = f.read().strip()
|
||||
(major, minor, patch) = parse_version(version)
|
||||
|
@ -2,12 +2,13 @@ import argparse
|
||||
import os
|
||||
import sys
|
||||
import yaml
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
try:
|
||||
# use faster C loader if available
|
||||
from yaml import CSafeLoader as YamlLoader
|
||||
except ImportError:
|
||||
from yaml import SafeLoader as YamlLoader
|
||||
from yaml import SafeLoader as YamlLoader # type: ignore[misc]
|
||||
|
||||
source_files = {'.py', '.cpp', '.h'}
|
||||
|
||||
@ -16,7 +17,7 @@ NATIVE_FUNCTIONS_PATH = 'aten/src/ATen/native/native_functions.yaml'
|
||||
|
||||
# TODO: This is a little inaccurate, because it will also pick
|
||||
# up setup_helper scripts which don't affect code generation
|
||||
def all_generator_source():
|
||||
def all_generator_source() -> List[str]:
|
||||
r = []
|
||||
for directory, _, filenames in os.walk('tools'):
|
||||
for f in filenames:
|
||||
@ -26,15 +27,15 @@ def all_generator_source():
|
||||
return sorted(r)
|
||||
|
||||
|
||||
def generate_code(ninja_global=None,
|
||||
declarations_path=None,
|
||||
nn_path=None,
|
||||
native_functions_path=None,
|
||||
install_dir=None,
|
||||
subset=None,
|
||||
disable_autograd=False,
|
||||
force_schema_registration=False,
|
||||
operator_selector=None):
|
||||
def generate_code(ninja_global: Optional[str] = None,
|
||||
declarations_path: Optional[str] = None,
|
||||
nn_path: Optional[str] = None,
|
||||
native_functions_path: Optional[str] = None,
|
||||
install_dir: Optional[str] = None,
|
||||
subset: Optional[str] = None,
|
||||
disable_autograd: bool = False,
|
||||
force_schema_registration: bool = False,
|
||||
operator_selector: Any = None) -> None:
|
||||
from tools.autograd.gen_autograd import gen_autograd, gen_autograd_python
|
||||
from tools.autograd.gen_annotated_fn_args import gen_annotated
|
||||
from tools.codegen.selective_build.selector import SelectiveBuilder
|
||||
@ -86,7 +87,7 @@ def generate_code(ninja_global=None,
|
||||
|
||||
def get_selector_from_legacy_operator_selection_list(
|
||||
selected_op_list_path: str,
|
||||
):
|
||||
) -> Any:
|
||||
with open(selected_op_list_path, 'r') as f:
|
||||
# strip out the overload part
|
||||
# It's only for legacy config - do NOT copy this code!
|
||||
@ -113,7 +114,10 @@ def get_selector_from_legacy_operator_selection_list(
|
||||
return selector
|
||||
|
||||
|
||||
def get_selector(selected_op_list_path, operators_yaml_path):
|
||||
def get_selector(
|
||||
selected_op_list_path: Optional[str],
|
||||
operators_yaml_path: Optional[str],
|
||||
) -> Any:
|
||||
# cwrap depends on pyyaml, so we can't import it earlier
|
||||
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.insert(0, root)
|
||||
@ -129,10 +133,10 @@ def get_selector(selected_op_list_path, operators_yaml_path):
|
||||
elif selected_op_list_path is not None:
|
||||
return get_selector_from_legacy_operator_selection_list(selected_op_list_path)
|
||||
else:
|
||||
return SelectiveBuilder.from_yaml_path(operators_yaml_path)
|
||||
return SelectiveBuilder.from_yaml_path(cast(str, operators_yaml_path))
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description='Autogenerate code')
|
||||
parser.add_argument('--declarations-path')
|
||||
parser.add_argument('--native-functions-path')
|
||||
|
@ -2,8 +2,11 @@
|
||||
# for now, I have put it in one place but right now is copied out of cwrap
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, Iterable, List, Union
|
||||
|
||||
def parse_arguments(args):
|
||||
Arg = Dict[str, Any]
|
||||
|
||||
def parse_arguments(args: List[Union[str, Arg]]) -> List[Arg]:
|
||||
new_args = []
|
||||
for arg in args:
|
||||
# Simple arg declaration of form "<type> <name>"
|
||||
@ -20,7 +23,10 @@ def parse_arguments(args):
|
||||
return new_args
|
||||
|
||||
|
||||
def set_declaration_defaults(declaration):
|
||||
Declaration = Dict[str, Any]
|
||||
|
||||
|
||||
def set_declaration_defaults(declaration: Declaration) -> None:
|
||||
if 'schema_string' not in declaration:
|
||||
# This happens for legacy TH bindings like
|
||||
# _thnn_conv_depthwise2d_backward
|
||||
@ -70,19 +76,26 @@ def set_declaration_defaults(declaration):
|
||||
# TODO(zach): added option to remove keyword handling for C++ which cannot
|
||||
# support it.
|
||||
|
||||
Option = Dict[str, Any]
|
||||
|
||||
def filter_unique_options(options, allow_kwarg, type_to_signature, remove_self):
|
||||
def exclude_arg(arg):
|
||||
return arg['type'] == 'CONSTANT'
|
||||
|
||||
def exclude_arg_with_self_check(arg):
|
||||
def filter_unique_options(
|
||||
options: Iterable[Option],
|
||||
allow_kwarg: bool,
|
||||
type_to_signature: Dict[str, str],
|
||||
remove_self: bool,
|
||||
) -> List[Option]:
|
||||
def exclude_arg(arg: Arg) -> bool:
|
||||
return arg['type'] == 'CONSTANT' # type: ignore[no-any-return]
|
||||
|
||||
def exclude_arg_with_self_check(arg: Arg) -> bool:
|
||||
return exclude_arg(arg) or (remove_self and arg['name'] == 'self')
|
||||
|
||||
def signature(option, kwarg_only_count):
|
||||
if kwarg_only_count == 0:
|
||||
def signature(option: Option, num_kwarg_only: int) -> str:
|
||||
if num_kwarg_only == 0:
|
||||
kwarg_only_count = None
|
||||
else:
|
||||
kwarg_only_count = -kwarg_only_count
|
||||
kwarg_only_count = -num_kwarg_only
|
||||
arg_signature = '#'.join(
|
||||
type_to_signature.get(arg['type'], arg['type'])
|
||||
for arg in option['arguments'][:kwarg_only_count]
|
||||
@ -111,40 +124,40 @@ def filter_unique_options(options, allow_kwarg, type_to_signature, remove_self):
|
||||
return unique
|
||||
|
||||
|
||||
def sort_by_number_of_args(declaration, reverse=True):
|
||||
def num_args(option):
|
||||
def sort_by_number_of_args(declaration: Declaration, reverse: bool = True) -> None:
|
||||
def num_args(option: Option) -> int:
|
||||
return len(option['arguments'])
|
||||
declaration['options'].sort(key=num_args, reverse=reverse)
|
||||
|
||||
|
||||
class Function(object):
|
||||
|
||||
def __init__(self, name):
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
self.arguments = []
|
||||
self.arguments: List['Argument'] = []
|
||||
|
||||
def add_argument(self, arg):
|
||||
def add_argument(self, arg: 'Argument') -> None:
|
||||
assert isinstance(arg, Argument)
|
||||
self.arguments.append(arg)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.name + '(' + ', '.join(a.__repr__() for a in self.arguments) + ')'
|
||||
|
||||
|
||||
class Argument(object):
|
||||
|
||||
def __init__(self, _type, name, is_optional):
|
||||
def __init__(self, _type: str, name: str, is_optional: bool):
|
||||
self.type = _type
|
||||
self.name = name
|
||||
self.is_optional = is_optional
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return self.type + ' ' + self.name
|
||||
|
||||
|
||||
def parse_header(path):
|
||||
def parse_header(path: str) -> List[Function]:
|
||||
with open(path, 'r') as f:
|
||||
lines = f.read().split('\n')
|
||||
lines: Iterable[Any] = f.read().split('\n')
|
||||
|
||||
# Remove empty lines and prebackend directives
|
||||
lines = filter(lambda l: l and not l.startswith('#'), lines)
|
||||
|
@ -1,6 +1,11 @@
|
||||
def import_module(name, path):
|
||||
from importlib.abc import Loader
|
||||
from types import ModuleType
|
||||
from typing import cast
|
||||
|
||||
|
||||
def import_module(name: str, path: str) -> ModuleType:
|
||||
import importlib.util
|
||||
spec = importlib.util.spec_from_file_location(name, path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
cast(Loader, spec.loader).exec_module(module)
|
||||
return module
|
||||
|
@ -1,13 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import unittest
|
||||
from typing import Dict, List
|
||||
|
||||
from tools import print_test_stats
|
||||
from tools.stats_utils.s3_stat_parser import (Commit, Report, ReportMetaMeta,
|
||||
Status, Version1Case,
|
||||
Version1Report, Version2Case,
|
||||
Version2Report)
|
||||
|
||||
|
||||
def fakehash(char):
|
||||
def fakehash(char: str) -> str:
|
||||
return char * 40
|
||||
|
||||
|
||||
def dummy_meta_meta() -> print_test_stats.ReportMetaMeta:
|
||||
def dummy_meta_meta() -> ReportMetaMeta:
|
||||
return {
|
||||
'build_pr': '',
|
||||
'build_tag': '',
|
||||
@ -18,7 +24,14 @@ def dummy_meta_meta() -> print_test_stats.ReportMetaMeta:
|
||||
}
|
||||
|
||||
|
||||
def makecase(name, seconds, *, errored=False, failed=False, skipped=False):
|
||||
def makecase(
|
||||
name: str,
|
||||
seconds: float,
|
||||
*,
|
||||
errored: bool = False,
|
||||
failed: bool = False,
|
||||
skipped: bool = False,
|
||||
) -> Version1Case:
|
||||
return {
|
||||
'name': name,
|
||||
'seconds': seconds,
|
||||
@ -28,7 +41,7 @@ def makecase(name, seconds, *, errored=False, failed=False, skipped=False):
|
||||
}
|
||||
|
||||
|
||||
def make_report_v1(tests) -> print_test_stats.Version1Report:
|
||||
def make_report_v1(tests: Dict[str, List[Version1Case]]) -> Version1Report:
|
||||
suites = {
|
||||
suite_name: {
|
||||
'total_seconds': sum(case['seconds'] for case in cases),
|
||||
@ -37,20 +50,20 @@ def make_report_v1(tests) -> print_test_stats.Version1Report:
|
||||
for suite_name, cases in tests.items()
|
||||
}
|
||||
return {
|
||||
**dummy_meta_meta(),
|
||||
**dummy_meta_meta(), # type: ignore[misc]
|
||||
'total_seconds': sum(s['total_seconds'] for s in suites.values()),
|
||||
'suites': suites,
|
||||
}
|
||||
|
||||
|
||||
def make_case_v2(seconds, status=None) -> print_test_stats.Version2Case:
|
||||
def make_case_v2(seconds: float, status: Status = None) -> Version2Case:
|
||||
return {
|
||||
'seconds': seconds,
|
||||
'status': status,
|
||||
}
|
||||
|
||||
|
||||
def make_report_v2(tests) -> print_test_stats.Version2Report:
|
||||
def make_report_v2(tests: Dict[str, Dict[str, Dict[str, Version2Case]]]) -> Version2Report:
|
||||
files = {}
|
||||
for file_name, file_suites in tests.items():
|
||||
suites = {
|
||||
@ -65,7 +78,7 @@ def make_report_v2(tests) -> print_test_stats.Version2Report:
|
||||
'total_seconds': sum(suite['total_seconds'] for suite in suites.values()),
|
||||
}
|
||||
return {
|
||||
**dummy_meta_meta(),
|
||||
**dummy_meta_meta(), # type: ignore[misc]
|
||||
'format_version': 2,
|
||||
'total_seconds': sum(s['total_seconds'] for s in files.values()),
|
||||
'files': files,
|
||||
@ -73,7 +86,7 @@ def make_report_v2(tests) -> print_test_stats.Version2Report:
|
||||
maxDiff = None
|
||||
|
||||
class TestPrintTestStats(unittest.TestCase):
|
||||
version1_report: print_test_stats.Version1Report = make_report_v1({
|
||||
version1_report: Version1Report = make_report_v1({
|
||||
# input ordering of the suites is ignored
|
||||
'Grault': [
|
||||
# not printed: status same and time similar
|
||||
@ -112,7 +125,7 @@ class TestPrintTestStats(unittest.TestCase):
|
||||
],
|
||||
})
|
||||
|
||||
version2_report: print_test_stats.Version2Report = make_report_v2(
|
||||
version2_report: Version2Report = make_report_v2(
|
||||
{
|
||||
'test_a': {
|
||||
'Grault': {
|
||||
@ -149,7 +162,7 @@ class TestPrintTestStats(unittest.TestCase):
|
||||
}
|
||||
})
|
||||
|
||||
def test_simplify(self):
|
||||
def test_simplify(self) -> None:
|
||||
self.assertEqual(
|
||||
{
|
||||
'': {
|
||||
@ -222,10 +235,10 @@ class TestPrintTestStats(unittest.TestCase):
|
||||
print_test_stats.simplify(self.version2_report),
|
||||
)
|
||||
|
||||
def test_analysis(self):
|
||||
def test_analysis(self) -> None:
|
||||
head_report = self.version1_report
|
||||
|
||||
base_reports = {
|
||||
base_reports: Dict[Commit, List[Report]] = {
|
||||
# bbbb has no reports, so base is cccc instead
|
||||
fakehash('b'): [],
|
||||
fakehash('c'): [
|
||||
@ -391,7 +404,7 @@ class TestPrintTestStats(unittest.TestCase):
|
||||
print_test_stats.anomalies(analysis),
|
||||
)
|
||||
|
||||
def test_graph(self):
|
||||
def test_graph(self) -> None:
|
||||
# HEAD is on master
|
||||
self.assertEqual(
|
||||
'''\
|
||||
@ -534,7 +547,7 @@ Commit graph (base is most recent master ancestor with at least one S3 report):
|
||||
)
|
||||
)
|
||||
|
||||
def test_regression_info(self):
|
||||
def test_regression_info(self) -> None:
|
||||
self.assertEqual(
|
||||
'''\
|
||||
----- Historic stats comparison result ------
|
||||
@ -588,7 +601,7 @@ Added (across 1 suite) 1 test, totaling + 3.00s
|
||||
)
|
||||
)
|
||||
|
||||
def test_regression_info_new_job(self):
|
||||
def test_regression_info_new_job(self) -> None:
|
||||
self.assertEqual(
|
||||
'''\
|
||||
----- Historic stats comparison result ------
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
import inspect
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
# this arbitrary-looking assortment of functionality is provided here
|
||||
# to have a central place for overrideable behavior. The motivating
|
||||
@ -20,30 +21,33 @@ else:
|
||||
else:
|
||||
torch_parent = os.path.dirname(os.path.dirname(__file__))
|
||||
|
||||
def get_file_path(*path_components):
|
||||
def get_file_path(*path_components: str) -> str:
|
||||
return os.path.join(torch_parent, *path_components)
|
||||
|
||||
|
||||
def get_file_path_2(*path_components):
|
||||
def get_file_path_2(*path_components: str) -> str:
|
||||
return os.path.join(*path_components)
|
||||
|
||||
|
||||
def get_writable_path(path):
|
||||
def get_writable_path(path: str) -> str:
|
||||
if os.access(path, os.W_OK):
|
||||
return path
|
||||
return tempfile.mkdtemp(suffix=os.path.basename(path))
|
||||
|
||||
|
||||
|
||||
def prepare_multiprocessing_environment(path):
|
||||
def prepare_multiprocessing_environment(path: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def resolve_library_path(path):
|
||||
def resolve_library_path(path: str) -> str:
|
||||
return os.path.realpath(path)
|
||||
|
||||
|
||||
def get_source_lines_and_file(obj, error_msg=None):
|
||||
def get_source_lines_and_file(
|
||||
obj: Any,
|
||||
error_msg: Optional[str] = None,
|
||||
) -> Tuple[List[str], int, Optional[str]]:
|
||||
"""
|
||||
Wrapper around inspect.getsourcelines and inspect.getsourcefile.
|
||||
|
||||
|
Reference in New Issue
Block a user