Strictly type everything in .github and tools (#59117)

Summary:
This PR greatly simplifies `mypy-strict.ini` by strictly typing everything in `.github` and `tools`, rather than picking and choosing only specific files in those two dirs. It also removes `warn_unused_ignores` from `mypy-strict.ini`, for reasons described in https://github.com/pytorch/pytorch/pull/56402#issuecomment-822743795: basically, that setting makes life more difficult depending on what libraries you have installed locally vs in CI (e.g. `ruamel`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59117

Test Plan:
```
flake8
mypy --config mypy-strict.ini
```

Reviewed By: malfet

Differential Revision: D28765386

Pulled By: samestep

fbshipit-source-id: 3e744e301c7a464f8a2a2428fcdbad534e231f2e
This commit is contained in:
Sam Estep
2021-06-07 14:48:29 -07:00
committed by Facebook GitHub Bot
parent 6ff001c125
commit 737d920b21
43 changed files with 463 additions and 312 deletions

View File

@ -11,12 +11,12 @@ REPO_ROOT = Path(__file__).resolve().parent.parent.parent
WORKFLOWS = REPO_ROOT / ".github" / "workflows"
def concurrency_key(filename):
def concurrency_key(filename: Path) -> str:
workflow_name = filename.with_suffix("").name.replace("_", "-")
return f"{workflow_name}-${{{{ github.event.pull_request.number || github.sha }}}}"
def should_check(filename):
def should_check(filename: Path) -> bool:
with open(filename, "r") as f:
content = f.read()
@ -31,7 +31,7 @@ if __name__ == "__main__":
)
args = parser.parse_args()
files = WORKFLOWS.glob("*.yml")
files = list(WORKFLOWS.glob("*.yml"))
errors_found = False
files = [f for f in files if should_check(f)]

View File

@ -16,12 +16,12 @@ LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$")
class NoGitTagException(Exception):
pass
def get_pytorch_root():
def get_pytorch_root() -> Path:
return Path(subprocess.check_output(
['git', 'rev-parse', '--show-toplevel']
).decode('ascii').strip())
def get_tag():
def get_tag() -> str:
root = get_pytorch_root()
# We're on a tag
am_on_tag = (
@ -46,7 +46,7 @@ def get_tag():
tag = re.sub(TRAILING_RC_PATTERN, "", tag)
return tag
def get_base_version():
def get_base_version() -> str:
root = get_pytorch_root()
dirty_version = open(root / 'version.txt', 'r').read().strip()
# Strips trailing a0 from version.txt, not too sure why it's there in the
@ -54,29 +54,34 @@ def get_base_version():
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)
class PytorchVersion:
def __init__(self, gpu_arch_type, gpu_arch_version, no_build_suffix):
def __init__(
self,
gpu_arch_type: str,
gpu_arch_version: str,
no_build_suffix: bool,
) -> None:
self.gpu_arch_type = gpu_arch_type
self.gpu_arch_version = gpu_arch_version
self.no_build_suffix = no_build_suffix
def get_post_build_suffix(self):
def get_post_build_suffix(self) -> str:
if self.gpu_arch_type == "cuda":
return f"+cu{self.gpu_arch_version.replace('.', '')}"
return f"+{self.gpu_arch_type}{self.gpu_arch_version}"
def get_release_version(self):
def get_release_version(self) -> str:
if not get_tag():
raise NoGitTagException(
"Not on a git tag, are you sure you want a release version?"
)
return f"{get_tag()}{self.get_post_build_suffix()}"
def get_nightly_version(self):
def get_nightly_version(self) -> str:
date_str = datetime.today().strftime('%Y%m%d')
build_suffix = self.get_post_build_suffix()
return f"{get_base_version()}.dev{date_str}{build_suffix}"
def main():
def main() -> None:
parser = argparse.ArgumentParser(
description="Generate pytorch version for binary builds"
)

View File

@ -14,19 +14,19 @@ is simply to make sure that there is *some* configuration of ruamel that can rou
the YAML, not to be prescriptive about it.
'''
import ruamel.yaml
import ruamel.yaml # type: ignore[import]
import difflib
import sys
from pathlib import Path
from io import StringIO
def fn(base):
def fn(base: str) -> str:
return str(base / Path("aten/src/ATen/native/native_functions.yaml"))
with open(Path(__file__).parent.parent.parent / fn('.'), "r") as f:
contents = f.read()
yaml = ruamel.yaml.YAML()
yaml = ruamel.yaml.YAML() # type: ignore[attr-defined]
yaml.preserve_quotes = True
yaml.width = 1000
yaml.boolean_representation = ['False', 'True']

View File

@ -31,7 +31,7 @@ direction: decrease
timeout: 720
tests:"""
def gen_abtest_config(control: str, treatment: str, models: List[str]):
def gen_abtest_config(control: str, treatment: str, models: List[str]) -> str:
d = {}
d["control"] = control
d["treatment"] = treatment
@ -43,7 +43,7 @@ def gen_abtest_config(control: str, treatment: str, models: List[str]):
config = config + "\n"
return config
def deploy_torchbench_config(output_dir: str, config: str):
def deploy_torchbench_config(output_dir: str, config: str) -> None:
# Create test dir if needed
pathlib.Path(output_dir).mkdir(exist_ok=True)
# TorchBench config file name
@ -71,7 +71,7 @@ def extract_models_from_pr(torchbench_path: str, prbody_file: str) -> List[str]:
return []
return model_list
def run_torchbench(pytorch_path: str, torchbench_path: str, output_dir: str):
def run_torchbench(pytorch_path: str, torchbench_path: str, output_dir: str) -> None:
# Copy system environment so that we will not override
env = dict(os.environ)
command = ["python", "bisection.py", "--work-dir", output_dir,

View File

@ -492,7 +492,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
"${TOOLS_PATH}/autograd/templates/python_linalg_functions.cpp"
"${TOOLS_PATH}/autograd/templates/python_special_functions.cpp"
"${TOOLS_PATH}/autograd/templates/variable_factories.h"
"${TOOLS_PATH}/autograd/templates/annotated_fn_args.py"
"${TOOLS_PATH}/autograd/templates/annotated_fn_args.py.in"
"${TOOLS_PATH}/autograd/deprecated.yaml"
"${TOOLS_PATH}/autograd/derivatives.yaml"
"${TOOLS_PATH}/autograd/gen_autograd_functions.py"

View File

@ -5,8 +5,6 @@
# this config file to be used to ENFORCE that people are using mypy on codegen
# files.
# For now, only code_template.py and benchmark utils Timer are covered this way
[mypy]
python_version = 3.6
plugins = mypy_plugins/check_mypy_version.py
@ -30,36 +28,14 @@ check_untyped_defs = True
disallow_untyped_decorators = True
no_implicit_optional = True
warn_redundant_casts = True
warn_unused_ignores = True
warn_return_any = True
implicit_reexport = False
strict_equality = True
files =
.github/scripts/generate_binary_build_matrix.py,
.github/scripts/generate_ci_workflows.py,
.github/scripts/parse_ref.py,
.github,
benchmarks/instruction_counts,
tools/actions_local_runner.py,
tools/autograd/*.py,
tools/clang_tidy.py,
tools/codegen,
tools/explicit_ci_jobs.py,
tools/extract_scripts.py,
tools/mypy_wrapper.py,
tools/print_test_stats.py,
tools/pyi,
tools/stats_utils,
tools/test_history.py,
tools/test/test_actions_local_runner.py,
tools/test/test_extract_scripts.py,
tools/test/test_mypy_wrapper.py,
tools/test/test_test_history.py,
tools/test/test_trailing_newlines.py,
tools/test/test_translate_annotations.py,
tools/trailing_newlines.py,
tools/translate_annotations.py,
tools/vscode_settings.py,
tools,
torch/testing/_internal/framework_utils.py,
torch/utils/_pytree.py,
torch/utils/benchmark/utils/common.py,

View File

@ -12,7 +12,7 @@ sys.path.append(os.path.realpath(os.path.join(
'torch',
'utils')))
from hipify import hipify_python
from hipify import hipify_python # type: ignore[import]
parser = argparse.ArgumentParser(description='Top-level script for HIPifying, filling in most common parameters')
parser.add_argument(
@ -115,7 +115,7 @@ ignores = [
]
# Check if the compiler is hip-clang.
def is_hip_clang():
def is_hip_clang() -> bool:
try:
hip_path = os.getenv('HIP_PATH', '/opt/rocm/hip')
return 'HIP_COMPILER=clang' in open(hip_path + '/lib/.hipInfo').read()

View File

@ -48,7 +48,7 @@ def gen_annotated(native_yaml_path: str, out: str, autograd_dir: str) -> None:
template_path = os.path.join(autograd_dir, 'templates')
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
fm.write_with_template('annotated_fn_args.py', 'annotated_fn_args.py', lambda: {
fm.write_with_template('annotated_fn_args.py', 'annotated_fn_args.py.in', lambda: {
'annotated_args': textwrap.indent('\n'.join(annotated_args), ' '),
})

View File

@ -1,15 +1,16 @@
import os
from glob import glob
import shutil
from typing import Dict, Optional
from .setup_helpers.env import IS_64BIT, IS_WINDOWS, check_negative_env_flag
from .setup_helpers.cmake import USE_NINJA
from .setup_helpers.cmake import USE_NINJA, CMake
from setuptools import distutils
from setuptools import distutils # type: ignore[import]
def _overlay_windows_vcvars(env):
def _overlay_windows_vcvars(env: Dict[str, str]) -> Dict[str, str]:
vc_arch = 'x64' if IS_64BIT else 'x86'
vc_env = distutils._msvccompiler._get_vc_env(vc_arch)
vc_env: Dict[str, str] = distutils._msvccompiler._get_vc_env(vc_arch)
# Keys in `_get_vc_env` are always lowercase.
# We turn them into uppercase before overlaying vcvars
# because OS environ keys are always uppercase on Windows.
@ -22,7 +23,7 @@ def _overlay_windows_vcvars(env):
return vc_env
def _create_build_env():
def _create_build_env() -> Dict[str, str]:
# XXX - our cmake file sometimes looks at the system environment
# and not cmake flags!
# you should NEVER add something to this list. It is bad practice to
@ -44,7 +45,14 @@ def _create_build_env():
return my_env
def build_caffe2(version, cmake_python_library, build_python, rerun_cmake, cmake_only, cmake):
def build_caffe2(
version: Optional[str],
cmake_python_library: Optional[str],
build_python: bool,
rerun_cmake: bool,
cmake_only: bool,
cmake: CMake,
) -> None:
my_env = _create_build_env()
build_test = not check_negative_env_flag('BUILD_TEST')
cmake.generate(version,

View File

@ -12,7 +12,9 @@ import asyncio
import re
import os
import sys
from clang_format_utils import get_and_check_clang_format, CLANG_FORMAT_PATH
from typing import List, Set
from .clang_format_utils import get_and_check_clang_format, CLANG_FORMAT_PATH
# Allowlist of directories to check. All files that in that directory
# (recursively) will be checked.
@ -28,7 +30,7 @@ CLANG_FORMAT_ALLOWLIST = [
CPP_FILE_REGEX = re.compile(".*\\.(h|cpp|cc|c|hpp)$")
def get_allowlisted_files():
def get_allowlisted_files() -> Set[str]:
"""
Parse CLANG_FORMAT_ALLOWLIST and resolve all directories.
Returns the set of allowlist cpp source files.
@ -42,7 +44,11 @@ def get_allowlisted_files():
return set(matches)
async def run_clang_format_on_file(filename, semaphore, verbose=False):
async def run_clang_format_on_file(
filename: str,
semaphore: asyncio.Semaphore,
verbose: bool = False,
) -> None:
"""
Run clang-format on the provided file.
"""
@ -55,7 +61,11 @@ async def run_clang_format_on_file(filename, semaphore, verbose=False):
print("Formatted {}".format(filename))
async def file_clang_formatted_correctly(filename, semaphore, verbose=False):
async def file_clang_formatted_correctly(
filename: str,
semaphore: asyncio.Semaphore,
verbose: bool = False,
) -> bool:
"""
Checks if a file is formatted correctly and returns True if so.
"""
@ -80,7 +90,11 @@ async def file_clang_formatted_correctly(filename, semaphore, verbose=False):
return ok
async def run_clang_format(max_processes, diff=False, verbose=False):
async def run_clang_format(
max_processes: int,
diff: bool = False,
verbose: bool = False,
) -> bool:
"""
Run clang-format to all files in CLANG_FORMAT_ALLOWLIST that match CPP_FILE_REGEX.
"""
@ -114,7 +128,7 @@ async def run_clang_format(max_processes, diff=False, verbose=False):
return ok
def parse_args(args):
def parse_args(args: List[str]) -> argparse.Namespace:
"""
Parse and return command-line arguments.
"""
@ -134,7 +148,7 @@ def parse_args(args):
return parser.parse_args(args)
def main(args):
def main(args: List[str]) -> bool:
# Parse arguments.
options = parse_args(args)
# Get clang-format and make sure it is the right binary and it is in the right place.

View File

@ -45,7 +45,7 @@ def compute_file_sha256(path: str) -> str:
return hash.hexdigest()
def report_download_progress(chunk_number, chunk_size, file_size):
def report_download_progress(chunk_number: int, chunk_size: int, file_size: int) -> None:
"""
Pretty printer for file download progress.
"""
@ -55,7 +55,7 @@ def report_download_progress(chunk_number, chunk_size, file_size):
sys.stdout.write("\r0% |{:<64}| {}%".format(bar, int(percent * 100)))
def download_clang_format(path):
def download_clang_format(path: str) -> bool:
"""
Downloads a clang-format binary appropriate for the host platform and stores it at the given location.
"""
@ -81,7 +81,7 @@ def download_clang_format(path):
return True
def get_and_check_clang_format(verbose=False):
def get_and_check_clang_format(verbose: bool = False) -> bool:
"""
Download a platform-appropriate clang-format binary if one doesn't already exist at the expected location and verify
that it is the right binary by checking its SHA256 hash against the expected hash.

View File

@ -12,14 +12,18 @@ import argparse
import yaml
from collections import defaultdict
from typing import Dict, List, Set
def canonical_name(opname):
def canonical_name(opname: str) -> str:
# Skip the overload name part as it's not supported by code analyzer yet.
return opname.split('.', 1)[0]
def load_op_dep_graph(fname):
DepGraph = Dict[str, Set[str]]
def load_op_dep_graph(fname: str) -> DepGraph:
with open(fname, 'r') as stream:
result = defaultdict(set)
for op in yaml.safe_load(stream):
@ -27,10 +31,10 @@ def load_op_dep_graph(fname):
for dep in op.get('depends', []):
dep_name = canonical_name(dep['name'])
result[op_name].add(dep_name)
return result
return dict(result)
def load_root_ops(fname):
def load_root_ops(fname: str) -> List[str]:
result = []
with open(fname, 'r') as stream:
for op in yaml.safe_load(stream):
@ -38,7 +42,11 @@ def load_root_ops(fname):
return result
def gen_transitive_closure(dep_graph, root_ops, train=False):
def gen_transitive_closure(
dep_graph: DepGraph,
root_ops: List[str],
train: bool = False,
) -> List[str]:
result = set(root_ops)
queue = root_ops[:]
@ -64,7 +72,7 @@ def gen_transitive_closure(dep_graph, root_ops, train=False):
return sorted(result)
def gen_transitive_closure_str(dep_graph, root_ops):
def gen_transitive_closure_str(dep_graph: DepGraph, root_ops: List[str]) -> str:
return ' '.join(gen_transitive_closure(dep_graph, root_ops))

View File

@ -11,6 +11,7 @@ python -m tools.code_analyzer.op_deps_processor \
import argparse
import yaml
from typing import Any, List
from tools.codegen.code_template import CodeTemplate
@ -46,12 +47,12 @@ DOT_OP_DEP = CodeTemplate("""\
""")
def load_op_deps(fname):
def load_op_deps(fname: str) -> Any:
with open(fname, 'r') as stream:
return yaml.safe_load(stream)
def process_base_ops(graph, base_ops):
def process_base_ops(graph: Any, base_ops: List[str]) -> None:
# remove base ops from all `depends` lists to compress the output graph
for op in graph:
op['depends'] = [
@ -64,7 +65,13 @@ def process_base_ops(graph, base_ops):
'depends': [{'name': name} for name in base_ops]})
def convert(fname, graph, output_template, op_template, op_dep_template):
def convert(
fname: str,
graph: Any,
output_template: CodeTemplate,
op_template: CodeTemplate,
op_dep_template: CodeTemplate,
) -> None:
ops = []
for op in graph:
op_name = op['name']

View File

@ -1,11 +1,11 @@
from ..tool import clang_coverage
from ..util.setting import CompilerType, Option, TestList, TestPlatform
from ..util.utils import check_compiler_type
from .init import detect_compiler_type
from .init import detect_compiler_type # type: ignore[attr-defined]
from .run import clang_run, gcc_run
def get_json_report(test_list: TestList, options: Option):
def get_json_report(test_list: TestList, options: Option) -> None:
cov_type = detect_compiler_type()
check_compiler_type(cov_type)
if cov_type == CompilerType.CLANG:

View File

@ -1,6 +1,6 @@
import argparse
import os
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, cast
from ..util.setting import (
JSON_FOLDER_BASE_DIR,
@ -129,7 +129,7 @@ def empty_list_if_none(arg_interested_folder: Optional[List[str]]) -> List[str]:
return arg_interested_folder
def gcc_export_init():
def gcc_export_init() -> None:
remove_folder(JSON_FOLDER_BASE_DIR)
create_folder(JSON_FOLDER_BASE_DIR)
@ -161,7 +161,7 @@ def print_init_info() -> None:
print_log("pytorch folder: ", get_pytorch_folder())
print_log("cpp test binaries folder: ", get_oss_binary_folder(TestType.CPP))
print_log("python test scripts folder: ", get_oss_binary_folder(TestType.PY))
print_log("compiler type: ", detect_compiler_type().value)
print_log("compiler type: ", cast(CompilerType, detect_compiler_type()).value)
print_log(
"llvm tool folder (only for clang, if you are using gcov please ignore it): ",
get_llvm_tool_path(),

View File

@ -82,8 +82,7 @@ def get_gcda_files() -> List[str]:
# TODO use glob
# output = glob.glob(f"{folder_has_gcda}/**/*.gcda")
output = subprocess.check_output(["find", folder_has_gcda, "-iname", "*.gcda"])
output = output.decode("utf-8").split("\n")
return output
return output.decode("utf-8").split("\n")
else:
return []

View File

@ -148,7 +148,7 @@ def export(test_list: TestList, platform_type: TestPlatform) -> None:
binary_file = ""
shared_library_list = []
if platform_type == TestPlatform.FBCODE:
from caffe2.fb.code_coverage.tool.package.fbcode.utils import (
from caffe2.fb.code_coverage.tool.package.fbcode.utils import ( # type: ignore[import]
get_fbcode_binary_folder,
)

View File

@ -10,11 +10,11 @@ class LlvmCoverageSegment(NamedTuple):
is_gap_entry: Optional[int]
@property
def has_coverage(self):
def has_coverage(self) -> bool:
return self.segment_count > 0
@property
def is_executable(self):
def is_executable(self) -> bool:
return self.has_count > 0
def get_coverage(

View File

@ -1,20 +1,22 @@
import os
import subprocess
from typing import IO, Dict, List, Set
from typing import IO, Dict, List, Set, Tuple
from ..oss.utils import get_pytorch_folder
from ..util.setting import SUMMARY_FOLDER_DIR, TestList, TestStatusType
CoverageItem = Tuple[str, float, int, int]
def key_by_percentage(x):
def key_by_percentage(x: CoverageItem) -> float:
return x[1]
def key_by_name(x):
def key_by_name(x: CoverageItem) -> str:
return x[0]
def is_intrested_file(file_path: str, interested_folders: List[str]):
def is_intrested_file(file_path: str, interested_folders: List[str]) -> bool:
if "cuda" in file_path:
return False
if "aten/gen_aten" in file_path or "aten/aten_" in file_path:
@ -34,7 +36,7 @@ def is_this_type_of_tests(target_name: str, test_set_by_type: Set[str]) -> bool:
def print_test_by_type(
tests: TestList, test_set_by_type: Set[str], type_name: str, summary_file: IO
tests: TestList, test_set_by_type: Set[str], type_name: str, summary_file: IO[str]
) -> None:
print("Tests " + type_name + " to collect coverage:", file=summary_file)
@ -49,7 +51,7 @@ def print_test_condition(
tests_type: TestStatusType,
interested_folders: List[str],
coverage_only: List[str],
summary_file: IO,
summary_file: IO[str],
summary_type: str,
) -> None:
print_test_by_type(tests, tests_type["success"], "fully success", summary_file)
@ -91,14 +93,8 @@ def line_oriented_report(
"LINE SUMMARY",
)
for file_name in covered_lines:
if len(covered_lines[file_name]) == 0:
covered = {}
else:
covered = covered_lines[file_name]
if len(uncovered_lines[file_name]) == 0:
uncovered = {}
else:
uncovered = uncovered_lines[file_name]
covered = covered_lines[file_name]
uncovered = uncovered_lines[file_name]
print(
f"{file_name}\n covered lines: {sorted(covered)}\n unconvered lines:{sorted(uncovered)}",
file=report_file,
@ -106,7 +102,7 @@ def line_oriented_report(
def print_file_summary(
covered_summary: int, total_summary: int, summary_file: IO
covered_summary: int, total_summary: int, summary_file: IO[str]
) -> float:
# print summary first
try:
@ -124,10 +120,10 @@ def print_file_summary(
def print_file_oriented_report(
tests_type: TestStatusType,
coverage,
coverage: List[CoverageItem],
covered_summary: int,
total_summary: int,
summary_file: IO,
summary_file: IO[str],
tests: TestList,
interested_folders: List[str],
coverage_only: List[str],
@ -178,7 +174,7 @@ def file_oriented_report(
except ZeroDivisionError:
percentage = 0
# store information in a list to be sorted
coverage.append([file_name, percentage, covered_count, total_count])
coverage.append((file_name, percentage, covered_count, total_count))
# update summary
covered_summary = covered_summary + covered_count
total_summary = total_summary + total_count
@ -202,7 +198,7 @@ def get_html_ignored_pattern() -> List[str]:
return ["/usr/*", "*anaconda3/*", "*third_party/*"]
def html_oriented_report():
def html_oriented_report() -> None:
# use lcov to generate the coverage report
build_folder = os.path.join(get_pytorch_folder(), "build")
coverage_info_file = os.path.join(SUMMARY_FOLDER_DIR, "coverage.info")

View File

@ -55,7 +55,7 @@ def transform_file_name(
def is_intrested_file(
file_path: str, interested_folders: List[str], platform: TestPlatform
):
) -> bool:
ignored_patterns = ["cuda", "aten/gen_aten", "aten/aten_", "build/"]
if any([pattern in file_path for pattern in ignored_patterns]):
return False

View File

@ -12,12 +12,12 @@ def run_cpp_test(binary_file: str) -> None:
print_error(f"Binary failed to run: {binary_file}")
def get_tool_path_by_platform(platform: TestPlatform):
def get_tool_path_by_platform(platform: TestPlatform) -> str:
if platform == TestPlatform.FBCODE:
from caffe2.fb.code_coverage.tool.package.fbcode.utils import get_llvm_tool_path
from caffe2.fb.code_coverage.tool.package.fbcode.utils import get_llvm_tool_path # type: ignore[import]
return get_llvm_tool_path()
return get_llvm_tool_path() # type: ignore[no-any-return]
else:
from ..oss.utils import get_llvm_tool_path
from ..oss.utils import get_llvm_tool_path # type: ignore[no-redef]
return get_llvm_tool_path()
return get_llvm_tool_path() # type: ignore[no-any-return]

View File

@ -2,7 +2,7 @@ import os
import shutil
import sys
import time
from typing import Any, Optional
from typing import Any, NoReturn, Optional
from .setting import (
LOG_DIR,
@ -71,7 +71,7 @@ def convert_to_relative_path(whole_path: str, base_path: str) -> str:
return whole_path[len(base_path) + 1 :]
def replace_extension(filename, ext):
def replace_extension(filename: str, ext: str) -> str:
return filename[: filename.rfind(".")] + ext
@ -89,11 +89,11 @@ def get_raw_profiles_folder() -> str:
def detect_compiler_type(platform: TestPlatform) -> CompilerType:
if platform == TestPlatform.OSS:
from package.oss.utils import detect_compiler_type
from package.oss.utils import detect_compiler_type # type: ignore[misc]
cov_type = detect_compiler_type()
cov_type = detect_compiler_type() # type: ignore[call-arg]
else:
from caffe2.fb.code_coverage.tool.package.fbcode.utils import (
from caffe2.fb.code_coverage.tool.package.fbcode.utils import ( # type: ignore[import]
detect_compiler_type,
)
@ -138,7 +138,7 @@ def check_test_type(test_type: str, target: str) -> None:
)
def raise_no_test_found_exception(cpp_binary_folder: str, python_binary_folder: str):
def raise_no_test_found_exception(cpp_binary_folder: str, python_binary_folder: str) -> NoReturn:
raise RuntimeError(
f"No cpp and python tests found in folder **{cpp_binary_folder} and **{python_binary_folder}**"
)

View File

@ -1,4 +1,4 @@
import setuptools
import setuptools # type: ignore[import]
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()

View File

@ -8,9 +8,10 @@ compiled code has been executed. This means that even if the code chunk is merel
marked as covered.
'''
from coverage import CoveragePlugin, CoverageData
from coverage import CoveragePlugin, CoverageData # type: ignore[import]
from inspect import ismodule, isclass, ismethod, isfunction, iscode, getsourcefile, getsourcelines
from time import time
from typing import Any
# All coverage stats resulting from this plug-in will be in a separate .coverage file that should be merged later with
# `coverage combine`. The convention seems to be .coverage.dotted.suffix based on the following link:
@ -18,17 +19,17 @@ from time import time
cov_data = CoverageData(basename=f'.coverage.jit.{time()}')
def is_not_builtin_class(obj):
def is_not_builtin_class(obj: Any) -> bool:
return isclass(obj) and not type(obj).__module__ == 'builtins'
class JitPlugin(CoveragePlugin):
class JitPlugin(CoveragePlugin): # type: ignore[misc, no-any-unimported]
'''
dynamic_context is an overridden function that gives us access to every frame run during the coverage process. We
look for when the function being run is `should_drop`, as all functions that get passed into `should_drop` will be
compiled and thus should be marked as covered.
'''
def dynamic_context(self, frame):
def dynamic_context(self, frame: Any) -> None:
if frame.f_code.co_name == 'should_drop':
obj = frame.f_locals['fn']
# The many conditions in the if statement below are based on the accepted arguments to getsourcefile. Based
@ -54,5 +55,5 @@ class JitPlugin(CoveragePlugin):
cov_data.add_lines(line_data)
super().dynamic_context(frame)
def coverage_init(reg, options):
def coverage_init(reg: Any, options: Any) -> None:
reg.add_dynamic_context(JitPlugin())

View File

@ -18,14 +18,18 @@ RESOURCES = [
]
def report_download_progress(chunk_number, chunk_size, file_size):
def report_download_progress(
chunk_number: int,
chunk_size: int,
file_size: int,
) -> None:
if file_size != -1:
percent = min(1, (chunk_number * chunk_size) / file_size)
bar = '#' * int(64 * percent)
sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100)))
def download(destination_path, resource, quiet):
def download(destination_path: str, resource: str, quiet: bool) -> None:
if os.path.exists(destination_path):
if not quiet:
print('{} already exists, skipping ...'.format(destination_path))
@ -48,7 +52,7 @@ def download(destination_path, resource, quiet):
raise RuntimeError('Error downloading resource!')
def unzip(zipped_path, quiet):
def unzip(zipped_path: str, quiet: bool) -> None:
unzipped_path = os.path.splitext(zipped_path)[0]
if os.path.exists(unzipped_path):
if not quiet:
@ -61,7 +65,7 @@ def unzip(zipped_path, quiet):
print('Unzipped {} ...'.format(zipped_path))
def main():
def main() -> None:
parser = argparse.ArgumentParser(
description='Download the MNIST dataset from the internet')
parser.add_argument(

View File

@ -17,7 +17,7 @@ def get_test_case_times() -> Dict[str, float]:
# an entry will be like ("test_doc_examples (__main__.TestTypeHints)" -> [values]))
test_names_to_times: DefaultDict[str, List[float]] = defaultdict(list)
for report in reports:
if report.get('format_version', 1) != 2:
if report.get('format_version', 1) != 2: # type: ignore[misc]
raise RuntimeError("S3 format currently handled is version 2 only")
v2report = cast(Version2Report, report)
for test_file in v2report['files'].values():
@ -46,7 +46,7 @@ def export_slow_tests(filename: str) -> None:
file.write('\n')
def parse_args():
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description='Export a JSON of slow test cases in PyTorch unit test suite')
parser.add_argument(
@ -61,7 +61,7 @@ def parse_args():
return parser.parse_args()
def main():
def main() -> None:
options = parse_args()
export_slow_tests(options.filename)

View File

@ -14,7 +14,10 @@ import shutil
import subprocess
import sys
import time
from typing import (Awaitable, DefaultDict, Dict, List, Match, Optional, Set,
cast)
from typing_extensions import TypedDict
help_msg = '''fast_nvcc [OPTION]... -- [NVCC_ARG]...
@ -78,14 +81,14 @@ url_vars = f'{url_base}#keeping-intermediate-phase-files'
re_tmp = r'(?<![\w\-/])(?:/tmp/)?(tmp[^ \"\'\\]+)'
def fast_nvcc_warn(warning):
def fast_nvcc_warn(warning: str) -> None:
"""
Warn the user about something regarding fast_nvcc.
"""
print(f'warning (fast_nvcc): {warning}', file=sys.stderr)
def warn_if_windows():
def warn_if_windows() -> None:
"""
Warn the user that using fast_nvcc on Windows might not work.
"""
@ -97,7 +100,7 @@ def warn_if_windows():
fast_nvcc_warn(url_vars)
def warn_if_tmpdir_flag(args):
def warn_if_tmpdir_flag(args: List[str]) -> None:
"""
Warn the user that using fast_nvcc with some flags might not work.
"""
@ -121,11 +124,17 @@ def warn_if_tmpdir_flag(args):
fast_nvcc_warn(f'{url_base}#{frag}')
def nvcc_dryrun_data(binary, args):
class DryunData(TypedDict):
env: Dict[str, str]
commands: List[str]
exit_code: int
def nvcc_dryrun_data(binary: str, args: List[str]) -> DryunData:
"""
Return parsed environment variables and commands from nvcc --dryrun.
"""
result = subprocess.run(
result = subprocess.run( # type: ignore[call-overload]
[binary, '--dryrun'] + args,
capture_output=True,
encoding='ascii', # this is just a guess
@ -148,7 +157,7 @@ def nvcc_dryrun_data(binary, args):
return {'env': env, 'commands': commands, 'exit_code': result.returncode}
def warn_if_tmpdir_set(env):
def warn_if_tmpdir_set(env: Dict[str, str]) -> None:
"""
Warn the user that setting TMPDIR with fast_nvcc might not work.
"""
@ -157,7 +166,7 @@ def warn_if_tmpdir_set(env):
fast_nvcc_warn(url_vars)
def contains_non_executable(commands):
def contains_non_executable(commands: List[str]) -> bool:
for command in commands:
# This is to deal with special command dry-run result from NVCC such as:
# ```
@ -170,7 +179,7 @@ def contains_non_executable(commands):
return False
def module_id_contents(command):
def module_id_contents(command: List[str]) -> str:
"""
Guess the contents of the .module_id file contained within command.
"""
@ -187,7 +196,7 @@ def module_id_contents(command):
return f'_{len(middle)}_{middle}_{suffix}'
def unique_module_id_files(commands):
def unique_module_id_files(commands: List[str]) -> List[str]:
"""
Give each command its own .module_id filename instead of sharing.
"""
@ -196,7 +205,7 @@ def unique_module_id_files(commands):
for i, line in enumerate(commands):
arr = []
def uniqueify(s):
def uniqueify(s: Match[str]) -> str:
filename = re.sub(r'\-(\d+)', r'-\1-' + str(i), s.group(0))
arr.append(filename)
return filename
@ -212,14 +221,19 @@ def unique_module_id_files(commands):
return uniqueified
def make_rm_force(commands):
def make_rm_force(commands: List[str]) -> List[str]:
"""
Add --force to all rm commands.
"""
return [f'{c} --force' if c.startswith('rm ') else c for c in commands]
def print_verbose_output(*, env, commands, filename):
def print_verbose_output(
*,
env: Dict[str, str],
commands: List[List[str]],
filename: str,
) -> None:
"""
Human-readably write nvcc --dryrun data to stderr.
"""
@ -234,21 +248,24 @@ def print_verbose_output(*, env, commands, filename):
print(f'#{" "*len(prefix)}{part}', file=f)
def straight_line_dependencies(commands):
Graph = List[Set[int]]
def straight_line_dependencies(commands: List[str]) -> Graph:
"""
Return a straight-line dependency graph.
"""
return [({i - 1} if i > 0 else set()) for i in range(len(commands))]
def files_mentioned(command):
def files_mentioned(command: str) -> List[str]:
"""
Return fully-qualified names of all tmp files referenced by command.
"""
return [f'/tmp/{match.group(1)}' for match in re.finditer(re_tmp, command)]
def nvcc_data_dependencies(commands):
def nvcc_data_dependencies(commands: List[str]) -> Graph:
"""
Return a list of the set of dependencies for each command.
"""
@ -261,8 +278,8 @@ def nvcc_data_dependencies(commands):
# data dependency is sort of flipped, because the steps that use the
# files generated by cicc need to wait for the fatbinary step to
# finish first
tmp_files = {}
fatbins = collections.defaultdict(set)
tmp_files: Dict[str, int] = {}
fatbins: DefaultDict[int, Set[str]] = collections.defaultdict(set)
graph = []
for i, line in enumerate(commands):
deps = set()
@ -284,13 +301,13 @@ def nvcc_data_dependencies(commands):
return graph
def is_weakly_connected(graph):
def is_weakly_connected(graph: Graph) -> bool:
"""
Return true iff graph is weakly connected.
"""
if not graph:
return True
neighbors = [set() for _ in graph]
neighbors: List[Set[int]] = [set() for _ in graph]
for node, predecessors in enumerate(graph):
for pred in predecessors:
neighbors[pred].add(node)
@ -307,7 +324,7 @@ def is_weakly_connected(graph):
return len(found) == len(graph)
def warn_if_not_weakly_connected(graph):
def warn_if_not_weakly_connected(graph: Graph) -> None:
"""
Warn the user if the execution graph is not weakly connected.
"""
@ -315,11 +332,16 @@ def warn_if_not_weakly_connected(graph):
fast_nvcc_warn('execution graph is not (weakly) connected')
def print_dot_graph(*, commands, graph, filename):
def print_dot_graph(
*,
commands: List[List[str]],
graph: Graph,
filename: str,
) -> None:
"""
Print a DOT file displaying short versions of the commands in graph.
"""
def name(k):
def name(k: int) -> str:
return f'"{k} {os.path.basename(commands[k][0])}"'
with open(filename, 'w') as f:
print('digraph {', file=f)
@ -332,7 +354,24 @@ def print_dot_graph(*, commands, graph, filename):
print('}', file=f)
async def run_command(command, *, env, deps, gather_data, i, save):
class Result(TypedDict, total=False):
exit_code: int
stdout: bytes
stderr: bytes
time: float
files: Dict[str, int]
async def run_command(
command: str,
*,
env: Dict[str, str],
deps: Set[Awaitable[Result]],
gather_data: bool,
i: int,
save: Optional[str],
) -> Result:
"""
Run the command with the given env after waiting for deps.
"""
@ -350,8 +389,8 @@ async def run_command(command, *, env, deps, gather_data, i, save):
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
code = proc.returncode
results = {'exit_code': code, 'stdout': stdout, 'stderr': stderr}
code = cast(int, proc.returncode)
results: Result = {'exit_code': code, 'stdout': stdout, 'stderr': stderr}
if gather_data:
t2 = time.monotonic()
results['time'] = t2 - t1
@ -371,14 +410,21 @@ async def run_command(command, *, env, deps, gather_data, i, save):
return results
async def run_graph(*, env, commands, graph, gather_data=False, save=None):
async def run_graph(
*,
env: Dict[str, str],
commands: List[str],
graph: Graph,
gather_data: bool = False,
save: Optional[str] = None,
) -> List[Result]:
"""
Return outputs/errors (and optionally time/file info) from commands.
"""
tasks = []
tasks: List[Awaitable[Result]] = []
for i, (command, indices) in enumerate(zip(commands, graph)):
deps = {tasks[j] for j in indices}
tasks.append(asyncio.create_task(run_command(
tasks.append(asyncio.create_task(run_command( # type: ignore[attr-defined]
command,
env=env,
deps=deps,
@ -389,7 +435,7 @@ async def run_graph(*, env, commands, graph, gather_data=False, save=None):
return [await task for task in tasks]
def print_command_outputs(command_results):
def print_command_outputs(command_results: List[Result]) -> None:
"""
Print captured stdout and stderr from commands.
"""
@ -398,11 +444,16 @@ def print_command_outputs(command_results):
sys.stderr.write(result.get('stderr', b'').decode('ascii'))
def write_log_csv(command_parts, command_results, *, filename):
def write_log_csv(
command_parts: List[List[str]],
command_results: List[Result],
*,
filename: str,
) -> None:
"""
Write a CSV file of the times and /tmp file sizes from each command.
"""
tmp_files = []
tmp_files: List[str] = []
for result in command_results:
tmp_files.extend(result.get('files', {}).keys())
with open(filename, 'w', newline='') as csvfile:
@ -415,7 +466,7 @@ def write_log_csv(command_parts, command_results, *, filename):
writer.writerow({**row, **result.get('files', {})})
def exit_code(results):
def exit_code(results: List[Result]) -> int:
"""
Aggregate individual exit codes into a single code.
"""
@ -426,11 +477,18 @@ def exit_code(results):
return 0
def wrap_nvcc(args, config=default_config):
def wrap_nvcc(
args: List[str],
config: argparse.Namespace = default_config,
) -> int:
return subprocess.call([config.nvcc] + args)
def fast_nvcc(args, *, config=default_config):
def fast_nvcc(
args: List[str],
*,
config: argparse.Namespace = default_config,
) -> int:
"""
Emulate the result of calling the given nvcc binary with args.
@ -465,7 +523,7 @@ def fast_nvcc(args, *, config=default_config):
)
if config.sequential:
graph = straight_line_dependencies(commands)
results = asyncio.run(run_graph(
results = asyncio.run(run_graph( # type: ignore[attr-defined]
env=env,
commands=commands,
graph=graph,
@ -478,7 +536,7 @@ def fast_nvcc(args, *, config=default_config):
return exit_code([dryrun_data] + results)
def our_arg(arg):
def our_arg(arg: str) -> bool:
return arg != '--'

View File

@ -2,7 +2,7 @@
import sys
from flake8.main import git
from flake8.main import git # type: ignore[import]
if __name__ == '__main__':
sys.exit(

View File

@ -1,5 +1,6 @@
import gdb
import gdb # type: ignore[import]
import textwrap
from typing import Any
class DisableBreakpoints:
"""
@ -8,18 +9,18 @@ class DisableBreakpoints:
commands
"""
def __enter__(self):
def __enter__(self) -> None:
self.disabled_breakpoints = []
for b in gdb.breakpoints():
if b.enabled:
b.enabled = False
self.disabled_breakpoints.append(b)
def __exit__(self, etype, evalue, tb):
def __exit__(self, etype: Any, evalue: Any, tb: Any) -> None:
for b in self.disabled_breakpoints:
b.enabled = True
class TensorRepr(gdb.Command):
class TensorRepr(gdb.Command): # type: ignore[misc, no-any-unimported]
"""
Print a human readable representation of the given at::Tensor.
Usage: torch-tensor-repr EXP
@ -31,11 +32,11 @@ class TensorRepr(gdb.Command):
"""
__doc__ = textwrap.dedent(__doc__).strip()
def __init__(self):
def __init__(self) -> None:
gdb.Command.__init__(self, 'torch-tensor-repr',
gdb.COMMAND_USER, gdb.COMPLETE_EXPRESSION)
def invoke(self, args, from_tty):
def invoke(self, args: str, from_tty: bool) -> None:
args = gdb.string_to_argv(args)
if len(args) != 1:
print('Usage: torch-tensor-repr EXP')

View File

@ -2,7 +2,7 @@ import argparse
import os
import subprocess
from pathlib import Path
from setuptools import distutils
from setuptools import distutils # type: ignore[import]
from typing import Optional, Union
def get_sha(pytorch_root: Union[str, Path]) -> str:

View File

@ -103,7 +103,7 @@ def write_selected_mobile_ops_with_all_dtypes(
header_contents = "".join(body_parts)
out_file.write(header_contents.encode("utf-8"))
def main():
def main() -> None:
parser = argparse.ArgumentParser(
description="Generate selected_mobile_ops.h for selective build."
)

View File

@ -40,10 +40,10 @@ import contextlib
import subprocess
from ast import literal_eval
from argparse import ArgumentParser
from typing import Dict, Optional, Iterator
from typing import (Any, Callable, Dict, Generator, Iterable, Iterator, List,
Optional, Sequence, Set, Tuple, TypeVar, cast)
LOGGER = None
LOGGER: Optional[logging.Logger] = None
URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
SHA1_RE = re.compile("([0-9a-fA-F]{40})")
@ -133,7 +133,7 @@ def logging_rotate() -> None:
@contextlib.contextmanager
def logging_manager(*, debug: bool = False) -> Iterator[None]:
def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, None]:
"""Setup logging. If a failure starts here we won't
be able to save the user in a reasonable way.
@ -179,7 +179,7 @@ def logging_manager(*, debug: bool = False) -> Iterator[None]:
sys.exit(1)
def check_in_repo():
def check_in_repo() -> Optional[str]:
"""Ensures that we are in the PyTorch repo."""
if not os.path.isfile("setup.py"):
return "Not in root-level PyTorch repo, no setup.py found"
@ -187,12 +187,13 @@ def check_in_repo():
s = f.read()
if "PyTorch" not in s:
return "Not in PyTorch repo, 'PyTorch' not found in setup.py"
return None
def check_branch(subcommand, branch):
def check_branch(subcommand: str, branch: Optional[str]) -> Optional[str]:
"""Checks that the branch name can be checked out."""
if subcommand != "checkout":
return
return None
# first make sure actual branch name was given
if branch is None:
return "Branch name to checkout must be supplied with '-b' option"
@ -203,36 +204,44 @@ def check_branch(subcommand, branch):
return "Need to have clean working tree to checkout!\n\n" + p.stdout
# next check that the branch name doesn't already exist
cmd = ["git", "show-ref", "--verify", "--quiet", "refs/heads/" + branch]
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False)
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False) # type: ignore[assignment]
if not p.returncode:
return f"Branch {branch!r} already exists"
return None
@contextlib.contextmanager
def timer(logger, prefix):
def timer(logger: logging.Logger, prefix: str) -> Iterator[None]:
"""Timed context manager"""
start_time = time.time()
yield
logger.info(f"{prefix} took {time.time() - start_time:.3f} [s]")
def timed(prefix):
F = TypeVar('F', bound=Callable[..., Any])
def timed(prefix: str) -> Callable[[F], F]:
"""Decorator for timing functions"""
def dec(f):
def dec(f: F) -> F:
@functools.wraps(f)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> Any:
global LOGGER
LOGGER.info(prefix)
with timer(LOGGER, prefix):
logger = cast(logging.Logger, LOGGER)
logger.info(prefix)
with timer(logger, prefix):
return f(*args, **kwargs)
return wrapper
return cast(F, wrapper)
return dec
def _make_channel_args(channels=("pytorch-nightly",), override_channels=False):
def _make_channel_args(
channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False,
) -> List[str]:
args = []
for channel in channels:
args.append("--channel")
@ -244,8 +253,11 @@ def _make_channel_args(channels=("pytorch-nightly",), override_channels=False):
@timed("Solving conda environment")
def conda_solve(
name=None, prefix=None, channels=("pytorch-nightly",), override_channels=False
):
name: Optional[str] = None,
prefix: Optional[str] = None,
channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False,
) -> Tuple[List[str], str, str, bool, List[str]]:
"""Performs the conda solve and splits the deps from the package."""
# compute what environment to use
if prefix is not None:
@ -299,7 +311,7 @@ def conda_solve(
@timed("Installing dependencies")
def deps_install(deps, existing_env, env_opts):
def deps_install(deps: List[str], existing_env: bool, env_opts: List[str]) -> None:
"""Install dependencies to deps environment"""
if not existing_env:
# first remove previous pytorch-deps env
@ -312,7 +324,7 @@ def deps_install(deps, existing_env, env_opts):
@timed("Installing pytorch nightly binaries")
def pytorch_install(url):
def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]:
""""Install pytorch into a temporary directory"""
pytdir = tempfile.TemporaryDirectory()
cmd = ["conda", "create", "--yes", "--no-deps", "--prefix", pytdir.name, url]
@ -320,7 +332,7 @@ def pytorch_install(url):
return pytdir
def _site_packages(dirname, platform):
def _site_packages(dirname: str, platform: str) -> str:
if platform.startswith("win"):
template = os.path.join(dirname, "Lib", "site-packages")
else:
@ -329,7 +341,7 @@ def _site_packages(dirname, platform):
return spdir
def _ensure_commit(git_sha1):
def _ensure_commit(git_sha1: str) -> None:
"""Make sure that we actually have the commit locally"""
cmd = ["git", "cat-file", "-e", git_sha1 + "^{commit}"]
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False)
@ -341,7 +353,7 @@ def _ensure_commit(git_sha1):
p = subprocess.run(cmd, check=True)
def _nightly_version(spdir):
def _nightly_version(spdir: str) -> str:
# first get the git version from the installed module
version_fname = os.path.join(spdir, "torch", "version.py")
with open(version_fname) as f:
@ -371,7 +383,7 @@ def _nightly_version(spdir):
@timed("Checking out nightly PyTorch")
def checkout_nightly_version(branch, spdir):
def checkout_nightly_version(branch: str, spdir: str) -> None:
"""Get's the nightly version and then checks it out."""
nightly_version = _nightly_version(spdir)
cmd = ["git", "checkout", "-b", branch, nightly_version]
@ -379,40 +391,40 @@ def checkout_nightly_version(branch, spdir):
@timed("Pulling nightly PyTorch")
def pull_nightly_version(spdir):
def pull_nightly_version(spdir: str) -> None:
"""Fetches the nightly version and then merges it ."""
nightly_version = _nightly_version(spdir)
cmd = ["git", "merge", nightly_version]
p = subprocess.run(cmd, check=True)
def _get_listing_linux(source_dir):
def _get_listing_linux(source_dir: str) -> List[str]:
listing = glob.glob(os.path.join(source_dir, "*.so"))
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.so")))
return listing
def _get_listing_osx(source_dir):
def _get_listing_osx(source_dir: str) -> List[str]:
# oddly, these are .so files even on Mac
listing = glob.glob(os.path.join(source_dir, "*.so"))
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dylib")))
return listing
def _get_listing_win(source_dir):
def _get_listing_win(source_dir: str) -> List[str]:
listing = glob.glob(os.path.join(source_dir, "*.pyd"))
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.lib")))
listing.extend(glob.glob(os.path.join(source_dir, "lib", "*.dll")))
return listing
def _glob_pyis(d):
def _glob_pyis(d: str) -> Set[str]:
search = os.path.join(d, "**", "*.pyi")
pyis = {os.path.relpath(p, d) for p in glob.iglob(search)}
return pyis
def _find_missing_pyi(source_dir, target_dir):
def _find_missing_pyi(source_dir: str, target_dir: str) -> List[str]:
source_pyis = _glob_pyis(source_dir)
target_pyis = _glob_pyis(target_dir)
missing_pyis = [os.path.join(source_dir, p) for p in (source_pyis - target_pyis)]
@ -420,7 +432,7 @@ def _find_missing_pyi(source_dir, target_dir):
return missing_pyis
def _get_listing(source_dir, target_dir, platform):
def _get_listing(source_dir: str, target_dir: str, platform: str) -> List[str]:
if platform.startswith("linux"):
listing = _get_listing_linux(source_dir)
elif platform.startswith("osx"):
@ -437,7 +449,7 @@ def _get_listing(source_dir, target_dir, platform):
return listing
def _remove_existing(trg, is_dir):
def _remove_existing(trg: str, is_dir: bool) -> None:
if os.path.exists(trg):
if is_dir:
shutil.rmtree(trg)
@ -445,7 +457,13 @@ def _remove_existing(trg, is_dir):
os.remove(trg)
def _move_single(src, source_dir, target_dir, mover, verb):
def _move_single(
src: str,
source_dir: str,
target_dir: str,
mover: Callable[[str, str], None],
verb: str,
) -> None:
is_dir = os.path.isdir(src)
relpath = os.path.relpath(src, source_dir)
trg = os.path.join(target_dir, relpath)
@ -469,18 +487,18 @@ def _move_single(src, source_dir, target_dir, mover, verb):
mover(src, trg)
def _copy_files(listing, source_dir, target_dir):
def _copy_files(listing: List[str], source_dir: str, target_dir: str) -> None:
for src in listing:
_move_single(src, source_dir, target_dir, shutil.copy2, "Copying")
def _link_files(listing, source_dir, target_dir):
def _link_files(listing: List[str], source_dir: str, target_dir: str) -> None:
for src in listing:
_move_single(src, source_dir, target_dir, os.link, "Linking")
@timed("Moving nightly files into repo")
def move_nightly_files(spdir, platform):
def move_nightly_files(spdir: str, platform: str) -> None:
"""Moves PyTorch files from temporary installed location to repo."""
# get file listing
source_dir = os.path.join(spdir, "torch")
@ -496,7 +514,7 @@ def move_nightly_files(spdir, platform):
_copy_files(listing, source_dir, target_dir)
def _available_envs():
def _available_envs() -> Dict[str, str]:
cmd = ["conda", "env", "list"]
p = subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
lines = p.stdout.splitlines()
@ -513,7 +531,7 @@ def _available_envs():
@timed("Writing pytorch-nightly.pth")
def write_pth(env_opts, platform):
def write_pth(env_opts: List[str], platform: str) -> None:
"""Writes Python path file for this dir."""
env_type, env_dir = env_opts
if env_type == "--name":
@ -533,17 +551,16 @@ def write_pth(env_opts, platform):
def install(
subcommand="checkout",
branch=None,
name=None,
prefix=None,
channels=("pytorch-nightly",),
override_channels=False,
logger=None,
):
*,
logger: logging.Logger,
subcommand: str = "checkout",
branch: Optional[str] = None,
name: Optional[str] = None,
prefix: Optional[str] = None,
channels: Iterable[str] = ("pytorch-nightly",),
override_channels: bool = False,
) -> None:
"""Development install of PyTorch"""
global LOGGER
logger = logger or LOGGER
deps, pytorch, platform, existing_env, env_opts = conda_solve(
name=name, prefix=prefix, channels=channels, override_channels=override_channels
)
@ -552,7 +569,7 @@ def install(
pytdir = pytorch_install(pytorch)
spdir = _site_packages(pytdir.name, platform)
if subcommand == "checkout":
checkout_nightly_version(branch, spdir)
checkout_nightly_version(cast(str, branch), spdir)
elif subcommand == "pull":
pull_nightly_version(spdir)
else:
@ -566,7 +583,7 @@ def install(
)
def make_parser():
def make_parser() -> ArgumentParser:
p = ArgumentParser("nightly")
# subcommands
subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute")
@ -627,7 +644,7 @@ def make_parser():
return p
def main(args=None):
def main(args: Optional[Sequence[str]] = None) -> None:
"""Main entry point"""
global LOGGER
p = make_parser()

View File

@ -16,8 +16,8 @@ try:
except ImportError:
print("rich not found, for color output use 'pip install rich'")
def parse_junit_reports(path_to_reports: str) -> List[TestCase]:
def parse_file(path: str) -> List[TestCase]:
def parse_junit_reports(path_to_reports: str) -> List[TestCase]: # type: ignore[no-any-unimported]
def parse_file(path: str) -> List[TestCase]: # type: ignore[no-any-unimported]
try:
return convert_junit_to_testcases(JUnitXml.fromfile(path))
except Exception as err:
@ -37,7 +37,7 @@ def parse_junit_reports(path_to_reports: str) -> List[TestCase]:
return ret_xml
def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase]:
def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase]: # type: ignore[no-any-unimported]
testcases = []
for item in xml:
if isinstance(item, TestSuite):
@ -46,7 +46,7 @@ def convert_junit_to_testcases(xml: Union[JUnitXml, TestSuite]) -> List[TestCase
testcases.append(item)
return testcases
def render_tests(testcases: List[TestCase]) -> None:
def render_tests(testcases: List[TestCase]) -> None: # type: ignore[no-any-unimported]
num_passed = 0
num_skipped = 0
num_failed = 0

View File

@ -1,8 +1,9 @@
import os
import sys
from typing import Optional
def which(thefile):
def which(thefile: str) -> Optional[str]:
path = os.environ.get("PATH", os.defpath).split(os.pathsep)
for d in path:
fname = os.path.join(d, thefile)

View File

@ -8,14 +8,15 @@ import re
from subprocess import check_call, check_output, CalledProcessError
import sys
import sysconfig
from setuptools import distutils
from setuptools import distutils # type: ignore[import]
from typing import IO, Any, Dict, List, Optional, Union
from . import which
from .env import (BUILD_DIR, IS_64BIT, IS_DARWIN, IS_WINDOWS, check_negative_env_flag)
from .numpy_ import USE_NUMPY, NUMPY_INCLUDE_DIR
def _mkdir_p(d):
def _mkdir_p(d: str) -> None:
try:
os.makedirs(d)
except OSError:
@ -28,7 +29,11 @@ def _mkdir_p(d):
USE_NINJA = (not check_negative_env_flag('USE_NINJA') and
which('ninja') is not None)
def convert_cmake_value_to_python_value(cmake_value, cmake_type):
CMakeValue = Optional[Union[bool, str]]
def convert_cmake_value_to_python_value(cmake_value: str, cmake_type: str) -> CMakeValue:
r"""Convert a CMake value in a string form to a Python value.
Args:
@ -52,7 +57,7 @@ def convert_cmake_value_to_python_value(cmake_value, cmake_type):
else: # Directly return the cmake_value.
return cmake_value
def get_cmake_cache_variables_from_file(cmake_cache_file):
def get_cmake_cache_variables_from_file(cmake_cache_file: IO[str]) -> Dict[str, CMakeValue]:
r"""Gets values in CMakeCache.txt into a dictionary.
Args:
@ -93,12 +98,12 @@ def get_cmake_cache_variables_from_file(cmake_cache_file):
class CMake:
"Manages cmake."
def __init__(self, build_dir=BUILD_DIR):
def __init__(self, build_dir: str = BUILD_DIR) -> None:
self._cmake_command = CMake._get_cmake_command()
self.build_dir = build_dir
@property
def _cmake_cache_file(self):
def _cmake_cache_file(self) -> str:
r"""Returns the path to CMakeCache.txt.
Returns:
@ -107,7 +112,7 @@ class CMake:
return os.path.join(self.build_dir, 'CMakeCache.txt')
@staticmethod
def _get_cmake_command():
def _get_cmake_command() -> str:
"Returns cmake command."
cmake_command = 'cmake'
@ -124,7 +129,7 @@ class CMake:
raise RuntimeError('no cmake or cmake3 with version >= 3.5.0 found')
@staticmethod
def _get_version(cmd):
def _get_version(cmd: str) -> Any:
"Returns cmake version."
for line in check_output([cmd, '--version']).decode('utf-8').split('\n'):
@ -132,7 +137,7 @@ class CMake:
return distutils.version.LooseVersion(line.strip().split(' ')[2])
raise RuntimeError('no version found')
def run(self, args, env):
def run(self, args: List[str], env: Dict[str, str]) -> None:
"Executes cmake with arguments and an environment."
command = [self._cmake_command] + args
@ -146,13 +151,13 @@ class CMake:
sys.exit(1)
@staticmethod
def defines(args, **kwargs):
def defines(args: List[str], **kwargs: CMakeValue) -> None:
"Adds definitions to a cmake argument list."
for key, value in sorted(kwargs.items()):
if value is not None:
args.append('-D{}={}'.format(key, value))
def get_cmake_cache_variables(self):
def get_cmake_cache_variables(self) -> Dict[str, CMakeValue]:
r"""Gets values in CMakeCache.txt into a dictionary.
Returns:
dict: A ``dict`` containing the value of cached CMake variables.
@ -160,7 +165,15 @@ class CMake:
with open(self._cmake_cache_file) as f:
return get_cmake_cache_variables_from_file(f)
def generate(self, version, cmake_python_library, build_python, build_test, my_env, rerun):
def generate(
self,
version: Optional[str],
cmake_python_library: Optional[str],
build_python: bool,
build_test: bool,
my_env: Dict[str, str],
rerun: bool,
) -> None:
"Runs cmake to generate native build files."
if rerun and os.path.isfile(self._cmake_cache_file):
@ -215,7 +228,7 @@ class CMake:
_mkdir_p(self.build_dir)
# Store build options that are directly stored in environment variables
build_options = {
build_options: Dict[str, CMakeValue] = {
# The default value cannot be easily obtained in CMakeLists.txt. We set it here.
'CMAKE_PREFIX_PATH': sysconfig.get_path('purelib')
}
@ -335,7 +348,7 @@ class CMake:
args.append(base_dir)
self.run(args, env=my_env)
def build(self, my_env):
def build(self, my_env: Dict[str, str]) -> None:
"Runs cmake to build binaries."
from .env import build_type

View File

@ -3,6 +3,7 @@ import platform
import struct
import sys
from itertools import chain
from typing import Iterable, List, Optional, cast
IS_WINDOWS = (platform.system() == 'Windows')
@ -17,19 +18,19 @@ IS_64BIT = (struct.calcsize("P") == 8)
BUILD_DIR = 'build'
def check_env_flag(name, default=''):
def check_env_flag(name: str, default: str = '') -> bool:
return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y']
def check_negative_env_flag(name, default=''):
def check_negative_env_flag(name: str, default: str = '') -> bool:
return os.getenv(name, default).upper() in ['OFF', '0', 'NO', 'FALSE', 'N']
def gather_paths(env_vars):
def gather_paths(env_vars: Iterable[str]) -> List[str]:
return list(chain(*(os.getenv(v, '').split(os.pathsep) for v in env_vars)))
def lib_paths_from_base(base_path):
def lib_paths_from_base(base_path: str) -> List[str]:
return [os.path.join(base_path, s) for s in ['lib/x64', 'lib', 'lib64']]
@ -49,7 +50,7 @@ class BuildType(object):
"""
def __init__(self, cmake_build_type_env=None):
def __init__(self, cmake_build_type_env: Optional[str] = None) -> None:
if cmake_build_type_env is not None:
self.build_type_string = cmake_build_type_env
return
@ -63,19 +64,19 @@ class BuildType(object):
# Normally it is anti-pattern to determine build type from CMAKE_BUILD_TYPE because it is not used for
# multi-configuration build tools, such as Visual Studio and XCode. But since we always communicate with
# CMake using CMAKE_BUILD_TYPE from our Python scripts, this is OK here.
self.build_type_string = cmake_cache_vars['CMAKE_BUILD_TYPE']
self.build_type_string = cast(str, cmake_cache_vars['CMAKE_BUILD_TYPE'])
else:
self.build_type_string = os.environ.get('CMAKE_BUILD_TYPE', 'Release')
def is_debug(self):
def is_debug(self) -> bool:
"Checks Debug build."
return self.build_type_string == 'Debug'
def is_rel_with_deb_info(self):
def is_rel_with_deb_info(self) -> bool:
"Checks RelWithDebInfo build."
return self.build_type_string == 'RelWithDebInfo'
def is_release(self):
def is_release(self) -> bool:
"Checks Release build."
return self.build_type_string == 'Release'

View File

@ -4,9 +4,12 @@
import argparse
import os
from typing import Dict, Tuple, cast
Version = Tuple[int, int, int]
def parse_version(version: str) -> (int, int, int):
def parse_version(version: str) -> Version:
"""
Parses a version string into (major, minor, patch) version numbers.
@ -24,10 +27,10 @@ def parse_version(version: str) -> (int, int, int):
version_number_str = version[:i]
break
return tuple([int(n) for n in version_number_str.split(".")])
return cast(Version, tuple([int(n) for n in version_number_str.split(".")]))
def apply_replacements(replacements, text):
def apply_replacements(replacements: Dict[str, str], text: str) -> str:
"""
Applies the given replacements within the text.
@ -43,7 +46,7 @@ def apply_replacements(replacements, text):
return text
def main(args):
def main(args: argparse.Namespace) -> None:
with open(args.version_path) as f:
version = f.read().strip()
(major, minor, patch) = parse_version(version)

View File

@ -2,12 +2,13 @@ import argparse
import os
import sys
import yaml
from typing import Any, List, Optional, cast
try:
# use faster C loader if available
from yaml import CSafeLoader as YamlLoader
except ImportError:
from yaml import SafeLoader as YamlLoader
from yaml import SafeLoader as YamlLoader # type: ignore[misc]
source_files = {'.py', '.cpp', '.h'}
@ -16,7 +17,7 @@ NATIVE_FUNCTIONS_PATH = 'aten/src/ATen/native/native_functions.yaml'
# TODO: This is a little inaccurate, because it will also pick
# up setup_helper scripts which don't affect code generation
def all_generator_source():
def all_generator_source() -> List[str]:
r = []
for directory, _, filenames in os.walk('tools'):
for f in filenames:
@ -26,15 +27,15 @@ def all_generator_source():
return sorted(r)
def generate_code(ninja_global=None,
declarations_path=None,
nn_path=None,
native_functions_path=None,
install_dir=None,
subset=None,
disable_autograd=False,
force_schema_registration=False,
operator_selector=None):
def generate_code(ninja_global: Optional[str] = None,
declarations_path: Optional[str] = None,
nn_path: Optional[str] = None,
native_functions_path: Optional[str] = None,
install_dir: Optional[str] = None,
subset: Optional[str] = None,
disable_autograd: bool = False,
force_schema_registration: bool = False,
operator_selector: Any = None) -> None:
from tools.autograd.gen_autograd import gen_autograd, gen_autograd_python
from tools.autograd.gen_annotated_fn_args import gen_annotated
from tools.codegen.selective_build.selector import SelectiveBuilder
@ -86,7 +87,7 @@ def generate_code(ninja_global=None,
def get_selector_from_legacy_operator_selection_list(
selected_op_list_path: str,
):
) -> Any:
with open(selected_op_list_path, 'r') as f:
# strip out the overload part
# It's only for legacy config - do NOT copy this code!
@ -113,7 +114,10 @@ def get_selector_from_legacy_operator_selection_list(
return selector
def get_selector(selected_op_list_path, operators_yaml_path):
def get_selector(
selected_op_list_path: Optional[str],
operators_yaml_path: Optional[str],
) -> Any:
# cwrap depends on pyyaml, so we can't import it earlier
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, root)
@ -129,10 +133,10 @@ def get_selector(selected_op_list_path, operators_yaml_path):
elif selected_op_list_path is not None:
return get_selector_from_legacy_operator_selection_list(selected_op_list_path)
else:
return SelectiveBuilder.from_yaml_path(operators_yaml_path)
return SelectiveBuilder.from_yaml_path(cast(str, operators_yaml_path))
def main():
def main() -> None:
parser = argparse.ArgumentParser(description='Autogenerate code')
parser.add_argument('--declarations-path')
parser.add_argument('--native-functions-path')

View File

@ -2,8 +2,11 @@
# for now, I have put it in one place but right now is copied out of cwrap
import copy
from typing import Any, Dict, Iterable, List, Union
def parse_arguments(args):
Arg = Dict[str, Any]
def parse_arguments(args: List[Union[str, Arg]]) -> List[Arg]:
new_args = []
for arg in args:
# Simple arg declaration of form "<type> <name>"
@ -20,7 +23,10 @@ def parse_arguments(args):
return new_args
def set_declaration_defaults(declaration):
Declaration = Dict[str, Any]
def set_declaration_defaults(declaration: Declaration) -> None:
if 'schema_string' not in declaration:
# This happens for legacy TH bindings like
# _thnn_conv_depthwise2d_backward
@ -70,19 +76,26 @@ def set_declaration_defaults(declaration):
# TODO(zach): added option to remove keyword handling for C++ which cannot
# support it.
Option = Dict[str, Any]
def filter_unique_options(options, allow_kwarg, type_to_signature, remove_self):
def exclude_arg(arg):
return arg['type'] == 'CONSTANT'
def exclude_arg_with_self_check(arg):
def filter_unique_options(
options: Iterable[Option],
allow_kwarg: bool,
type_to_signature: Dict[str, str],
remove_self: bool,
) -> List[Option]:
def exclude_arg(arg: Arg) -> bool:
return arg['type'] == 'CONSTANT' # type: ignore[no-any-return]
def exclude_arg_with_self_check(arg: Arg) -> bool:
return exclude_arg(arg) or (remove_self and arg['name'] == 'self')
def signature(option, kwarg_only_count):
if kwarg_only_count == 0:
def signature(option: Option, num_kwarg_only: int) -> str:
if num_kwarg_only == 0:
kwarg_only_count = None
else:
kwarg_only_count = -kwarg_only_count
kwarg_only_count = -num_kwarg_only
arg_signature = '#'.join(
type_to_signature.get(arg['type'], arg['type'])
for arg in option['arguments'][:kwarg_only_count]
@ -111,40 +124,40 @@ def filter_unique_options(options, allow_kwarg, type_to_signature, remove_self):
return unique
def sort_by_number_of_args(declaration, reverse=True):
def num_args(option):
def sort_by_number_of_args(declaration: Declaration, reverse: bool = True) -> None:
def num_args(option: Option) -> int:
return len(option['arguments'])
declaration['options'].sort(key=num_args, reverse=reverse)
class Function(object):
def __init__(self, name):
def __init__(self, name: str) -> None:
self.name = name
self.arguments = []
self.arguments: List['Argument'] = []
def add_argument(self, arg):
def add_argument(self, arg: 'Argument') -> None:
assert isinstance(arg, Argument)
self.arguments.append(arg)
def __repr__(self):
def __repr__(self) -> str:
return self.name + '(' + ', '.join(a.__repr__() for a in self.arguments) + ')'
class Argument(object):
def __init__(self, _type, name, is_optional):
def __init__(self, _type: str, name: str, is_optional: bool):
self.type = _type
self.name = name
self.is_optional = is_optional
def __repr__(self):
def __repr__(self) -> str:
return self.type + ' ' + self.name
def parse_header(path):
def parse_header(path: str) -> List[Function]:
with open(path, 'r') as f:
lines = f.read().split('\n')
lines: Iterable[Any] = f.read().split('\n')
# Remove empty lines and prebackend directives
lines = filter(lambda l: l and not l.startswith('#'), lines)

View File

@ -1,6 +1,11 @@
def import_module(name, path):
from importlib.abc import Loader
from types import ModuleType
from typing import cast
def import_module(name: str, path: str) -> ModuleType:
import importlib.util
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
cast(Loader, spec.loader).exec_module(module)
return module

View File

@ -1,13 +1,19 @@
# -*- coding: utf-8 -*-
import unittest
from typing import Dict, List
from tools import print_test_stats
from tools.stats_utils.s3_stat_parser import (Commit, Report, ReportMetaMeta,
Status, Version1Case,
Version1Report, Version2Case,
Version2Report)
def fakehash(char):
def fakehash(char: str) -> str:
return char * 40
def dummy_meta_meta() -> print_test_stats.ReportMetaMeta:
def dummy_meta_meta() -> ReportMetaMeta:
return {
'build_pr': '',
'build_tag': '',
@ -18,7 +24,14 @@ def dummy_meta_meta() -> print_test_stats.ReportMetaMeta:
}
def makecase(name, seconds, *, errored=False, failed=False, skipped=False):
def makecase(
name: str,
seconds: float,
*,
errored: bool = False,
failed: bool = False,
skipped: bool = False,
) -> Version1Case:
return {
'name': name,
'seconds': seconds,
@ -28,7 +41,7 @@ def makecase(name, seconds, *, errored=False, failed=False, skipped=False):
}
def make_report_v1(tests) -> print_test_stats.Version1Report:
def make_report_v1(tests: Dict[str, List[Version1Case]]) -> Version1Report:
suites = {
suite_name: {
'total_seconds': sum(case['seconds'] for case in cases),
@ -37,20 +50,20 @@ def make_report_v1(tests) -> print_test_stats.Version1Report:
for suite_name, cases in tests.items()
}
return {
**dummy_meta_meta(),
**dummy_meta_meta(), # type: ignore[misc]
'total_seconds': sum(s['total_seconds'] for s in suites.values()),
'suites': suites,
}
def make_case_v2(seconds, status=None) -> print_test_stats.Version2Case:
def make_case_v2(seconds: float, status: Status = None) -> Version2Case:
return {
'seconds': seconds,
'status': status,
}
def make_report_v2(tests) -> print_test_stats.Version2Report:
def make_report_v2(tests: Dict[str, Dict[str, Dict[str, Version2Case]]]) -> Version2Report:
files = {}
for file_name, file_suites in tests.items():
suites = {
@ -65,7 +78,7 @@ def make_report_v2(tests) -> print_test_stats.Version2Report:
'total_seconds': sum(suite['total_seconds'] for suite in suites.values()),
}
return {
**dummy_meta_meta(),
**dummy_meta_meta(), # type: ignore[misc]
'format_version': 2,
'total_seconds': sum(s['total_seconds'] for s in files.values()),
'files': files,
@ -73,7 +86,7 @@ def make_report_v2(tests) -> print_test_stats.Version2Report:
maxDiff = None
class TestPrintTestStats(unittest.TestCase):
version1_report: print_test_stats.Version1Report = make_report_v1({
version1_report: Version1Report = make_report_v1({
# input ordering of the suites is ignored
'Grault': [
# not printed: status same and time similar
@ -112,7 +125,7 @@ class TestPrintTestStats(unittest.TestCase):
],
})
version2_report: print_test_stats.Version2Report = make_report_v2(
version2_report: Version2Report = make_report_v2(
{
'test_a': {
'Grault': {
@ -149,7 +162,7 @@ class TestPrintTestStats(unittest.TestCase):
}
})
def test_simplify(self):
def test_simplify(self) -> None:
self.assertEqual(
{
'': {
@ -222,10 +235,10 @@ class TestPrintTestStats(unittest.TestCase):
print_test_stats.simplify(self.version2_report),
)
def test_analysis(self):
def test_analysis(self) -> None:
head_report = self.version1_report
base_reports = {
base_reports: Dict[Commit, List[Report]] = {
# bbbb has no reports, so base is cccc instead
fakehash('b'): [],
fakehash('c'): [
@ -391,7 +404,7 @@ class TestPrintTestStats(unittest.TestCase):
print_test_stats.anomalies(analysis),
)
def test_graph(self):
def test_graph(self) -> None:
# HEAD is on master
self.assertEqual(
'''\
@ -534,7 +547,7 @@ Commit graph (base is most recent master ancestor with at least one S3 report):
)
)
def test_regression_info(self):
def test_regression_info(self) -> None:
self.assertEqual(
'''\
----- Historic stats comparison result ------
@ -588,7 +601,7 @@ Added (across 1 suite) 1 test, totaling + 3.00s
)
)
def test_regression_info_new_job(self):
def test_regression_info_new_job(self) -> None:
self.assertEqual(
'''\
----- Historic stats comparison result ------

View File

@ -3,6 +3,7 @@ import os
import inspect
import sys
import tempfile
from typing import Any, List, Optional, Tuple
# this arbitrary-looking assortment of functionality is provided here
# to have a central place for overrideable behavior. The motivating
@ -20,30 +21,33 @@ else:
else:
torch_parent = os.path.dirname(os.path.dirname(__file__))
def get_file_path(*path_components):
def get_file_path(*path_components: str) -> str:
return os.path.join(torch_parent, *path_components)
def get_file_path_2(*path_components):
def get_file_path_2(*path_components: str) -> str:
return os.path.join(*path_components)
def get_writable_path(path):
def get_writable_path(path: str) -> str:
if os.access(path, os.W_OK):
return path
return tempfile.mkdtemp(suffix=os.path.basename(path))
def prepare_multiprocessing_environment(path):
def prepare_multiprocessing_environment(path: str) -> None:
pass
def resolve_library_path(path):
def resolve_library_path(path: str) -> str:
return os.path.realpath(path)
def get_source_lines_and_file(obj, error_msg=None):
def get_source_lines_and_file(
obj: Any,
error_msg: Optional[str] = None,
) -> Tuple[List[str], int, Optional[str]]:
"""
Wrapper around inspect.getsourcelines and inspect.getsourcefile.