[BE] Apply ufmt to run_test and GitHub Python util scripts (#97588)

This has been bugging me for a while as I'm working on these Python scripts and they are not tracked by ufmt linter.  So I add these script into that linter.

```
[[linter]]
code = 'UFMT'
include_patterns = [
    '.github/**/*.py',
    'test/run_test.py',
```

This change should just work and not break anything as ufmt (black + usort) linter is very safe to use for standalone util scripts.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97588
Approved by: https://github.com/kit1980
This commit is contained in:
Huy Do
2023-03-26 04:52:55 +00:00
committed by PyTorch MergeBot
parent f09347a9f1
commit 4c0dce50fd
30 changed files with 1955 additions and 1070 deletions

View File

@ -1,13 +1,15 @@
#!/usr/bin/env python3
from subprocess import check_call
import shutil
import sys
from pathlib import Path
from subprocess import check_call
from tempfile import TemporaryDirectory
from typing import Optional
import sys
import shutil
SCRIPT_DIR = Path(__file__).parent
REPO_DIR = SCRIPT_DIR.parent.parent
def read_triton_pin() -> str:
with open(REPO_DIR / ".ci" / "docker" / "ci_commit_pins" / "triton.txt") as f:
return f.read().strip()
@ -19,7 +21,7 @@ def read_triton_version() -> str:
def check_and_replace(inp: str, src: str, dst: str) -> str:
""" Checks that `src` can be found in `input` and replaces it with `dst` """
"""Checks that `src` can be found in `input` and replaces it with `dst`"""
if src not in inp:
raise RuntimeError(f"Can't find ${src} in the input")
return inp.replace(src, dst)
@ -29,9 +31,11 @@ def patch_setup_py(path: Path, *, version: str, name: str = "triton") -> None:
with open(path) as f:
orig = f.read()
# Replace name
orig = check_and_replace(orig, "name=\"triton\",", f"name=\"{name}\",")
orig = check_and_replace(orig, 'name="triton",', f'name="{name}",')
# Replace version
orig = check_and_replace(orig, f"version=\"{read_triton_version()}\",", f"version=\"{version}\",")
orig = check_and_replace(
orig, f'version="{read_triton_version()}",', f'version="{version}",'
)
with open(path, "w") as f:
f.write(orig)
@ -40,12 +44,20 @@ def patch_init_py(path: Path, *, version: str) -> None:
with open(path) as f:
orig = f.read()
# Replace version
orig = check_and_replace(orig, f"__version__ = '{read_triton_version()}'", f"__version__ = \"{version}\"")
orig = check_and_replace(
orig, f"__version__ = '{read_triton_version()}'", f'__version__ = "{version}"'
)
with open(path, "w") as f:
f.write(orig)
def build_triton(*, version: str, commit_hash: str, build_conda: bool = False, py_version : Optional[str] = None) -> Path:
def build_triton(
*,
version: str,
commit_hash: str,
build_conda: bool = False,
py_version: Optional[str] = None,
) -> Path:
with TemporaryDirectory() as tmpdir:
triton_basedir = Path(tmpdir) / "triton"
triton_pythondir = triton_basedir / "python"
@ -53,26 +65,60 @@ def build_triton(*, version: str, commit_hash: str, build_conda: bool = False, p
check_call(["git", "checkout", commit_hash], cwd=triton_basedir)
if build_conda:
with open(triton_basedir / "meta.yaml", "w") as meta:
print(f"package:\n name: torchtriton\n version: {version}+{commit_hash[:10]}\n", file=meta)
print(
f"package:\n name: torchtriton\n version: {version}+{commit_hash[:10]}\n",
file=meta,
)
print("source:\n path: .\n", file=meta)
print("build:\n string: py{{py}}\n number: 1\n script: cd python; "
"python setup.py install --single-version-externally-managed --record=record.txt\n", file=meta)
print("requirements:\n host:\n - python\n - setuptools\n run:\n - python\n"
" - filelock\n - pytorch\n", file=meta)
print("about:\n home: https://github.com/openai/triton\n license: MIT\n summary:"
" 'A language and compiler for custom Deep Learning operation'", file=meta)
print(
"build:\n string: py{{py}}\n number: 1\n script: cd python; "
"python setup.py install --single-version-externally-managed --record=record.txt\n",
file=meta,
)
print(
"requirements:\n host:\n - python\n - setuptools\n run:\n - python\n"
" - filelock\n - pytorch\n",
file=meta,
)
print(
"about:\n home: https://github.com/openai/triton\n license: MIT\n summary:"
" 'A language and compiler for custom Deep Learning operation'",
file=meta,
)
patch_init_py(triton_pythondir / "triton" / "__init__.py", version=f"{version}+{commit_hash[:10]}")
patch_init_py(
triton_pythondir / "triton" / "__init__.py",
version=f"{version}+{commit_hash[:10]}",
)
if py_version is None:
py_version = f"{sys.version_info.major}.{sys.version_info.minor}"
check_call(["conda", "build", "--python", py_version,
"-c", "pytorch-nightly", "--output-folder", tmpdir, "."], cwd=triton_basedir)
check_call(
[
"conda",
"build",
"--python",
py_version,
"-c",
"pytorch-nightly",
"--output-folder",
tmpdir,
".",
],
cwd=triton_basedir,
)
conda_path = list(Path(tmpdir).glob("linux-64/torchtriton*.bz2"))[0]
shutil.copy(conda_path, Path.cwd())
return Path.cwd() / conda_path.name
patch_setup_py(triton_pythondir / "setup.py", name="pytorch-triton", version=f"{version}+{commit_hash[:10]}")
patch_init_py(triton_pythondir / "triton" / "__init__.py", version=f"{version}+{commit_hash[:10]}")
patch_setup_py(
triton_pythondir / "setup.py",
name="pytorch-triton",
version=f"{version}+{commit_hash[:10]}",
)
patch_init_py(
triton_pythondir / "triton" / "__init__.py",
version=f"{version}+{commit_hash[:10]}",
)
check_call([sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir)
whl_path = list((triton_pythondir / "dist").glob("*.whl"))[0]
shutil.copy(whl_path, Path.cwd())
@ -81,16 +127,19 @@ def build_triton(*, version: str, commit_hash: str, build_conda: bool = False, p
def main() -> None:
from argparse import ArgumentParser
parser = ArgumentParser("Build Triton binaries")
parser.add_argument("--build-conda", action="store_true")
parser.add_argument("--py-version", type=str)
parser.add_argument("--commit-hash", type=str, default=read_triton_pin())
parser.add_argument("--triton-version", type=str, default=read_triton_version())
args = parser.parse_args()
build_triton(commit_hash=args.commit_hash,
version=args.triton_version,
build_conda=args.build_conda,
py_version=args.py_version)
build_triton(
commit_hash=args.commit_hash,
version=args.triton_version,
build_conda=args.build_conda,
py_version=args.py_version,
)
if __name__ == "__main__":

View File

@ -3,21 +3,12 @@
from typing import Any
from gitutils import (
get_git_remote_name,
get_git_repo_dir,
GitRepo,
)
from github_utils import gh_delete_comment, gh_post_pr_comment
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
from label_utils import has_required_labels, is_label_err_comment, LABEL_ERR_MSG
from trymerge import GitHubPR
from github_utils import (
gh_delete_comment,
gh_post_pr_comment,
)
from label_utils import (
LABEL_ERR_MSG,
is_label_err_comment,
has_required_labels,
)
def delete_all_label_err_comments(pr: "GitHubPR") -> None:
for comment in pr.get_comments():
@ -33,6 +24,7 @@ def add_label_err_comment(pr: "GitHubPR") -> None:
def parse_args() -> Any:
from argparse import ArgumentParser
parser = ArgumentParser("Check PR labels")
parser.add_argument("pr_num", type=int)

View File

@ -1,11 +1,13 @@
#!/usr/bin/env python3
from pathlib import Path
from typing import Any, Dict, List, Set, cast
import yaml
import sys
from pathlib import Path
from typing import Any, cast, Dict, List, Set
import yaml
GITHUB_DIR = Path(__file__).parent.parent
def get_workflows_push_tags() -> Set[str]:
"Extract all known push tags from workflows"
rc: Set[str] = set()
@ -22,8 +24,10 @@ def get_workflows_push_tags() -> Set[str]:
def filter_ciflow_tags(tags: Set[str]) -> List[str]:
" Return sorted list of ciflow tags"
return sorted(tag[:-2] for tag in tags if tag.startswith("ciflow/") and tag.endswith("/*"))
"Return sorted list of ciflow tags"
return sorted(
tag[:-2] for tag in tags if tag.startswith("ciflow/") and tag.endswith("/*")
)
def read_probot_config() -> Dict[str, Any]:
@ -40,6 +44,7 @@ def update_probot_config(labels: Set[str]) -> None:
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser("Validate or update list of tags")
parser.add_argument("--validate-tags", action="store_true")
args = parser.parse_args()
@ -51,9 +56,15 @@ if __name__ == "__main__":
if config_tags != ciflow_tags:
print("Tags mismatch!")
if ciflow_tags.difference(config_tags):
print("Reference in workflows but not in config", ciflow_tags.difference(config_tags))
print(
"Reference in workflows but not in config",
ciflow_tags.difference(config_tags),
)
if config_tags.difference(ciflow_tags):
print("Reference in config, but not in workflows", config_tags.difference(ciflow_tags))
print(
"Reference in config, but not in workflows",
config_tags.difference(ciflow_tags),
)
print(f"Please run {__file__} to remediate the difference")
sys.exit(-1)
print("All tags are listed in pytorch-probot.yml")

View File

@ -1,8 +1,9 @@
import os
from typing import Any
from github_utils import gh_post_pr_comment
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
from trymerge_explainer import BOT_COMMANDS_WIKI
import os
def parse_args() -> Any:

View File

@ -6,6 +6,7 @@ from enum import Enum
from pathlib import Path
from typing import NamedTuple, Optional
# From: https://docs.github.com/en/rest/reference/checks
class GitHubAnnotationLevel(str, Enum):
NOTICE = "notice"
@ -24,7 +25,12 @@ class GitHubAnnotation(NamedTuple):
title: Optional[str]
raw_details: Optional[str]
PYTORCH_ROOT = Path(subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).decode('ascii').strip())
PYTORCH_ROOT = Path(
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
.decode("ascii")
.strip()
)
annotations = []
for line in sys.stdin:
@ -33,7 +39,6 @@ for line in sys.stdin:
path = lint_message.get("path")
line = lint_message.get("line")
code = lint_message["code"]
severity = lint_message["severity"]
name = lint_message["name"]
@ -48,16 +53,18 @@ for line in sys.stdin:
# normalize path relative to git root
path = Path(path).relative_to(PYTORCH_ROOT)
annotations.append(GitHubAnnotation(
path=str(path),
start_line=int(line),
end_line=int(line),
start_column=None,
end_column=None,
annotation_level=GitHubAnnotationLevel.FAILURE,
message=description,
title=f"({code}) {name}",
raw_details=None,
)._asdict())
annotations.append(
GitHubAnnotation(
path=str(path),
start_line=int(line),
end_line=int(line),
start_column=None,
end_column=None,
annotation_level=GitHubAnnotationLevel.FAILURE,
message=description,
title=f"({code}) {name}",
raw_details=None,
)._asdict()
)
print(json.dumps(annotations), flush=True)

View File

@ -2,15 +2,18 @@
import argparse
import sys
import yaml
from pathlib import Path
import yaml
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
WORKFLOWS = REPO_ROOT / ".github" / "workflows"
EXPECTED_GROUP = "${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}" \
EXPECTED_GROUP = (
"${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}"
"-${{ github.event_name == 'workflow_dispatch' }}"
)
def should_check(filename: Path) -> bool:

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3
'''
"""
Test ownership was introduced in https://github.com/pytorch/pytorch/issues/66232.
As a part of enforcing test ownership, we want to maintain a list of existing PyTorch labels
@ -8,17 +8,19 @@ pytorch/pytorch labels so that the file could be uploaded to S3.
This script assumes the correct env vars are set for AWS permissions.
'''
"""
import json
from typing import Any
import boto3 # type: ignore[import]
import json
from label_utils import gh_get_labels
from typing import Any
def parse_args() -> Any:
from argparse import ArgumentParser
parser = ArgumentParser("Export PR labels")
parser.add_argument("org", type=str)
parser.add_argument("repo", type=str)
@ -30,9 +32,9 @@ def main() -> None:
args = parse_args()
print(f"Exporting labels for {args.org}/{args.repo}")
labels_file_name = "pytorch_labels.json"
obj = boto3.resource('s3').Object('ossci-metrics', labels_file_name)
obj = boto3.resource("s3").Object("ossci-metrics", labels_file_name)
obj.put(Body=json.dumps(gh_get_labels(args.org, args.repo)).encode())
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -1,20 +1,23 @@
import sys
from typing import Any, Dict, List, NamedTuple, Tuple, cast
from gitutils import _check_output
import rockset # type: ignore[import]
import os
import re
import sys
from typing import Any, cast, Dict, List, NamedTuple, Tuple
import rockset # type: ignore[import]
from gitutils import _check_output
def eprint(msg: str) -> None:
print(msg, file=sys.stderr)
class WorkflowCheck(NamedTuple):
workflowName: str
name: str
jobName: str
conclusion: str
def get_latest_commits() -> List[str]:
latest_viable_commit = _check_output(
[
@ -39,42 +42,46 @@ def get_latest_commits() -> List[str]:
return commits
def query_commits(commits: List[str]) -> List[Dict[str, Any]]:
rs = rockset.RocksetClient(
host="api.usw2a1.rockset.com", api_key=os.environ["ROCKSET_API_KEY"]
)
params = [{
"name": "shas",
"type": "string",
"value": ",".join(commits)
}]
params = [{"name": "shas", "type": "string", "value": ",".join(commits)}]
res = rs.QueryLambdas.execute_query_lambda(
query_lambda='commit_jobs_batch_query',
version='8003fdfd18b64696',
workspace='commons',
parameters=params
query_lambda="commit_jobs_batch_query",
version="8003fdfd18b64696",
workspace="commons",
parameters=params,
)
return cast(List[Dict[str, Any]], res.results)
def print_commit_status(commit: str, results: Dict[str, Any]) -> None:
print(commit)
for check in results['results']:
if check['sha'] == commit:
for check in results["results"]:
if check["sha"] == commit:
print(f"\t{check['conclusion']:>10}: {check['name']}")
def get_commit_results(commit: str, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def get_commit_results(
commit: str, results: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
workflow_checks = []
for check in results:
if check['sha'] == commit:
workflow_checks.append(WorkflowCheck(
workflowName=check['workflowName'],
name=check['name'],
jobName=check['jobName'],
conclusion=check['conclusion'],
)._asdict())
if check["sha"] == commit:
workflow_checks.append(
WorkflowCheck(
workflowName=check["workflowName"],
name=check["name"],
jobName=check["jobName"],
conclusion=check["conclusion"],
)._asdict()
)
return workflow_checks
def isGreen(commit: str, results: List[Dict[str, Any]]) -> Tuple[bool, str]:
workflow_checks = get_commit_results(commit, results)
@ -87,8 +94,8 @@ def isGreen(commit: str, results: List[Dict[str, Any]]) -> Tuple[bool, str]:
}
for check in workflow_checks:
workflowName = check['workflowName']
conclusion = check['conclusion']
workflowName = check["workflowName"]
conclusion = check["conclusion"]
for required_check in regex:
if re.match(required_check, workflowName, flags=re.IGNORECASE):
if conclusion not in ["success", "skipped"]:
@ -102,6 +109,7 @@ def isGreen(commit: str, results: List[Dict[str, Any]]) -> Tuple[bool, str]:
return (True, "")
def get_latest_green_commit(commits: List[str], results: List[Dict[str, Any]]) -> Any:
for commit in commits:
eprint(f"Checking {commit}")
@ -113,13 +121,14 @@ def get_latest_green_commit(commits: List[str], results: List[Dict[str, Any]]) -
eprint("RED: " + msg)
return None
def main() -> None:
def main() -> None:
commits = get_latest_commits()
results = query_commits(commits)
latest_viable_commit = get_latest_green_commit(commits, results)
print(latest_viable_commit)
if __name__ == "__main__":
main()

View File

@ -10,7 +10,7 @@ architectures:
* Latest ROCM
"""
from typing import Dict, List, Tuple, Optional
from typing import Dict, List, Optional, Tuple
CUDA_ARCHES = ["11.7", "11.8"]
@ -19,7 +19,8 @@ CUDA_ARCHES = ["11.7", "11.8"]
ROCM_ARCHES = ["5.3", "5.4.2"]
CPU_CXX11_ABI_ARCH = ['cpu-cxx11-abi']
CPU_CXX11_ABI_ARCH = ["cpu-cxx11-abi"]
def arch_type(arch_version: str) -> str:
if arch_version in CUDA_ARCHES:
@ -121,9 +122,12 @@ def generate_conda_matrix(os: str) -> List[Dict[str, str]]:
return ret
def generate_libtorch_matrix(os: str, abi_version: str,
arches: Optional[List[str]] = None,
libtorch_variants: Optional[List[str]] = None) -> List[Dict[str, str]]:
def generate_libtorch_matrix(
os: str,
abi_version: str,
arches: Optional[List[str]] = None,
libtorch_variants: Optional[List[str]] = None,
) -> List[Dict[str, str]]:
if arches is None:
arches = ["cpu"]
if os == "linux":
@ -163,7 +167,9 @@ def generate_libtorch_matrix(os: str, abi_version: str,
"devtoolset": abi_version if os != "windows" else "",
"container_image": LIBTORCH_CONTAINER_IMAGES[
(arch_version, abi_version)
] if os != "windows" else "",
]
if os != "windows"
else "",
"package_type": "libtorch",
"build_name": f"libtorch-{gpu_arch_type}{gpu_arch_version}-{libtorch_variant}-{abi_version}".replace(
".", "_"
@ -173,9 +179,11 @@ def generate_libtorch_matrix(os: str, abi_version: str,
return ret
def generate_wheels_matrix(os: str,
arches: Optional[List[str]] = None,
python_versions: Optional[List[str]] = None) -> List[Dict[str, str]]:
def generate_wheels_matrix(
os: str,
arches: Optional[List[str]] = None,
python_versions: Optional[List[str]] = None,
) -> List[Dict[str, str]]:
package_type = "wheel"
if os == "linux":
# NOTE: We only build manywheel packages for linux
@ -196,7 +204,11 @@ def generate_wheels_matrix(os: str,
for python_version in python_versions:
for arch_version in arches:
gpu_arch_type = arch_type(arch_version)
gpu_arch_version = "" if arch_version == "cpu" or arch_version == "cpu-cxx11-abi" else arch_version
gpu_arch_version = (
""
if arch_version == "cpu" or arch_version == "cpu-cxx11-abi"
else arch_version
)
# Skip rocm 3.11 binaries for now as the docker image are not correct
if python_version == "3.11" and gpu_arch_type == "rocm":
continue
@ -215,8 +227,7 @@ def generate_wheels_matrix(os: str,
"devtoolset": "",
"container_image": WHEEL_CONTAINER_IMAGES[arch_version],
"package_type": package_type,
"pytorch_extra_install_requirements":
"nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"pytorch_extra_install_requirements": "nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950
"nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | "
@ -227,9 +238,7 @@ def generate_wheels_matrix(os: str,
"nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | "
"nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64'",
"build_name":
f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-with-pypi-cudnn"
.replace(
"build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-with-pypi-cudnn".replace( # noqa: B950
".", "_"
),
}
@ -243,7 +252,9 @@ def generate_wheels_matrix(os: str,
"desired_cuda": translate_desired_cuda(
gpu_arch_type, gpu_arch_version
),
"devtoolset": "cxx11-abi" if arch_version == "cpu-cxx11-abi" else "",
"devtoolset": "cxx11-abi"
if arch_version == "cpu-cxx11-abi"
else "",
"container_image": WHEEL_CONTAINER_IMAGES[arch_version],
"package_type": package_type,
"build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}".replace(

View File

@ -1,17 +1,16 @@
#!/usr/bin/env python3
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Dict, Set, List, Literal, Iterable
import jinja2
import os
import sys
from typing_extensions import TypedDict # Python 3.11+
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Dict, Iterable, List, Literal, Set
import generate_binary_build_matrix # type: ignore[import]
import jinja2
from typing_extensions import TypedDict # Python 3.11+
Arch = Literal["windows", "linux", "macos"]
GITHUB_DIR = Path(__file__).resolve().parent.parent
@ -23,6 +22,7 @@ LABEL_CIFLOW_BINARIES_LIBTORCH = "ciflow/binaries_libtorch"
LABEL_CIFLOW_BINARIES_CONDA = "ciflow/binaries_conda"
LABEL_CIFLOW_BINARIES_WHEEL = "ciflow/binaries_wheel"
@dataclass
class CIFlowConfig:
# For use to enable workflows to run on pytorch/pytorch-canary
@ -36,10 +36,12 @@ class CIFlowConfig:
if LABEL_CIFLOW_PERIODIC not in self.labels:
self.labels.add(LABEL_CIFLOW_TRUNK)
class Config(TypedDict):
num_shards: int
runner: str
@dataclass
class BinaryBuildWorkflow:
os: str
@ -47,23 +49,28 @@ class BinaryBuildWorkflow:
package_type: str
# Optional fields
build_environment: str = ''
abi_version: str = ''
build_environment: str = ""
abi_version: str = ""
ciflow_config: CIFlowConfig = field(default_factory=CIFlowConfig)
is_scheduled: str = ''
branches: str = 'nightly'
is_scheduled: str = ""
branches: str = "nightly"
# Mainly for macos
cross_compile_arm64: bool = False
xcode_version: str = ''
xcode_version: str = ""
def __post_init__(self) -> None:
if self.abi_version:
self.build_environment = f"{self.os}-binary-{self.package_type}-{self.abi_version}"
self.build_environment = (
f"{self.os}-binary-{self.package_type}-{self.abi_version}"
)
else:
self.build_environment = f"{self.os}-binary-{self.package_type}"
def generate_workflow_file(self, workflow_template: jinja2.Template) -> None:
output_file_path = GITHUB_DIR / f"workflows/generated-{self.build_environment}-{self.branches}.yml"
output_file_path = (
GITHUB_DIR
/ f"workflows/generated-{self.build_environment}-{self.branches}.yml"
)
with open(output_file_path, "w") as output_file:
GENERATED = "generated" # Note that please keep the variable GENERATED otherwise phabricator will hide the whole file
output_file.writelines([f"# @{GENERATED} DO NOT EDIT MANUALLY\n"])
@ -77,17 +84,21 @@ class BinaryBuildWorkflow:
output_file.write("\n")
print(output_file_path)
class OperatingSystem:
LINUX = "linux"
WINDOWS = "windows"
MACOS = "macos"
MACOS_ARM64 = "macos-arm64"
LINUX_BINARY_BUILD_WORFKLOWS = [
BinaryBuildWorkflow(
os=OperatingSystem.LINUX,
package_type="manywheel",
build_configs=generate_binary_build_matrix.generate_wheels_matrix(OperatingSystem.LINUX),
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
OperatingSystem.LINUX
),
ciflow_config=CIFlowConfig(
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL},
isolated_workflow=True,
@ -96,7 +107,9 @@ LINUX_BINARY_BUILD_WORFKLOWS = [
BinaryBuildWorkflow(
os=OperatingSystem.LINUX,
package_type="conda",
build_configs=generate_binary_build_matrix.generate_conda_matrix(OperatingSystem.LINUX),
build_configs=generate_binary_build_matrix.generate_conda_matrix(
OperatingSystem.LINUX
),
ciflow_config=CIFlowConfig(
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_CONDA},
isolated_workflow=True,
@ -133,18 +146,16 @@ LINUX_BINARY_SMOKE_WORKFLOWS = [
os=OperatingSystem.LINUX,
package_type="manywheel",
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
OperatingSystem.LINUX,
arches=["11.8"],
python_versions=["3.8"]),
OperatingSystem.LINUX, arches=["11.8"], python_versions=["3.8"]
),
branches="master",
),
BinaryBuildWorkflow(
os=OperatingSystem.LINUX,
package_type="manywheel",
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
OperatingSystem.LINUX,
arches=["11.7"],
python_versions=["3.8"]),
OperatingSystem.LINUX, arches=["11.7"], python_versions=["3.8"]
),
branches="master",
),
BinaryBuildWorkflow(
@ -152,7 +163,8 @@ LINUX_BINARY_SMOKE_WORKFLOWS = [
package_type="libtorch",
abi_version=generate_binary_build_matrix.CXX11_ABI,
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
OperatingSystem.LINUX, generate_binary_build_matrix.CXX11_ABI,
OperatingSystem.LINUX,
generate_binary_build_matrix.CXX11_ABI,
arches=["cpu"],
libtorch_variants=["shared-with-deps"],
),
@ -163,7 +175,8 @@ LINUX_BINARY_SMOKE_WORKFLOWS = [
package_type="libtorch",
abi_version=generate_binary_build_matrix.PRE_CXX11_ABI,
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
OperatingSystem.LINUX, generate_binary_build_matrix.PRE_CXX11_ABI,
OperatingSystem.LINUX,
generate_binary_build_matrix.PRE_CXX11_ABI,
arches=["cpu"],
libtorch_variants=["shared-with-deps"],
),
@ -175,7 +188,9 @@ WINDOWS_BINARY_BUILD_WORKFLOWS = [
BinaryBuildWorkflow(
os=OperatingSystem.WINDOWS,
package_type="wheel",
build_configs=generate_binary_build_matrix.generate_wheels_matrix(OperatingSystem.WINDOWS),
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
OperatingSystem.WINDOWS
),
ciflow_config=CIFlowConfig(
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL},
isolated_workflow=True,
@ -184,7 +199,9 @@ WINDOWS_BINARY_BUILD_WORKFLOWS = [
BinaryBuildWorkflow(
os=OperatingSystem.WINDOWS,
package_type="conda",
build_configs=generate_binary_build_matrix.generate_conda_matrix(OperatingSystem.WINDOWS),
build_configs=generate_binary_build_matrix.generate_conda_matrix(
OperatingSystem.WINDOWS
),
ciflow_config=CIFlowConfig(
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_CONDA},
isolated_workflow=True,
@ -221,7 +238,8 @@ WINDOWS_BINARY_SMOKE_WORKFLOWS = [
package_type="libtorch",
abi_version=generate_binary_build_matrix.RELEASE,
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
OperatingSystem.WINDOWS, generate_binary_build_matrix.RELEASE,
OperatingSystem.WINDOWS,
generate_binary_build_matrix.RELEASE,
arches=["cpu"],
libtorch_variants=["shared-with-deps"],
),
@ -232,7 +250,8 @@ WINDOWS_BINARY_SMOKE_WORKFLOWS = [
package_type="libtorch",
abi_version=generate_binary_build_matrix.DEBUG,
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
OperatingSystem.WINDOWS, generate_binary_build_matrix.DEBUG,
OperatingSystem.WINDOWS,
generate_binary_build_matrix.DEBUG,
arches=["cpu"],
libtorch_variants=["shared-with-deps"],
),
@ -244,7 +263,9 @@ MACOS_BINARY_BUILD_WORKFLOWS = [
BinaryBuildWorkflow(
os=OperatingSystem.MACOS,
package_type="wheel",
build_configs=generate_binary_build_matrix.generate_wheels_matrix(OperatingSystem.MACOS),
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
OperatingSystem.MACOS
),
ciflow_config=CIFlowConfig(
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL},
isolated_workflow=True,
@ -253,7 +274,9 @@ MACOS_BINARY_BUILD_WORKFLOWS = [
BinaryBuildWorkflow(
os=OperatingSystem.MACOS,
package_type="conda",
build_configs=generate_binary_build_matrix.generate_conda_matrix(OperatingSystem.MACOS),
build_configs=generate_binary_build_matrix.generate_conda_matrix(
OperatingSystem.MACOS
),
ciflow_config=CIFlowConfig(
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_CONDA},
isolated_workflow=True,
@ -274,7 +297,9 @@ MACOS_BINARY_BUILD_WORKFLOWS = [
BinaryBuildWorkflow(
os=OperatingSystem.MACOS_ARM64,
package_type="wheel",
build_configs=generate_binary_build_matrix.generate_wheels_matrix(OperatingSystem.MACOS_ARM64),
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
OperatingSystem.MACOS_ARM64
),
cross_compile_arm64=True,
ciflow_config=CIFlowConfig(
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL},
@ -285,7 +310,9 @@ MACOS_BINARY_BUILD_WORKFLOWS = [
os=OperatingSystem.MACOS_ARM64,
package_type="conda",
cross_compile_arm64=True,
build_configs=generate_binary_build_matrix.generate_conda_matrix(OperatingSystem.MACOS_ARM64),
build_configs=generate_binary_build_matrix.generate_conda_matrix(
OperatingSystem.MACOS_ARM64
),
ciflow_config=CIFlowConfig(
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_CONDA},
isolated_workflow=True,
@ -293,6 +320,7 @@ MACOS_BINARY_BUILD_WORKFLOWS = [
),
]
def main() -> None:
jinja_env = jinja2.Environment(
variable_start_string="!{{",
@ -302,11 +330,26 @@ def main() -> None:
# not ported yet
template_and_workflows = [
(jinja_env.get_template("linux_binary_build_workflow.yml.j2"), LINUX_BINARY_BUILD_WORFKLOWS),
(jinja_env.get_template("linux_binary_build_workflow.yml.j2"), LINUX_BINARY_SMOKE_WORKFLOWS),
(jinja_env.get_template("windows_binary_build_workflow.yml.j2"), WINDOWS_BINARY_BUILD_WORKFLOWS),
(jinja_env.get_template("windows_binary_build_workflow.yml.j2"), WINDOWS_BINARY_SMOKE_WORKFLOWS),
(jinja_env.get_template("macos_binary_build_workflow.yml.j2"), MACOS_BINARY_BUILD_WORKFLOWS),
(
jinja_env.get_template("linux_binary_build_workflow.yml.j2"),
LINUX_BINARY_BUILD_WORFKLOWS,
),
(
jinja_env.get_template("linux_binary_build_workflow.yml.j2"),
LINUX_BINARY_SMOKE_WORKFLOWS,
),
(
jinja_env.get_template("windows_binary_build_workflow.yml.j2"),
WINDOWS_BINARY_BUILD_WORKFLOWS,
),
(
jinja_env.get_template("windows_binary_build_workflow.yml.j2"),
WINDOWS_BINARY_SMOKE_WORKFLOWS,
),
(
jinja_env.get_template("macos_binary_build_workflow.yml.j2"),
MACOS_BINARY_BUILD_WORKFLOWS,
),
]
# Delete the existing generated files first, this should align with .gitattributes file description.
existing_workflows = GITHUB_DIR.glob("workflows/generated-*")
@ -323,5 +366,6 @@ def main() -> None:
for workflow in workflows:
workflow.generate_workflow_file(workflow_template=template)
if __name__ == "__main__":
main()

View File

@ -2,8 +2,8 @@
import argparse
import os
import subprocess
import re
import subprocess
from datetime import datetime
from distutils.util import strtobool
@ -13,21 +13,27 @@ LEADING_V_PATTERN = re.compile("^v")
TRAILING_RC_PATTERN = re.compile("-rc[0-9]*$")
LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$")
class NoGitTagException(Exception):
pass
def get_pytorch_root() -> Path:
return Path(subprocess.check_output(
['git', 'rev-parse', '--show-toplevel']
).decode('ascii').strip())
return Path(
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
.decode("ascii")
.strip()
)
def get_tag() -> str:
root = get_pytorch_root()
try:
dirty_tag = subprocess.check_output(
['git', 'describe', '--tags', '--exact'],
cwd=root
).decode('ascii').strip()
dirty_tag = (
subprocess.check_output(["git", "describe", "--tags", "--exact"], cwd=root)
.decode("ascii")
.strip()
)
except subprocess.CalledProcessError:
return ""
# Strip leading v that we typically do when we tag branches
@ -41,13 +47,15 @@ def get_tag() -> str:
return ""
return tag
def get_base_version() -> str:
root = get_pytorch_root()
dirty_version = open(root / 'version.txt', 'r').read().strip()
dirty_version = open(root / "version.txt", "r").read().strip()
# Strips trailing a0 from version.txt, not too sure why it's there in the
# first place
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)
class PytorchVersion:
def __init__(
self,
@ -74,10 +82,11 @@ class PytorchVersion:
return f"{get_tag()}{self.get_post_build_suffix()}"
def get_nightly_version(self) -> str:
date_str = datetime.today().strftime('%Y%m%d')
date_str = datetime.today().strftime("%Y%m%d")
build_suffix = self.get_post_build_suffix()
return f"{get_base_version()}.dev{date_str}{build_suffix}"
def main() -> None:
parser = argparse.ArgumentParser(
description="Generate pytorch version for binary builds"
@ -86,30 +95,29 @@ def main() -> None:
"--no-build-suffix",
action="store_true",
help="Whether or not to add a build suffix typically (+cpu)",
default=strtobool(os.environ.get("NO_BUILD_SUFFIX", "False"))
default=strtobool(os.environ.get("NO_BUILD_SUFFIX", "False")),
)
parser.add_argument(
"--gpu-arch-type",
type=str,
help="GPU arch you are building for, typically (cpu, cuda, rocm)",
default=os.environ.get("GPU_ARCH_TYPE", "cpu")
default=os.environ.get("GPU_ARCH_TYPE", "cpu"),
)
parser.add_argument(
"--gpu-arch-version",
type=str,
help="GPU arch version, typically (10.2, 4.0), leave blank for CPU",
default=os.environ.get("GPU_ARCH_VERSION", "")
default=os.environ.get("GPU_ARCH_VERSION", ""),
)
args = parser.parse_args()
version_obj = PytorchVersion(
args.gpu_arch_type,
args.gpu_arch_version,
args.no_build_suffix
args.gpu_arch_type, args.gpu_arch_version, args.no_build_suffix
)
try:
print(version_obj.get_release_version())
except NoGitTagException:
print(version_obj.get_nightly_version())
if __name__ == "__main__":
main()

View File

@ -11,9 +11,10 @@ import time
import urllib
import urllib.parse
from typing import Any, Callable, Dict, List, Tuple, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.request import Request, urlopen
def parse_json_and_links(conn: Any) -> Tuple[Any, Dict[str, Dict[str, str]]]:
links = {}
# Extract links which GH uses for pagination
@ -26,18 +27,26 @@ def parse_json_and_links(conn: Any) -> Tuple[Any, Dict[str, Dict[str, str]]]:
continue
url = urllib.parse.unquote(url.strip("<> "))
qparams = urllib.parse.parse_qs(params_.strip(), separator=";")
params = {k: v[0].strip('"') for k, v in qparams.items() if type(v) is list and len(v) > 0}
params = {
k: v[0].strip('"')
for k, v in qparams.items()
if type(v) is list and len(v) > 0
}
params["url"] = url
if "rel" in params:
links[params["rel"]] = params
return json.load(conn), links
def fetch_url(url: str, *,
headers: Optional[Dict[str, str]] = None,
reader: Callable[[Any], Any] = lambda x: x.read(),
retries: Optional[int] = 3,
backoff_timeout: float = .5) -> Any:
def fetch_url(
url: str,
*,
headers: Optional[Dict[str, str]] = None,
reader: Callable[[Any], Any] = lambda x: x.read(),
retries: Optional[int] = 3,
backoff_timeout: float = 0.5,
) -> Any:
if headers is None:
headers = {}
try:
@ -46,14 +55,21 @@ def fetch_url(url: str, *,
except urllib.error.HTTPError as err:
if isinstance(retries, (int, float)) and retries > 0:
time.sleep(backoff_timeout)
return fetch_url(url, headers=headers, reader=reader, retries=retries - 1, backoff_timeout=backoff_timeout)
return fetch_url(
url,
headers=headers,
reader=reader,
retries=retries - 1,
backoff_timeout=backoff_timeout,
)
exception_message = (
"Is github alright?",
f"Recieved status code '{err.code}' when attempting to retrieve {url}:\n",
f"{err.reason}\n\nheaders={err.headers}"
f"{err.reason}\n\nheaders={err.headers}",
)
raise RuntimeError(exception_message) from err
def parse_args() -> Any:
parser = argparse.ArgumentParser()
parser.add_argument(
@ -72,7 +88,9 @@ def fetch_jobs(url: str, headers: Dict[str, str]) -> List[Dict[str, str]]:
jobs = response["jobs"]
assert type(jobs) is list
while "next" in links.keys():
response, links = fetch_url(links["next"]["url"], headers=headers, reader=parse_json_and_links)
response, links = fetch_url(
links["next"]["url"], headers=headers, reader=parse_json_and_links
)
jobs.extend(response["jobs"])
return jobs
@ -92,6 +110,7 @@ def fetch_jobs(url: str, headers: Dict[str, str]) -> List[Dict[str, str]]:
# looking for RUNNER_NAME will uniquely identify the job we're currently
# running.
def find_job_id(args: Any) -> str:
# From https://docs.github.com/en/actions/learn-github-actions/environment-variables
PYTORCH_REPO = os.environ.get("GITHUB_REPOSITORY", "pytorch/pytorch")
@ -115,6 +134,7 @@ def find_job_id(args: Any) -> str:
raise RuntimeError(f"Can't find job id for runner {args.runner_name}")
def main() -> None:
args = parse_args()
try:
@ -123,5 +143,6 @@ def main() -> None:
print(repr(e), file=sys.stderr)
print(f"workflow-{args.workflow_run_id}")
if __name__ == "__main__":
main()

View File

@ -21,56 +21,69 @@ class GitHubComment:
def gh_fetch_url(
url: str, *,
url: str,
*,
headers: Optional[Dict[str, str]] = None,
data: Optional[Dict[str, Any]] = None,
method: Optional[str] = None,
reader: Callable[[Any], Any] = lambda x: x.read()
reader: Callable[[Any], Any] = lambda x: x.read(),
) -> Any:
if headers is None:
headers = {}
token = os.environ.get("GITHUB_TOKEN")
if token is not None and url.startswith('https://api.github.com/'):
headers['Authorization'] = f'token {token}'
if token is not None and url.startswith("https://api.github.com/"):
headers["Authorization"] = f"token {token}"
data_ = json.dumps(data).encode() if data is not None else None
try:
with urlopen(Request(url, headers=headers, data=data_, method=method)) as conn:
return reader(conn)
except HTTPError as err:
if err.code == 403 and all(key in err.headers for key in ['X-RateLimit-Limit', 'X-RateLimit-Used']):
print(f"""Rate limit exceeded:
if err.code == 403 and all(
key in err.headers for key in ["X-RateLimit-Limit", "X-RateLimit-Used"]
):
print(
f"""Rate limit exceeded:
Used: {err.headers['X-RateLimit-Used']}
Limit: {err.headers['X-RateLimit-Limit']}
Remaining: {err.headers['X-RateLimit-Remaining']}
Resets at: {err.headers['x-RateLimit-Reset']}""")
Resets at: {err.headers['x-RateLimit-Reset']}"""
)
raise
def gh_fetch_json(
url: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None
data: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
headers = {'Accept': 'application/vnd.github.v3+json'}
headers = {"Accept": "application/vnd.github.v3+json"}
if params is not None and len(params) > 0:
url += '?' + '&'.join(f"{name}={quote(str(val))}" for name, val in params.items())
return cast(List[Dict[str, Any]], gh_fetch_url(url, headers=headers, data=data, reader=json.load))
url += "?" + "&".join(
f"{name}={quote(str(val))}" for name, val in params.items()
)
return cast(
List[Dict[str, Any]],
gh_fetch_url(url, headers=headers, data=data, reader=json.load),
)
def _gh_fetch_json_any(
url: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None
data: Optional[Dict[str, Any]] = None,
) -> Any:
headers = {'Accept': 'application/vnd.github.v3+json'}
headers = {"Accept": "application/vnd.github.v3+json"}
if params is not None and len(params) > 0:
url += '?' + '&'.join(f"{name}={quote(str(val))}" for name, val in params.items())
url += "?" + "&".join(
f"{name}={quote(str(val))}" for name, val in params.items()
)
return gh_fetch_url(url, headers=headers, data=data, reader=json.load)
def gh_fetch_json_list(
url: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None
data: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
return cast(List[Dict[str, Any]], _gh_fetch_json_any(url, params, data))
@ -78,24 +91,38 @@ def gh_fetch_json_list(
def gh_fetch_json_dict(
url: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None
) -> Dict[str, Any] :
data: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data))
def _gh_post_comment(url: str, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]:
def _gh_post_comment(
url: str, comment: str, dry_run: bool = False
) -> List[Dict[str, Any]]:
if dry_run:
print(comment)
return []
return gh_fetch_json_list(url, data={"body": comment})
def gh_post_pr_comment(org: str, repo: str, pr_num: int, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]:
return _gh_post_comment(f'https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/comments', comment, dry_run)
def gh_post_pr_comment(
org: str, repo: str, pr_num: int, comment: str, dry_run: bool = False
) -> List[Dict[str, Any]]:
return _gh_post_comment(
f"https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/comments",
comment,
dry_run,
)
def gh_post_commit_comment(org: str, repo: str, sha: str, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]:
return _gh_post_comment(f'https://api.github.com/repos/{org}/{repo}/commits/{sha}/comments', comment, dry_run)
def gh_post_commit_comment(
org: str, repo: str, sha: str, comment: str, dry_run: bool = False
) -> List[Dict[str, Any]]:
return _gh_post_comment(
f"https://api.github.com/repos/{org}/{repo}/commits/{sha}/comments",
comment,
dry_run,
)
def gh_delete_comment(org: str, repo: str, comment_id: int) -> None:

View File

@ -5,7 +5,7 @@ import re
import tempfile
from collections import defaultdict
from datetime import datetime
from typing import cast, Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
RE_GITHUB_URL_MATCH = re.compile("^https://.*@?github.com/(.+)/(.+)$")
@ -17,6 +17,7 @@ def get_git_remote_name() -> str:
def get_git_repo_dir() -> str:
from pathlib import Path
return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parent.parent.parent))
@ -25,13 +26,14 @@ def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]:
Converts list to dict preserving elements with duplicate keys
"""
rc: Dict[str, List[str]] = defaultdict(lambda: [])
for (key, val) in items:
for key, val in items:
rc[key].append(val)
return dict(rc)
def _check_output(items: List[str], encoding: str = "utf-8") -> str:
from subprocess import check_output, CalledProcessError, STDOUT
from subprocess import CalledProcessError, check_output, STDOUT
try:
return check_output(items, stderr=STDOUT).decode(encoding)
except CalledProcessError as e:
@ -53,13 +55,15 @@ class GitCommit:
author_date: datetime
commit_date: Optional[datetime]
def __init__(self,
commit_hash: str,
author: str,
author_date: datetime,
title: str,
body: str,
commit_date: Optional[datetime] = None) -> None:
def __init__(
self,
commit_hash: str,
author: str,
author_date: datetime,
title: str,
body: str,
commit_date: Optional[datetime] = None,
) -> None:
self.commit_hash = commit_hash
self.author = author
self.author_date = author_date
@ -100,13 +104,14 @@ def parse_fuller_format(lines: Union[str, List[str]]) -> GitCommit:
assert lines[3].startswith("Commit: ")
assert lines[4].startswith("CommitDate: ")
assert len(lines[5]) == 0
return GitCommit(commit_hash=lines[0].split()[1].strip(),
author=lines[1].split(":", 1)[1].strip(),
author_date=datetime.fromtimestamp(int(lines[2].split(":", 1)[1].strip())),
commit_date=datetime.fromtimestamp(int(lines[4].split(":", 1)[1].strip())),
title=lines[6].strip(),
body="\n".join(lines[7:]),
)
return GitCommit(
commit_hash=lines[0].split()[1].strip(),
author=lines[1].split(":", 1)[1].strip(),
author_date=datetime.fromtimestamp(int(lines[2].split(":", 1)[1].strip())),
commit_date=datetime.fromtimestamp(int(lines[4].split(":", 1)[1].strip())),
title=lines[6].strip(),
body="\n".join(lines[7:]),
)
class GitRepo:
@ -139,16 +144,16 @@ class GitRepo:
self._run_git("fetch", self.remote, f"{ref}:{branch}")
def show_ref(self, name: str) -> str:
refs = self._run_git('show-ref', '-s', name).strip().split('\n')
refs = self._run_git("show-ref", "-s", name).strip().split("\n")
if not all(refs[i] == refs[0] for i in range(1, len(refs))):
raise RuntimeError(f"referce {name} is ambigous")
return refs[0]
def rev_parse(self, name: str) -> str:
return self._run_git('rev-parse', '--verify', name).strip()
return self._run_git("rev-parse", "--verify", name).strip()
def get_merge_base(self, from_ref: str, to_ref: str) -> str:
return self._run_git('merge-base', from_ref, to_ref).strip()
return self._run_git("merge-base", from_ref, to_ref).strip()
def patch_id(self, ref: Union[str, List[str]]) -> List[Tuple[str, str]]:
is_list = isinstance(ref, list)
@ -156,25 +161,31 @@ class GitRepo:
if len(ref) == 0:
return []
ref = " ".join(ref)
rc = _check_output(['sh', '-c', f'git -C {self.repo_dir} show {ref}|git patch-id --stable']).strip()
rc = _check_output(
["sh", "-c", f"git -C {self.repo_dir} show {ref}|git patch-id --stable"]
).strip()
return [cast(Tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")]
def commits_resolving_gh_pr(self, pr_num: int) -> List[str]:
owner, name = self.gh_owner_and_name()
msg = f"Pull Request resolved: https://github.com/{owner}/{name}/pull/{pr_num}"
rc = self._run_git('log', '--format=%H', '--grep', msg).strip()
rc = self._run_git("log", "--format=%H", "--grep", msg).strip()
return rc.split("\n") if len(rc) > 0 else []
def get_commit(self, ref: str) -> GitCommit:
return parse_fuller_format(self._run_git('show', '--format=fuller', '--date=unix', '--shortstat', ref))
return parse_fuller_format(
self._run_git("show", "--format=fuller", "--date=unix", "--shortstat", ref)
)
def cherry_pick(self, ref: str) -> None:
self._run_git('cherry-pick', '-x', ref)
self._run_git("cherry-pick", "-x", ref)
def revert(self, ref: str) -> None:
self._run_git("revert", "--no-edit", ref)
def compute_branch_diffs(self, from_branch: str, to_branch: str) -> Tuple[List[str], List[str]]:
def compute_branch_diffs(
self, from_branch: str, to_branch: str
) -> Tuple[List[str], List[str]]:
"""
Returns list of commmits that are missing in each other branch since their merge base
Might be slow if merge base is between two branches is pretty far off
@ -182,8 +193,8 @@ class GitRepo:
from_ref = self.rev_parse(from_branch)
to_ref = self.rev_parse(to_branch)
merge_base = self.get_merge_base(from_ref, to_ref)
from_commits = self.revlist(f'{merge_base}..{from_ref}')
to_commits = self.revlist(f'{merge_base}..{to_ref}')
from_commits = self.revlist(f"{merge_base}..{from_ref}")
to_commits = self.revlist(f"{merge_base}..{to_ref}")
from_ids = fuzzy_list_to_dict(self.patch_id(from_commits))
to_ids = fuzzy_list_to_dict(self.patch_id(to_commits))
for patch_id in set(from_ids).intersection(set(to_ids)):
@ -199,14 +210,19 @@ class GitRepo:
# HACK: Same commit were merged, reverted and landed again
# which creates a tracking problem
if (
"pytorch/pytorch" not in self.remote_url() or
frc.commit_hash not in {"0a6a1b27a464ba5be5f587cce2ee12ab8c504dbf",
"6d0f4a1d545a8f161df459e8d4ccafd4b9017dbe",
"edf909e58f06150f7be41da2f98a3b9de3167bca",
"a58c6aea5a0c9f8759a4154e46f544c8b03b8db1",
"7106d216c29ca16a3504aa2bedad948ebcf4abc2"}
"pytorch/pytorch" not in self.remote_url()
or frc.commit_hash
not in {
"0a6a1b27a464ba5be5f587cce2ee12ab8c504dbf",
"6d0f4a1d545a8f161df459e8d4ccafd4b9017dbe",
"edf909e58f06150f7be41da2f98a3b9de3167bca",
"a58c6aea5a0c9f8759a4154e46f544c8b03b8db1",
"7106d216c29ca16a3504aa2bedad948ebcf4abc2",
}
):
raise RuntimeError(f"Unexpected differences between {frc} and {toc}")
raise RuntimeError(
f"Unexpected differences between {frc} and {toc}"
)
from_commits.remove(frc.commit_hash)
to_commits.remove(toc.commit_hash)
continue
@ -217,11 +233,13 @@ class GitRepo:
# Another HACK: Patch-id is not stable for commits with binary files or for big changes across commits
# I.e. cherry-picking those from one branch into another will change patchid
if "pytorch/pytorch" in self.remote_url():
for excluded_commit in {"8e09e20c1dafcdbdb45c2d1574da68a32e54a3a5",
"5f37e5c2a39c3acb776756a17730b865f0953432",
"b5222584e6d6990c6585981a936defd1af14c0ba",
"84d9a2e42d5ed30ec3b8b4140c38dd83abbce88d",
"f211ec90a6cdc8a2a5795478b5b5c8d7d7896f7e"}:
for excluded_commit in {
"8e09e20c1dafcdbdb45c2d1574da68a32e54a3a5",
"5f37e5c2a39c3acb776756a17730b865f0953432",
"b5222584e6d6990c6585981a936defd1af14c0ba",
"84d9a2e42d5ed30ec3b8b4140c38dd83abbce88d",
"f211ec90a6cdc8a2a5795478b5b5c8d7d7896f7e",
}:
if excluded_commit in from_commits:
from_commits.remove(excluded_commit)
@ -281,7 +299,14 @@ class GitRepo:
def clone_repo(username: str, password: str, org: str, project: str) -> GitRepo:
path = tempfile.mkdtemp()
_check_output(['git', 'clone', f'https://{username}:{password}@github.com/{org}/{project}', path]).strip()
_check_output(
[
"git",
"clone",
f"https://{username}:{password}@github.com/{org}/{project}",
path,
]
).strip()
return GitRepo(path=path)
@ -337,17 +362,21 @@ def patterns_to_regex(allowed_patterns: List[str]) -> Any:
rc += ")"
return re.compile(rc)
def _shasum(value: str) -> str:
import hashlib
m = hashlib.sha256()
m.update(value.encode("utf-8"))
return m.hexdigest()
def are_ghstack_branches_in_sync(repo: GitRepo, head_ref: str) -> bool:
""" Checks that diff between base and head is the same as diff between orig and its parent """
orig_ref = re.sub(r'/head$', '/orig', head_ref)
base_ref = re.sub(r'/head$', '/base', head_ref)
"""Checks that diff between base and head is the same as diff between orig and its parent"""
orig_ref = re.sub(r"/head$", "/orig", head_ref)
base_ref = re.sub(r"/head$", "/base", head_ref)
orig_diff_sha = _shasum(repo.diff(f"{repo.remote}/{orig_ref}"))
head_diff_sha = _shasum(repo.diff(f"{repo.remote}/{base_ref}", f"{repo.remote}/{head_ref}"))
head_diff_sha = _shasum(
repo.diff(f"{repo.remote}/{base_ref}", f"{repo.remote}/{head_ref}")
)
return orig_diff_sha == head_diff_sha

View File

@ -3,13 +3,10 @@
import json
from functools import lru_cache
from typing import List, Any, Tuple, TYPE_CHECKING, Union
from urllib.request import urlopen, Request
from typing import Any, List, Tuple, TYPE_CHECKING, Union
from urllib.request import Request, urlopen
from github_utils import (
GitHubComment,
gh_fetch_json,
)
from github_utils import gh_fetch_json, GitHubComment
# TODO: this is a temp workaround to avoid circular dependencies,
# and should be removed once GitHubPR is refactored out of trymerge script.
@ -31,14 +28,15 @@ For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.
"""
# Modified from https://github.com/pytorch/pytorch/blob/b00206d4737d1f1e7a442c9f8a1cadccd272a386/torch/hub.py#L129
def _read_url(url: Request) -> Tuple[Any, Any]:
with urlopen(url) as r:
return r.headers, r.read().decode(r.headers.get_content_charset('utf-8'))
return r.headers, r.read().decode(r.headers.get_content_charset("utf-8"))
def request_for_labels(url: str) -> Tuple[Any, Any]:
headers = {'Accept': 'application/vnd.github.v3+json'}
headers = {"Accept": "application/vnd.github.v3+json"}
return _read_url(Request(url, headers=headers))
@ -50,10 +48,12 @@ def update_labels(labels: List[str], info: str) -> None:
def get_last_page_num_from_header(header: Any) -> int:
# Link info looks like: <https://api.github.com/repositories/65600975/labels?per_page=100&page=2>;
# rel="next", <https://api.github.com/repositories/65600975/labels?per_page=100&page=3>; rel="last"
link_info = header['link']
link_info = header["link"]
prefix = "&page="
suffix = ">;"
return int(link_info[link_info.rindex(prefix) + len(prefix):link_info.rindex(suffix)])
return int(
link_info[link_info.rindex(prefix) + len(prefix) : link_info.rindex(suffix)]
)
@lru_cache()
@ -64,7 +64,9 @@ def gh_get_labels(org: str, repo: str) -> List[str]:
update_labels(labels, info)
last_page = get_last_page_num_from_header(header)
assert last_page > 0, "Error reading header info to determine total number of pages of labels"
assert (
last_page > 0
), "Error reading header info to determine total number of pages of labels"
for page_number in range(2, last_page + 1): # skip page 1
_, info = request_for_labels(prefix + f"&page={page_number}")
update_labels(labels, info)
@ -72,26 +74,37 @@ def gh_get_labels(org: str, repo: str) -> List[str]:
return labels
def gh_add_labels(org: str, repo: str, pr_num: int, labels: Union[str, List[str]]) -> None:
def gh_add_labels(
org: str, repo: str, pr_num: int, labels: Union[str, List[str]]
) -> None:
gh_fetch_json(
f'https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/labels',
f"https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/labels",
data={"labels": labels},
)
def get_release_notes_labels(org: str, repo: str) -> List[str]:
return [label for label in gh_get_labels(org, repo) if label.lstrip().startswith("release notes:")]
return [
label
for label in gh_get_labels(org, repo)
if label.lstrip().startswith("release notes:")
]
def has_required_labels(pr: "GitHubPR") -> bool:
pr_labels = pr.get_labels()
# Check if PR is not user facing
is_not_user_facing_pr = any(label.strip() == "topic: not user facing" for label in pr_labels)
return (
is_not_user_facing_pr or
any(label.strip() in get_release_notes_labels(pr.org, pr.project) for label in pr_labels)
is_not_user_facing_pr = any(
label.strip() == "topic: not user facing" for label in pr_labels
)
return is_not_user_facing_pr or any(
label.strip() in get_release_notes_labels(pr.org, pr.project)
for label in pr_labels
)
def is_label_err_comment(comment: GitHubComment) -> bool:
return comment.body_text.lstrip(" #").startswith(LABEL_ERR_MSG_TITLE) and comment.author_login in BOT_AUTHORS
return (
comment.body_text.lstrip(" #").startswith(LABEL_ERR_MSG_TITLE)
and comment.author_login in BOT_AUTHORS
)

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3
'''
"""
Verify that it is possible to round-trip native_functions.yaml via ruamel under some
configuration. Keeping native_functions.yaml consistent in this way allows us to
run codemods on the file using ruamel without introducing line noise. Note that we don't
@ -12,24 +12,27 @@ you may find that you want to use some format that is not what ruamel prefers.
it is OK to modify this script (instead of reformatting native_functions.yaml)--the point
is simply to make sure that there is *some* configuration of ruamel that can round trip
the YAML, not to be prescriptive about it.
'''
"""
import ruamel.yaml # type: ignore[import]
import difflib
import sys
from pathlib import Path
from io import StringIO
from pathlib import Path
import ruamel.yaml # type: ignore[import]
def fn(base: str) -> str:
return str(base / Path("aten/src/ATen/native/native_functions.yaml"))
with open(Path(__file__).parent.parent.parent / fn('.'), "r") as f:
with open(Path(__file__).parent.parent.parent / fn("."), "r") as f:
contents = f.read()
yaml = ruamel.yaml.YAML() # type: ignore[attr-defined]
yaml.preserve_quotes = True # type: ignore[assignment]
yaml.width = 1000 # type: ignore[assignment]
yaml.boolean_representation = ['False', 'True'] # type: ignore[attr-defined]
yaml.boolean_representation = ["False", "True"] # type: ignore[attr-defined]
r = yaml.load(contents)
# Cuz ruamel's author intentionally didn't include conversion to string
@ -40,12 +43,19 @@ new_contents = string_stream.getvalue()
string_stream.close()
if contents != new_contents:
print("""\
print(
"""\
## LINT FAILURE: native_functions.yaml ##
native_functions.yaml failed lint; please apply the diff below to fix lint.
If you think this is in error, please see .github/scripts/lint_native_functions.py
""", file=sys.stderr)
sys.stdout.writelines(difflib.unified_diff(contents.splitlines(True), new_contents.splitlines(True), fn('a'), fn('b')))
""",
file=sys.stderr,
)
sys.stdout.writelines(
difflib.unified_diff(
contents.splitlines(True), new_contents.splitlines(True), fn("a"), fn("b")
)
)
sys.exit(1)

View File

@ -14,7 +14,7 @@ def set_output(name: str, val: str) -> None:
def main() -> None:
ref = os.environ["GITHUB_REF"]
m = re.match(r'^refs/(\w+)/(.*)$', ref)
m = re.match(r"^refs/(\w+)/(.*)$", ref)
if m:
category, stripped = m.groups()
if category == "heads":

View File

@ -9,19 +9,21 @@ Testing environment:
- Python 3.8
- CUDA 11.3
"""
import argparse
# Known issues:
# 1. Does not reuse the build artifact in other CI workflows
# 2. CI jobs are serialized because there is only one worker
import os
import boto3 # type: ignore[import]
import git # type: ignore[import]
import pathlib
import argparse
import subprocess
from pathlib import Path
from typing import List, Tuple
import boto3 # type: ignore[import]
import git # type: ignore[import]
TORCHBENCH_CONFIG_NAME = "config.yaml"
TORCHBENCH_USERBENCHMARK_CONFIG_NAME = "ub-config.yaml"
MAGIC_PREFIX = "RUN_TORCHBENCH:"
@ -37,15 +39,18 @@ S3_BUCKET = "ossci-metrics"
S3_PREFIX = "torchbench-pr-test"
S3_URL_BASE = f"https://{S3_BUCKET}.s3.amazonaws.com/"
class S3Client:
def __init__(self, bucket: str = S3_BUCKET, prefix: str = S3_PREFIX):
self.s3 = boto3.client('s3')
self.resource = boto3.resource('s3')
self.s3 = boto3.client("s3")
self.resource = boto3.resource("s3")
self.bucket = bucket
self.prefix = prefix
def upload_file(self, file_path: Path, filekey_prefix: str) -> None:
assert file_path.is_file(), f"Specified file path {file_path} does not exist or not file."
assert (
file_path.is_file()
), f"Specified file path {file_path} does not exist or not file."
file_name = file_path.name
s3_key = f"{self.prefix}/{filekey_prefix}/{file_name}"
print(f"Uploading file {file_name} to S3 with key: {s3_key}")
@ -53,6 +58,7 @@ class S3Client:
# output the result URL
print(f"Uploaded the result file {file_name} to {S3_URL_BASE}{s3_key}")
def gen_abtest_config(control: str, treatment: str, models: List[str]) -> str:
d = {}
d["control"] = control
@ -65,18 +71,23 @@ def gen_abtest_config(control: str, treatment: str, models: List[str]) -> str:
config = config + "\n"
return config
def setup_gha_env(name: str, val: str) -> None:
fname = os.environ["GITHUB_ENV"]
content = f"{name}={val}\n"
with open(fname, "a") as fo:
fo.write(content)
def find_current_branch(repo_path: str) -> str:
repo = git.Repo(repo_path)
name: str = repo.active_branch.name
return name
def deploy_torchbench_config(output_dir: str, config: str, config_name: str = TORCHBENCH_CONFIG_NAME) -> None:
def deploy_torchbench_config(
output_dir: str, config: str, config_name: str = TORCHBENCH_CONFIG_NAME
) -> None:
# Create test dir if needed
pathlib.Path(output_dir).mkdir(exist_ok=True)
# TorchBench config file name
@ -84,20 +95,37 @@ def deploy_torchbench_config(output_dir: str, config: str, config_name: str = TO
with open(config_path, "w") as fp:
fp.write(config)
def get_valid_models(torchbench_path: str) -> List[str]:
benchmark_path = os.path.join(torchbench_path, "torchbenchmark", "models")
valid_models = [model for model in os.listdir(benchmark_path) if os.path.isdir(os.path.join(benchmark_path, model))]
valid_models = [
model
for model in os.listdir(benchmark_path)
if os.path.isdir(os.path.join(benchmark_path, model))
]
return valid_models
def get_valid_userbenchmarks(torchbench_path: str) -> List[str]:
def is_valid_ub_dir(ub_path: str) -> bool:
return os.path.isdir(ub_path) and os.path.exists(os.path.join(ub_path, "__init__.py"))
return os.path.isdir(ub_path) and os.path.exists(
os.path.join(ub_path, "__init__.py")
)
ub_path = os.path.join(os.path.abspath(torchbench_path), "userbenchmark")
ubs = list(filter(is_valid_ub_dir, [os.path.join(ub_path, ubdir) for ubdir in os.listdir(ub_path)]))
ubs = list(
filter(
is_valid_ub_dir,
[os.path.join(ub_path, ubdir) for ubdir in os.listdir(ub_path)],
)
)
valid_ubs = list(map(lambda x: os.path.basename(x), ubs))
return valid_ubs
def extract_models_from_pr(torchbench_path: str, prbody_file: str) -> Tuple[List[str], List[str]]:
def extract_models_from_pr(
torchbench_path: str, prbody_file: str
) -> Tuple[List[str], List[str]]:
model_list = []
userbenchmark_list = []
pr_list = []
@ -106,7 +134,9 @@ def extract_models_from_pr(torchbench_path: str, prbody_file: str) -> Tuple[List
magic_lines = list(filter(lambda x: x.startswith(MAGIC_PREFIX), lines))
if magic_lines:
# Only the first magic line will be recognized.
pr_list = list(map(lambda x: x.strip(), magic_lines[0][len(MAGIC_PREFIX):].split(",")))
pr_list = list(
map(lambda x: x.strip(), magic_lines[0][len(MAGIC_PREFIX) :].split(","))
)
valid_models = get_valid_models(torchbench_path)
valid_ubs = get_valid_userbenchmarks(torchbench_path)
for pr_bm in pr_list:
@ -115,53 +145,87 @@ def extract_models_from_pr(torchbench_path: str, prbody_file: str) -> Tuple[List
elif pr_bm in valid_ubs:
userbenchmark_list.append(pr_bm)
else:
print(f"The model or benchmark {pr_bm} you specified does not exist in TorchBench suite. Please double check.")
print(
f"The model or benchmark {pr_bm} you specified does not exist in TorchBench suite. Please double check."
)
exit(-1)
# Shortcut: if pr_list is ["ALL"], run all the model tests
if "ALL" in model_list:
model_list = ["ALL"]
return model_list, userbenchmark_list
def find_torchbench_branch(prbody_file: str) -> str:
branch_name: str = ""
with open(prbody_file, "r") as pf:
lines = map(lambda x: x.strip(), pf.read().splitlines())
magic_lines = list(filter(lambda x: x.startswith(MAGIC_TORCHBENCH_PREFIX), lines))
magic_lines = list(
filter(lambda x: x.startswith(MAGIC_TORCHBENCH_PREFIX), lines)
)
if magic_lines:
# Only the first magic line will be recognized.
branch_name = magic_lines[0][len(MAGIC_TORCHBENCH_PREFIX):].strip()
branch_name = magic_lines[0][len(MAGIC_TORCHBENCH_PREFIX) :].strip()
# If not specified, use main as the default branch
if not branch_name:
branch_name = "main"
return branch_name
def run_torchbench(pytorch_path: str, torchbench_path: str, output_dir: str) -> None:
# Copy system environment so that we will not override
env = dict(os.environ)
command = ["python", "bisection.py", "--work-dir", output_dir,
"--pytorch-src", pytorch_path, "--torchbench-src", torchbench_path,
"--config", os.path.join(output_dir, TORCHBENCH_CONFIG_NAME),
"--output", os.path.join(output_dir, "result.txt")]
command = [
"python",
"bisection.py",
"--work-dir",
output_dir,
"--pytorch-src",
pytorch_path,
"--torchbench-src",
torchbench_path,
"--config",
os.path.join(output_dir, TORCHBENCH_CONFIG_NAME),
"--output",
os.path.join(output_dir, "result.txt"),
]
print(f"Running torchbench command: {command}")
subprocess.check_call(command, cwd=torchbench_path, env=env)
def run_userbenchmarks(pytorch_path: str, torchbench_path: str, base_sha: str, head_sha: str,
userbenchmark: str, output_dir: str) -> None:
def run_userbenchmarks(
pytorch_path: str,
torchbench_path: str,
base_sha: str,
head_sha: str,
userbenchmark: str,
output_dir: str,
) -> None:
# Copy system environment so that we will not override
env = dict(os.environ)
command = ["python", "./.github/scripts/abtest.py",
"--pytorch-repo", pytorch_path,
"--base", base_sha,
"--head", head_sha,
"--userbenchmark", userbenchmark,
"--output-dir", output_dir]
command = [
"python",
"./.github/scripts/abtest.py",
"--pytorch-repo",
pytorch_path,
"--base",
base_sha,
"--head",
head_sha,
"--userbenchmark",
userbenchmark,
"--output-dir",
output_dir,
]
print(f"Running torchbench userbenchmark command: {command}")
subprocess.check_call(command, cwd=torchbench_path, env=env)
def process_upload_s3(result_dir: str) -> None:
# validate result directory
result_dir_path = Path(result_dir)
assert result_dir_path.exists(), f"Specified result directory {result_dir} doesn't exist."
assert (
result_dir_path.exists()
), f"Specified result directory {result_dir} doesn't exist."
# upload all files to S3 bucket oss-ci-metrics
files = [x for x in result_dir_path.iterdir() if x.is_file()]
# upload file to S3 bucket
@ -170,54 +234,92 @@ def process_upload_s3(result_dir: str) -> None:
for f in files:
s3_client.upload_file(f, filekey_prefix)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Run TorchBench tests based on PR')
parser.add_argument('--pr-body', help="The file that contains body of a Pull Request")
subparsers = parser.add_subparsers(dest='command')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run TorchBench tests based on PR")
parser.add_argument(
"--pr-body", help="The file that contains body of a Pull Request"
)
subparsers = parser.add_subparsers(dest="command")
# parser for setup the torchbench branch name env
branch_parser = subparsers.add_parser("set-torchbench-branch")
# parser to run the torchbench branch
run_parser = subparsers.add_parser("run")
run_parser.add_argument('--pr-num', required=True, type=str, help="The Pull Request number")
run_parser.add_argument('--pr-base-sha', required=True, type=str, help="The Pull Request base hash")
run_parser.add_argument('--pr-head-sha', required=True, type=str, help="The Pull Request head hash")
run_parser.add_argument('--pytorch-path', required=True, type=str, help="Path to pytorch repository")
run_parser.add_argument('--torchbench-path', required=True, type=str, help="Path to TorchBench repository")
run_parser.add_argument(
"--pr-num", required=True, type=str, help="The Pull Request number"
)
run_parser.add_argument(
"--pr-base-sha", required=True, type=str, help="The Pull Request base hash"
)
run_parser.add_argument(
"--pr-head-sha", required=True, type=str, help="The Pull Request head hash"
)
run_parser.add_argument(
"--pytorch-path", required=True, type=str, help="Path to pytorch repository"
)
run_parser.add_argument(
"--torchbench-path",
required=True,
type=str,
help="Path to TorchBench repository",
)
# parser to upload results to S3
upload_parser = subparsers.add_parser("upload-s3")
upload_parser.add_argument('--result-dir', required=True, type=str, help="Path to benchmark output")
upload_parser.add_argument(
"--result-dir", required=True, type=str, help="Path to benchmark output"
)
args = parser.parse_args()
if args.command == 'set-torchbench-branch':
if args.command == "set-torchbench-branch":
branch_name = find_torchbench_branch(args.pr_body)
# env name: "TORCHBENCH_BRANCH"
setup_gha_env(MAGIC_TORCHBENCH_PREFIX[:-1], branch_name)
elif args.command == 'run':
output_dir: str = os.path.join(os.environ["HOME"], ".torchbench", "bisection", f"pr{args.pr_num}")
elif args.command == "run":
output_dir: str = os.path.join(
os.environ["HOME"], ".torchbench", "bisection", f"pr{args.pr_num}"
)
# Assert the current branch in args.torchbench_path is the same as the one specified in pr body
branch_name = find_torchbench_branch(args.pr_body)
current_branch = find_current_branch(args.torchbench_path)
assert branch_name == current_branch, f"Torchbench repo {args.torchbench_path} is on branch {current_branch}, \
assert (
branch_name == current_branch
), f"Torchbench repo {args.torchbench_path} is on branch {current_branch}, \
but user specified to run on branch {branch_name}."
print(f"Ready to run TorchBench with benchmark. Result will be saved in the directory: {output_dir}.")
print(
f"Ready to run TorchBench with benchmark. Result will be saved in the directory: {output_dir}."
)
# Identify the specified models and userbenchmarks
models, userbenchmarks = extract_models_from_pr(args.torchbench_path, args.pr_body)
models, userbenchmarks = extract_models_from_pr(
args.torchbench_path, args.pr_body
)
if models:
torchbench_config = gen_abtest_config(args.pr_base_sha, args.pr_head_sha, models)
torchbench_config = gen_abtest_config(
args.pr_base_sha, args.pr_head_sha, models
)
deploy_torchbench_config(output_dir, torchbench_config)
run_torchbench(pytorch_path=args.pytorch_path, torchbench_path=args.torchbench_path, output_dir=output_dir)
run_torchbench(
pytorch_path=args.pytorch_path,
torchbench_path=args.torchbench_path,
output_dir=output_dir,
)
if userbenchmarks:
assert len(userbenchmarks) == 1, \
"We don't support running multiple userbenchmarks in single workflow yet." \
assert len(userbenchmarks) == 1, (
"We don't support running multiple userbenchmarks in single workflow yet."
"If you need, please submit a feature request."
run_userbenchmarks(pytorch_path=args.pytorch_path, torchbench_path=args.torchbench_path,
base_sha=args.pr_base_sha, head_sha=args.pr_head_sha,
userbenchmark=userbenchmarks[0], output_dir=output_dir)
)
run_userbenchmarks(
pytorch_path=args.pytorch_path,
torchbench_path=args.torchbench_path,
base_sha=args.pr_base_sha,
head_sha=args.pr_head_sha,
userbenchmark=userbenchmarks[0],
output_dir=output_dir,
)
if not models and not userbenchmarks:
print("Can't parse valid models or userbenchmarks from the pr body. Quit.")
exit(-1)
elif args.command == 'upload-s3':
elif args.command == "upload-s3":
process_upload_s3(args.result_dir)
else:
print(f"The command {args.command} is not supported.")

View File

@ -1,30 +1,35 @@
"""test_check_labels.py"""
from typing import Any, List
from unittest import TestCase, mock, main
from unittest import main, mock, TestCase
from check_labels import (
main as check_labels_main,
add_label_err_comment,
delete_all_label_err_comments,
main as check_labels_main,
)
from github_utils import GitHubComment
from label_utils import BOT_AUTHORS, LABEL_ERR_MSG_TITLE
from test_trymerge import mocked_gh_graphql, mock_gh_get_info
from test_trymerge import mock_gh_get_info, mocked_gh_graphql
from trymerge import GitHubPR
def mock_parse_args() -> object:
class Object(object):
def __init__(self) -> None:
self.pr_num = 76123
return Object()
def mock_add_label_err_comment(pr: "GitHubPR") -> None:
pass
def mock_delete_all_label_err_comments(pr: "GitHubPR") -> None:
pass
def mock_get_comments() -> List[GitHubComment]:
return [
# Case 1 - a non label err comment
@ -49,9 +54,9 @@ def mock_get_comments() -> List[GitHubComment]:
class TestCheckLabels(TestCase):
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
@mock.patch('trymerge.GitHubPR.get_comments', return_value=[mock_get_comments()[0]])
@mock.patch('check_labels.gh_post_pr_comment')
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch("trymerge.GitHubPR.get_comments", return_value=[mock_get_comments()[0]])
@mock.patch("check_labels.gh_post_pr_comment")
def test_correctly_add_label_err_comment(
self, mock_gh_post_pr_comment: Any, mock_get_comments: Any, mock_gh_grphql: Any
) -> None:
@ -60,9 +65,9 @@ class TestCheckLabels(TestCase):
add_label_err_comment(pr)
mock_gh_post_pr_comment.assert_called_once()
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
@mock.patch('trymerge.GitHubPR.get_comments', return_value=[mock_get_comments()[1]])
@mock.patch('check_labels.gh_post_pr_comment')
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch("trymerge.GitHubPR.get_comments", return_value=[mock_get_comments()[1]])
@mock.patch("check_labels.gh_post_pr_comment")
def test_not_add_label_err_comment(
self, mock_gh_post_pr_comment: Any, mock_get_comments: Any, mock_gh_grphql: Any
) -> None:
@ -71,9 +76,9 @@ class TestCheckLabels(TestCase):
add_label_err_comment(pr)
mock_gh_post_pr_comment.assert_not_called()
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
@mock.patch('trymerge.GitHubPR.get_comments', return_value=mock_get_comments())
@mock.patch('check_labels.gh_delete_comment')
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch("trymerge.GitHubPR.get_comments", return_value=mock_get_comments())
@mock.patch("check_labels.gh_delete_comment")
def test_correctly_delete_all_label_err_comments(
self, mock_gh_delete_comment: Any, mock_get_comments: Any, mock_gh_grphql: Any
) -> None:
@ -82,11 +87,16 @@ class TestCheckLabels(TestCase):
delete_all_label_err_comments(pr)
mock_gh_delete_comment.assert_called_once_with("pytorch", "pytorch", 2)
@mock.patch('trymerge.gh_get_pr_info', return_value=mock_gh_get_info())
@mock.patch('check_labels.parse_args', return_value=mock_parse_args())
@mock.patch('check_labels.has_required_labels', return_value=False)
@mock.patch('check_labels.delete_all_label_err_comments', side_effect=mock_delete_all_label_err_comments)
@mock.patch('check_labels.add_label_err_comment', side_effect=mock_add_label_err_comment)
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
@mock.patch("check_labels.parse_args", return_value=mock_parse_args())
@mock.patch("check_labels.has_required_labels", return_value=False)
@mock.patch(
"check_labels.delete_all_label_err_comments",
side_effect=mock_delete_all_label_err_comments,
)
@mock.patch(
"check_labels.add_label_err_comment", side_effect=mock_add_label_err_comment
)
def test_ci_comments_and_exit0_without_required_labels(
self,
mock_add_label_err_comment: Any,
@ -101,11 +111,16 @@ class TestCheckLabels(TestCase):
mock_add_label_err_comment.assert_called_once()
mock_delete_all_label_err_comments.assert_not_called()
@mock.patch('trymerge.gh_get_pr_info', return_value=mock_gh_get_info())
@mock.patch('check_labels.parse_args', return_value=mock_parse_args())
@mock.patch('check_labels.has_required_labels', return_value=True)
@mock.patch('check_labels.delete_all_label_err_comments', side_effect=mock_delete_all_label_err_comments)
@mock.patch('check_labels.add_label_err_comment', side_effect=mock_add_label_err_comment)
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
@mock.patch("check_labels.parse_args", return_value=mock_parse_args())
@mock.patch("check_labels.has_required_labels", return_value=True)
@mock.patch(
"check_labels.delete_all_label_err_comments",
side_effect=mock_delete_all_label_err_comments,
)
@mock.patch(
"check_labels.add_label_err_comment", side_effect=mock_add_label_err_comment
)
def test_ci_exit0_with_required_labels(
self,
mock_add_label_err_comment: Any,
@ -120,5 +135,6 @@ class TestCheckLabels(TestCase):
mock_add_label_err_comment.assert_not_called()
mock_delete_all_label_err_comments.assert_called_once()
if __name__ == "__main__":
main()

View File

@ -1,5 +1,6 @@
from unittest import TestCase, main, mock
from typing import Any, List, Dict
from typing import Any, Dict, List
from unittest import main, mock, TestCase
from fetch_latest_green_commit import isGreen, WorkflowCheck
workflowNames = [
@ -15,46 +16,72 @@ workflowNames = [
"pr-labels",
"Close stale pull requests",
"Update S3 HTML indices for download.pytorch.org",
"Create Release"
"Create Release",
]
def set_workflow_job_status(workflow: List[Dict[str, Any]], name: str, status: str) -> List[Dict[str, Any]]:
def set_workflow_job_status(
workflow: List[Dict[str, Any]], name: str, status: str
) -> List[Dict[str, Any]]:
for check in workflow:
if check['workflowName'] == name:
check['conclusion'] = status
if check["workflowName"] == name:
check["conclusion"] = status
return workflow
class TestChecks:
def make_test_checks(self) -> List[Dict[str, Any]]:
workflow_checks = []
for i in range(len(workflowNames)):
workflow_checks.append(WorkflowCheck(
workflowName=workflowNames[i],
name="test/job",
jobName="job",
conclusion="success",
)._asdict())
workflow_checks.append(
WorkflowCheck(
workflowName=workflowNames[i],
name="test/job",
jobName="job",
conclusion="success",
)._asdict()
)
return workflow_checks
class TestPrintCommits(TestCase):
@mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks())
@mock.patch(
"fetch_latest_green_commit.get_commit_results",
return_value=TestChecks().make_test_checks(),
)
def test_all_successful(self, mock_get_commit_results: Any) -> None:
"Test with workflows are successful"
workflow_checks = mock_get_commit_results()
self.assertTrue(isGreen("sha", workflow_checks)[0])
@mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks())
@mock.patch(
"fetch_latest_green_commit.get_commit_results",
return_value=TestChecks().make_test_checks(),
)
def test_necessary_successful(self, mock_get_commit_results: Any) -> None:
"Test with necessary workflows are successful"
workflow_checks = mock_get_commit_results()
workflow_checks = set_workflow_job_status(workflow_checks, workflowNames[8], "failed")
workflow_checks = set_workflow_job_status(workflow_checks, workflowNames[9], "failed")
workflow_checks = set_workflow_job_status(workflow_checks, workflowNames[10], "failed")
workflow_checks = set_workflow_job_status(workflow_checks, workflowNames[11], "failed")
workflow_checks = set_workflow_job_status(workflow_checks, workflowNames[12], "failed")
workflow_checks = set_workflow_job_status(
workflow_checks, workflowNames[8], "failed"
)
workflow_checks = set_workflow_job_status(
workflow_checks, workflowNames[9], "failed"
)
workflow_checks = set_workflow_job_status(
workflow_checks, workflowNames[10], "failed"
)
workflow_checks = set_workflow_job_status(
workflow_checks, workflowNames[11], "failed"
)
workflow_checks = set_workflow_job_status(
workflow_checks, workflowNames[12], "failed"
)
self.assertTrue(isGreen("sha", workflow_checks)[0])
@mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks())
@mock.patch(
"fetch_latest_green_commit.get_commit_results",
return_value=TestChecks().make_test_checks(),
)
def test_necessary_skipped(self, mock_get_commit_results: Any) -> None:
"Test with necessary job (ex: pull) skipped"
workflow_checks = mock_get_commit_results()
@ -62,15 +89,25 @@ class TestPrintCommits(TestCase):
result = isGreen("sha", workflow_checks)
self.assertTrue(result[0])
@mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks())
@mock.patch(
"fetch_latest_green_commit.get_commit_results",
return_value=TestChecks().make_test_checks(),
)
def test_skippable_skipped(self, mock_get_commit_results: Any) -> None:
"Test with skippable jobs (periodic and docker-release-builds skipped"
workflow_checks = mock_get_commit_results()
workflow_checks = set_workflow_job_status(workflow_checks, "periodic", "skipped")
workflow_checks = set_workflow_job_status(workflow_checks, "docker-release-builds", "skipped")
workflow_checks = set_workflow_job_status(
workflow_checks, "periodic", "skipped"
)
workflow_checks = set_workflow_job_status(
workflow_checks, "docker-release-builds", "skipped"
)
self.assertTrue(isGreen("sha", workflow_checks))
@mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks())
@mock.patch(
"fetch_latest_green_commit.get_commit_results",
return_value=TestChecks().make_test_checks(),
)
def test_necessary_failed(self, mock_get_commit_results: Any) -> None:
"Test with necessary job (ex: Lint) failed"
workflow_checks = mock_get_commit_results()
@ -79,22 +116,33 @@ class TestPrintCommits(TestCase):
self.assertFalse(result[0])
self.assertEqual(result[1], "Lint checks were not successful")
@mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks())
@mock.patch(
"fetch_latest_green_commit.get_commit_results",
return_value=TestChecks().make_test_checks(),
)
def test_skippable_failed(self, mock_get_commit_results: Any) -> None:
"Test with failing skippable jobs (ex: docker-release-builds) should pass"
workflow_checks = mock_get_commit_results()
workflow_checks = set_workflow_job_status(workflow_checks, "periodic", "skipped")
workflow_checks = set_workflow_job_status(workflow_checks, "docker-release-builds", "failed")
workflow_checks = set_workflow_job_status(
workflow_checks, "periodic", "skipped"
)
workflow_checks = set_workflow_job_status(
workflow_checks, "docker-release-builds", "failed"
)
result = isGreen("sha", workflow_checks)
self.assertTrue(result[0])
@mock.patch('fetch_latest_green_commit.get_commit_results', return_value={})
@mock.patch("fetch_latest_green_commit.get_commit_results", return_value={})
def test_no_workflows(self, mock_get_commit_results: Any) -> None:
"Test with missing workflows"
workflow_checks = mock_get_commit_results()
result = isGreen("sha", workflow_checks)
self.assertFalse(result[0])
self.assertEqual(result[1], "missing required workflows: pull, trunk, lint, linux-binary, windows-binary")
self.assertEqual(
result[1],
"missing required workflows: pull, trunk, lint, linux-binary, windows-binary",
)
if __name__ == "__main__":
main()

View File

@ -1,7 +1,14 @@
#!/usr/bin/env python3
from gitutils import PeekableIterator, patterns_to_regex, GitRepo, are_ghstack_branches_in_sync, _shasum
from unittest import TestCase, main, SkipTest
from pathlib import Path
from unittest import main, SkipTest, TestCase
from gitutils import (
_shasum,
are_ghstack_branches_in_sync,
GitRepo,
patterns_to_regex,
PeekableIterator,
)
BASE_DIR = Path(__file__).parent
@ -15,6 +22,7 @@ class TestPeekableIterator(TestCase):
def test_is_iterable(self) -> None:
from collections.abc import Iterator
iter_ = PeekableIterator("")
self.assertTrue(isinstance(iter_, Iterator))
@ -35,7 +43,8 @@ class TestPattern(TestCase):
patterns_re = patterns_to_regex(allowed_patterns)
fnames = [
"aten/src/ATen/native/LinearAlgebra.cpp",
"aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp"]
"aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp",
]
for filename in fnames:
self.assertTrue(patterns_re.match(filename))
@ -44,11 +53,13 @@ class TestGitRepo(TestCase):
def setUp(self) -> None:
repo_dir = BASE_DIR.parent.parent.absolute()
if not (repo_dir / ".git").is_dir():
raise SkipTest("Can't find git directory, make sure to run this test on real repo checkout")
raise SkipTest(
"Can't find git directory, make sure to run this test on real repo checkout"
)
self.repo = GitRepo(str(repo_dir))
def _skip_if_ref_does_not_exist(self, ref: str) -> None:
""" Skip test if ref is missing as stale branches are deleted with time """
"""Skip test if ref is missing as stale branches are deleted with time"""
try:
self.repo.show_ref(ref)
except RuntimeError as e:
@ -69,5 +80,6 @@ class TestGitRepo(TestCase):
self._skip_if_ref_does_not_exist(head_ref)
self.assertFalse(are_ghstack_branches_in_sync(self.repo, head_ref))
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -1,29 +1,37 @@
from typing import Any
from unittest import TestCase, mock, main
from unittest import main, mock, TestCase
from label_utils import (
get_last_page_num_from_header,
gh_get_labels,
has_required_labels,
)
from trymerge import GitHubPR
from test_trymerge import mocked_gh_graphql
from trymerge import GitHubPR
release_notes_labels = [
"release notes: nn",
]
class TestLabelUtils(TestCase):
MOCK_HEADER_LINKS_TO_PAGE_NUMS = {
1: {"link": "<https://api.github.com/dummy/labels?per_page=10&page=1>; rel='last'"},
1: {
"link": "<https://api.github.com/dummy/labels?per_page=10&page=1>; rel='last'"
},
2: {"link": "<https://api.github.com/dummy/labels?per_page=1&page=2>;"},
3: {"link": "<https://api.github.com/dummy/labels?per_page=1&page=2&page=3>;"},
}
def test_get_last_page_num_from_header(self) -> None:
for expected_page_num, mock_header in self.MOCK_HEADER_LINKS_TO_PAGE_NUMS.items():
self.assertEqual(get_last_page_num_from_header(mock_header), expected_page_num)
for (
expected_page_num,
mock_header,
) in self.MOCK_HEADER_LINKS_TO_PAGE_NUMS.items():
self.assertEqual(
get_last_page_num_from_header(mock_header), expected_page_num
)
MOCK_LABEL_INFO = '[{"name": "foo"}]'
@ -49,23 +57,35 @@ class TestLabelUtils(TestCase):
gh_get_labels("foo", "bar")
self.assertIn("number of pages of labels", str(err.exception))
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
@mock.patch('label_utils.get_release_notes_labels', return_value=release_notes_labels)
def test_pr_with_missing_labels(self, mocked_rn_labels: Any, mocked_gql: Any) -> None:
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch(
"label_utils.get_release_notes_labels", return_value=release_notes_labels
)
def test_pr_with_missing_labels(
self, mocked_rn_labels: Any, mocked_gql: Any
) -> None:
"Test PR with no 'release notes:' label or 'topic: not user facing' label"
pr = GitHubPR("pytorch", "pytorch", 82169)
self.assertFalse(has_required_labels(pr))
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
@mock.patch('label_utils.get_release_notes_labels', return_value=release_notes_labels)
def test_pr_with_release_notes_label(self, mocked_rn_labels: Any, mocked_gql: Any) -> None:
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch(
"label_utils.get_release_notes_labels", return_value=release_notes_labels
)
def test_pr_with_release_notes_label(
self, mocked_rn_labels: Any, mocked_gql: Any
) -> None:
"Test PR with 'release notes: nn' label"
pr = GitHubPR("pytorch", "pytorch", 71759)
self.assertTrue(has_required_labels(pr))
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
@mock.patch('label_utils.get_release_notes_labels', return_value=release_notes_labels)
def test_pr_with_not_user_facing_label(self, mocked_rn_labels: Any, mocked_gql: Any) -> None:
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch(
"label_utils.get_release_notes_labels", return_value=release_notes_labels
)
def test_pr_with_not_user_facing_label(
self, mocked_rn_labels: Any, mocked_gql: Any
) -> None:
"Test PR with 'topic: not user facing' label"
pr = GitHubPR("pytorch", "pytorch", 75095)
self.assertTrue(has_required_labels(pr))

View File

@ -10,30 +10,32 @@
import json
import os
from hashlib import sha256
from trymerge import (
find_matching_merge_rule,
gh_graphql,
gh_get_team_members,
read_merge_rules,
validate_revert,
GitHubPR,
MergeRule,
MandatoryChecksMissingError,
PostCommentError,
FlakyRule,
categorize_checks,
get_rockset_results,
main as trymerge_main,
get_classifications,
)
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
from typing import Any, Dict, List, Optional
from unittest import TestCase, main, mock
from unittest import main, mock, TestCase
from urllib.error import HTTPError
if 'GIT_REMOTE_URL' not in os.environ:
os.environ['GIT_REMOTE_URL'] = "https://github.com/pytorch/pytorch"
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
from trymerge import (
categorize_checks,
find_matching_merge_rule,
FlakyRule,
get_classifications,
get_rockset_results,
gh_get_team_members,
gh_graphql,
GitHubPR,
main as trymerge_main,
MandatoryChecksMissingError,
MergeRule,
PostCommentError,
read_merge_rules,
validate_revert,
)
if "GIT_REMOTE_URL" not in os.environ:
os.environ["GIT_REMOTE_URL"] = "https://github.com/pytorch/pytorch"
def mock_query(
fallback_function: Any,
@ -67,8 +69,13 @@ def mock_query(
err_msg = f"If you are seeing this message during workflow run, please make sure to update {file_name}"
err_msg += f" locally, by deleting it and running {os.path.basename(__file__)} with "
err_msg += " GitHub Personal Access Token passed via GITHUB_TOKEN environment variable"
err_msg += " the rockset api key passed via ROCKSET_API_KEY environment variable"
if os.getenv("GITHUB_TOKEN") is None or os.getenv("ROCKSET_API_KEY") is None:
err_msg += (
" the rockset api key passed via ROCKSET_API_KEY environment variable"
)
if (
os.getenv("GITHUB_TOKEN") is None
or os.getenv("ROCKSET_API_KEY") is None
):
err_msg = (
"Failed to update cached GraphQL queries as GITHUB_TOKEN or ROCKSET_API_KEY is not defined."
+ err_msg
@ -89,8 +96,10 @@ def mocked_gh_graphql(query: str, **kwargs: Any) -> Any:
def gh_graphql_wrapper(query: str, kwargs: Any) -> Any:
return gh_graphql(query, **kwargs)
return mock_query(gh_graphql_wrapper, "gql_mocks.json", key_function, query, kwargs)
def mocked_rockset_results(head_sha: str, merge_base: str, num_retries: int = 3) -> Any:
return mock_query(
get_rockset_results,
@ -100,6 +109,7 @@ def mocked_rockset_results(head_sha: str, merge_base: str, num_retries: int = 3)
merge_base,
)
def mock_parse_args(revert: bool = False, force: bool = False) -> Any:
class Object(object):
def __init__(self) -> None:
@ -108,69 +118,87 @@ def mock_parse_args(revert: bool = False, force: bool = False) -> Any:
self.pr_num = 76123
self.dry_run = True
self.comment_id = 0
self.reason = 'this is for testing'
self.reason = "this is for testing"
self.ignore_current = False
return Object()
def mock_revert(repo: GitRepo, pr: GitHubPR, *,
dry_run: bool = False,
comment_id: Optional[int] = None,
reason: Optional[str] = None) -> None:
def mock_revert(
repo: GitRepo,
pr: GitHubPR,
*,
dry_run: bool = False,
comment_id: Optional[int] = None,
reason: Optional[str] = None,
) -> None:
pass
def mock_merge(pr_num: int, repo: GitRepo,
dry_run: bool = False,
skip_mandatory_checks: bool = False,
comment_id: Optional[int] = None,
timeout_minutes: int = 400,
stale_pr_days: int = 3,
ignore_current: bool = False) -> None:
def mock_merge(
pr_num: int,
repo: GitRepo,
dry_run: bool = False,
skip_mandatory_checks: bool = False,
comment_id: Optional[int] = None,
timeout_minutes: int = 400,
stale_pr_days: int = 3,
ignore_current: bool = False,
) -> None:
pass
def mock_gh_get_info() -> Any:
return {"closed": False,
"isCrossRepository": False,
"files": {"nodes": [], "pageInfo": {"hasNextPage": False}}, "changedFiles": 0}
return {
"closed": False,
"isCrossRepository": False,
"files": {"nodes": [], "pageInfo": {"hasNextPage": False}},
"changedFiles": 0,
}
def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> List[MergeRule]:
return [
MergeRule(name="mock with nonexistent check",
patterns=["*"],
approved_by=[],
mandatory_checks_name=["Lint",
"Facebook CLA Check",
"nonexistent"],
),
MergeRule(
name="mock with nonexistent check",
patterns=["*"],
approved_by=[],
mandatory_checks_name=["Lint", "Facebook CLA Check", "nonexistent"],
),
]
def mocked_read_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]:
return [
MergeRule(name="super",
patterns=["*"],
approved_by=["pytorch/metamates"],
mandatory_checks_name=["Lint",
"Facebook CLA Check",
"pull / linux-xenial-cuda11.3-py3.7-gcc7 / build",
],
),
MergeRule(
name="super",
patterns=["*"],
approved_by=["pytorch/metamates"],
mandatory_checks_name=[
"Lint",
"Facebook CLA Check",
"pull / linux-xenial-cuda11.3-py3.7-gcc7 / build",
],
),
]
def mocked_read_merge_rules_raise(repo: Any, org: str, project: str) -> List[MergeRule]:
raise RuntimeError("testing")
def empty_flaky_rules() -> List[FlakyRule]:
return []
def empty_rockset_results(head_sha: str, merge_base: str) -> List[Dict[str, Any]]:
return []
def dummy_merge_base() -> str:
return "dummy"
class DummyGitRepo(GitRepo):
def __init__(self) -> None:
super().__init__(get_git_repo_dir(), get_git_remote_name())
@ -185,7 +213,7 @@ class DummyGitRepo(GitRepo):
@mock.patch("trymerge.read_flaky_rules", side_effect=empty_flaky_rules)
@mock.patch("trymerge.get_rockset_results", side_effect=empty_rockset_results)
@mock.patch("trymerge.GitHubPR.get_merge_base", side_effect=dummy_merge_base)
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
class TestTryMerge(TestCase):
def test_merge_rules_valid(self, *args: Any) -> None:
"Test that merge_rules.yaml can be parsed"
@ -193,21 +221,23 @@ class TestTryMerge(TestCase):
merge_rules = read_merge_rules(repo, "pytorch", "pytorch")
self.assertGreater(len(merge_rules), 1)
@mock.patch('trymerge.read_merge_rules', side_effect=mocked_read_merge_rules)
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
def test_match_rules(self, *args: Any) -> None:
"Tests that PR passes merge rules"
pr = GitHubPR("pytorch", "pytorch", 77700)
repo = DummyGitRepo()
self.assertTrue(find_matching_merge_rule(pr, repo) is not None)
@mock.patch('trymerge.read_merge_rules', side_effect=mocked_read_merge_rules_raise)
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_raise)
def test_read_merge_rules_fails(self, *args: Any) -> None:
"Tests that PR fails to read the merge rules"
pr = GitHubPR("pytorch", "pytorch", 77700)
repo = DummyGitRepo()
self.assertRaisesRegex(RuntimeError, "testing", lambda: find_matching_merge_rule(pr, repo))
self.assertRaisesRegex(
RuntimeError, "testing", lambda: find_matching_merge_rule(pr, repo)
)
@mock.patch('trymerge.read_merge_rules', side_effect=mocked_read_merge_rules)
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
def test_lint_fails(self, *args: Any) -> None:
"Tests that PR fails mandatory lint check"
pr = GitHubPR("pytorch", "pytorch", 90791)
@ -223,8 +253,8 @@ class TestTryMerge(TestCase):
self.assertTrue("You've committed this PR" in comment.body_text)
def test_get_author_null(self, *args: Any) -> None:
""" Tests that PR author can be computed
If reply contains NULL
"""Tests that PR author can be computed
If reply contains NULL
"""
pr = GitHubPR("pytorch", "pytorch", 71759)
author = pr.get_author()
@ -239,8 +269,7 @@ class TestTryMerge(TestCase):
self.assertTrue(author is not None)
def test_last_pushed_at(self, *args: Any) -> None:
""" Tests that last_pushed_at will return None on merge commits.
"""
"""Tests that last_pushed_at will return None on merge commits."""
pr = GitHubPR("pytorch", "pytorch", 71759)
self.assertIsNotNone(pr.last_pushed_at())
@ -295,27 +324,26 @@ class TestTryMerge(TestCase):
self.assertEqual(len(non_existing_team), 0)
def test_get_author_many_commits(self, *args: Any) -> None:
""" Tests that authors for all commits can be fetched
"""
"""Tests that authors for all commits can be fetched"""
pr = GitHubPR("pytorch", "pytorch", 76118)
authors = pr.get_authors()
self.assertGreater(pr.get_commit_count(), 100)
self.assertGreater(len(authors), 50)
self.assertTrue("@" in pr.get_author())
@mock.patch('trymerge.read_merge_rules', side_effect=mocked_read_merge_rules_NE)
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_NE)
def test_pending_status_check(self, *args: Any) -> None:
""" Tests that PR with nonexistent/pending status checks fails with the right reason.
"""
"""Tests that PR with nonexistent/pending status checks fails with the right reason."""
pr = GitHubPR("pytorch", "pytorch", 76118)
repo = DummyGitRepo()
self.assertRaisesRegex(MandatoryChecksMissingError,
".*are pending/not yet run.*",
lambda: find_matching_merge_rule(pr, repo))
self.assertRaisesRegex(
MandatoryChecksMissingError,
".*are pending/not yet run.*",
lambda: find_matching_merge_rule(pr, repo),
)
def test_get_author_many_reviews(self, *args: Any) -> None:
""" Tests that all reviews can be fetched
"""
"""Tests that all reviews can be fetched"""
pr = GitHubPR("pytorch", "pytorch", 76123)
approved_by = pr.get_approved_by()
self.assertGreater(len(approved_by), 0)
@ -323,56 +351,62 @@ class TestTryMerge(TestCase):
self.assertGreater(len(pr._reviews), 100)
def test_get_checkruns_many_runs(self, *args: Any) -> None:
""" Tests that all checkruns can be fetched
"""
"""Tests that all checkruns can be fetched"""
pr = GitHubPR("pytorch", "pytorch", 77700)
conclusions = pr.get_checkrun_conclusions()
self.assertEqual(len(conclusions), 79)
self.assertTrue("pull / linux-docs / build-docs (cpp)" in conclusions.keys())
def test_cancelled_gets_ignored(self, *args: Any) -> None:
""" Tests that cancelled workflow does not override existing successfull status
"""
"""Tests that cancelled workflow does not override existing successfull status"""
pr = GitHubPR("pytorch", "pytorch", 82169)
conclusions = pr.get_checkrun_conclusions()
lint_checks = [name for name in conclusions.keys() if "Lint" in name]
self.assertTrue(len(lint_checks) > 0)
self.assertTrue(all([conclusions[name].status == "SUCCESS" for name in lint_checks]))
self.assertTrue(
all([conclusions[name].status == "SUCCESS" for name in lint_checks])
)
@mock.patch('trymerge.gh_get_pr_info', return_value=mock_gh_get_info())
@mock.patch('trymerge.parse_args', return_value=mock_parse_args(True, False))
@mock.patch('trymerge.try_revert', side_effect=mock_revert)
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
@mock.patch("trymerge.parse_args", return_value=mock_parse_args(True, False))
@mock.patch("trymerge.try_revert", side_effect=mock_revert)
def test_main_revert(self, mock_revert: Any, *args: Any) -> None:
trymerge_main()
mock_revert.assert_called_once()
@mock.patch('trymerge.gh_get_pr_info', return_value=mock_gh_get_info())
@mock.patch('trymerge.parse_args', return_value=mock_parse_args(False, True))
@mock.patch('trymerge.merge', side_effect=mock_merge)
def test_main_force(self, mock_merge: Any, mock_parse_args: Any, *args: Any) -> None:
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
@mock.patch("trymerge.parse_args", return_value=mock_parse_args(False, True))
@mock.patch("trymerge.merge", side_effect=mock_merge)
def test_main_force(
self, mock_merge: Any, mock_parse_args: Any, *args: Any
) -> None:
trymerge_main()
mock_merge.assert_called_once_with(mock.ANY,
mock.ANY,
dry_run=mock.ANY,
skip_mandatory_checks=True,
comment_id=mock.ANY,
ignore_current=False)
mock_merge.assert_called_once_with(
mock.ANY,
mock.ANY,
dry_run=mock.ANY,
skip_mandatory_checks=True,
comment_id=mock.ANY,
ignore_current=False,
)
@mock.patch('trymerge.gh_get_pr_info', return_value=mock_gh_get_info())
@mock.patch('trymerge.parse_args', return_value=mock_parse_args(False, False))
@mock.patch('trymerge.merge', side_effect=mock_merge)
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
@mock.patch("trymerge.parse_args", return_value=mock_parse_args(False, False))
@mock.patch("trymerge.merge", side_effect=mock_merge)
def test_main_merge(self, mock_merge: Any, *args: Any) -> None:
trymerge_main()
mock_merge.assert_called_once_with(mock.ANY,
mock.ANY,
dry_run=mock.ANY,
skip_mandatory_checks=False,
comment_id=mock.ANY,
ignore_current=False)
mock_merge.assert_called_once_with(
mock.ANY,
mock.ANY,
dry_run=mock.ANY,
skip_mandatory_checks=False,
comment_id=mock.ANY,
ignore_current=False,
)
@mock.patch('trymerge.read_merge_rules', side_effect=mocked_read_merge_rules)
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
def test_revert_rules(self, *args: Any) -> None:
""" Tests that reverts from collaborators are allowed """
"""Tests that reverts from collaborators are allowed"""
pr = GitHubPR("pytorch", "pytorch", 79694)
repo = DummyGitRepo()
self.assertIsNotNone(validate_revert(repo, pr, comment_id=1189459845))
@ -403,7 +437,11 @@ class TestTryMerge(TestCase):
return pr.get_body()
repo = GitRepoCoDev()
self.assertRaisesRegex(PostCommentError, "landed via phabricator", lambda: validate_revert(repo, pr, comment_id=1372496233))
self.assertRaisesRegex(
PostCommentError,
"landed via phabricator",
lambda: validate_revert(repo, pr, comment_id=1372496233),
)
def test_pr_changed_submodule_detection(self, *args: Any) -> None:
# Updates submodule during dev-cycle but reverts it later
@ -426,10 +464,14 @@ class TestTryMerge(TestCase):
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
class TestBypassFailures(TestCase):
def test_get_classifications(self, *args: Any) -> None:
flaky_rules = [FlakyRule("distributed", ["##[error]The operation was canceled."])]
flaky_rules = [
FlakyRule("distributed", ["##[error]The operation was canceled."])
]
pr = GitHubPR("pytorch", "pytorch", 92863)
checks = pr.get_checkrun_conclusions()
checks = get_classifications(checks, pr.last_commit()['oid'], pr.get_merge_base(), flaky_rules, [])
checks = get_classifications(
checks, pr.last_commit()["oid"], pr.get_merge_base(), flaky_rules, []
)
self.assertTrue(
checks[
"pull / linux-bionic-py3_7-clang8-xla / test (xla, 1, 1, linux.4xlarge)"
@ -442,10 +484,14 @@ class TestBypassFailures(TestCase):
].classification
== "FLAKY"
)
pending, failed = categorize_checks(checks, list(checks.keys()), ok_failed_checks_threshold=2)
pending, failed = categorize_checks(
checks, list(checks.keys()), ok_failed_checks_threshold=2
)
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 0)
pending, failed = categorize_checks(checks, list(checks.keys()), ok_failed_checks_threshold=1)
pending, failed = categorize_checks(
checks, list(checks.keys()), ok_failed_checks_threshold=1
)
self.assertTrue(len(pending) == 0)
self.assertTrue(len(failed) == 2)
@ -454,32 +500,52 @@ class TestBypassFailures(TestCase):
# ignore current checks takes precedence over classifications for flaky
# or broken trunk
flaky_rules = [FlakyRule("distributed", ["##[error]The operation was canceled."])]
flaky = "pull / linux-focal-py3.7-gcc7 / test (distributed, 1, 2, linux.2xlarge)"
broken_trunk = "pull / linux-bionic-py3_7-clang8-xla / test (xla, 1, 1, linux.4xlarge)"
flaky_rules = [
FlakyRule("distributed", ["##[error]The operation was canceled."])
]
flaky = (
"pull / linux-focal-py3.7-gcc7 / test (distributed, 1, 2, linux.2xlarge)"
)
broken_trunk = (
"pull / linux-bionic-py3_7-clang8-xla / test (xla, 1, 1, linux.4xlarge)"
)
pr = GitHubPR("pytorch", "pytorch", 92863)
checks = pr.get_checkrun_conclusions()
# No broken trunk or flaky rules
checks = get_classifications(checks, pr.last_commit()['oid'], None, [], [flaky])
checks = get_classifications(checks, pr.last_commit()["oid"], None, [], [flaky])
self.assertTrue(checks[flaky].classification == "IGNORE_CURRENT_CHECK")
self.assertTrue(checks[broken_trunk].classification is None)
_, failed = categorize_checks(checks, list(checks.keys()), ok_failed_checks_threshold=0)
_, failed = categorize_checks(
checks, list(checks.keys()), ok_failed_checks_threshold=0
)
self.assertTrue(len(failed) == 1)
# No flaky rules
checks = get_classifications(checks, pr.last_commit()['oid'], pr.get_merge_base(), [], [flaky])
checks = get_classifications(
checks, pr.last_commit()["oid"], pr.get_merge_base(), [], [flaky]
)
self.assertTrue(checks[flaky].classification == "IGNORE_CURRENT_CHECK")
self.assertTrue(checks[broken_trunk].classification == "BROKEN_TRUNK")
_, failed = categorize_checks(checks, list(checks.keys()), ok_failed_checks_threshold=1)
_, failed = categorize_checks(
checks, list(checks.keys()), ok_failed_checks_threshold=1
)
self.assertTrue(len(failed) == 0)
# No broken_trunk
checks = get_classifications(checks, pr.last_commit()['oid'], pr.get_merge_base(), flaky_rules, [broken_trunk])
checks = get_classifications(
checks,
pr.last_commit()["oid"],
pr.get_merge_base(),
flaky_rules,
[broken_trunk],
)
self.assertTrue(checks[flaky].classification == "FLAKY")
self.assertTrue(checks[broken_trunk].classification == "IGNORE_CURRENT_CHECK")
_, failed = categorize_checks(checks, list(checks.keys()), ok_failed_checks_threshold=1)
_, failed = categorize_checks(
checks, list(checks.keys()), ok_failed_checks_threshold=1
)
self.assertTrue(len(failed) == 0)

View File

@ -1,75 +1,129 @@
from unittest import TestCase, mock, main
from typing import Any
from unittest import main, mock, TestCase
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
from test_trymerge import mocked_gh_graphql
from trymerge import GitHubPR
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
from typing import Any
from tryrebase import rebase_onto, rebase_ghstack_onto
from tryrebase import rebase_ghstack_onto, rebase_onto
def mocked_rev_parse(branch: str) -> str:
return branch
class TestRebase(TestCase):
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
@mock.patch('gitutils.GitRepo._run_git')
@mock.patch('gitutils.GitRepo.rev_parse', side_effect=mocked_rev_parse)
@mock.patch('tryrebase.gh_post_comment')
def test_rebase(self, mocked_post_comment: Any, mocked_rp: Any, mocked_run_git: Any, mocked_gql: Any) -> None:
class TestRebase(TestCase):
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch("gitutils.GitRepo._run_git")
@mock.patch("gitutils.GitRepo.rev_parse", side_effect=mocked_rev_parse)
@mock.patch("tryrebase.gh_post_comment")
def test_rebase(
self,
mocked_post_comment: Any,
mocked_rp: Any,
mocked_run_git: Any,
mocked_gql: Any,
) -> None:
"Tests rebase successfully"
pr = GitHubPR("pytorch", "pytorch", 31093)
repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
rebase_onto(pr, repo, 'master')
calls = [mock.call('fetch', 'origin', 'pull/31093/head:pull/31093/head'),
mock.call('rebase', 'refs/remotes/origin/master', 'pull/31093/head'),
mock.call('push', '-f', 'https://github.com/mingxiaoh/pytorch.git', 'pull/31093/head:master')]
rebase_onto(pr, repo, "master")
calls = [
mock.call("fetch", "origin", "pull/31093/head:pull/31093/head"),
mock.call("rebase", "refs/remotes/origin/master", "pull/31093/head"),
mock.call(
"push",
"-f",
"https://github.com/mingxiaoh/pytorch.git",
"pull/31093/head:master",
),
]
mocked_run_git.assert_has_calls(calls)
self.assertTrue(
"Successfully rebased `master` onto `refs/remotes/origin/master`" in mocked_post_comment.call_args[0][3])
"Successfully rebased `master` onto `refs/remotes/origin/master`"
in mocked_post_comment.call_args[0][3]
)
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
@mock.patch('gitutils.GitRepo._run_git')
@mock.patch('gitutils.GitRepo.rev_parse', side_effect=mocked_rev_parse)
@mock.patch('tryrebase.gh_post_comment')
def test_rebase_to_stable(self, mocked_post_comment: Any, mocked_rp: Any, mocked_run_git: Any, mocked_gql: Any) -> None:
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch("gitutils.GitRepo._run_git")
@mock.patch("gitutils.GitRepo.rev_parse", side_effect=mocked_rev_parse)
@mock.patch("tryrebase.gh_post_comment")
def test_rebase_to_stable(
self,
mocked_post_comment: Any,
mocked_rp: Any,
mocked_run_git: Any,
mocked_gql: Any,
) -> None:
"Tests rebase to viable/strict successfully"
pr = GitHubPR("pytorch", "pytorch", 31093)
repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
rebase_onto(pr, repo, 'viable/strict', False)
calls = [mock.call('fetch', 'origin', 'pull/31093/head:pull/31093/head'),
mock.call('rebase', 'refs/remotes/origin/viable/strict', 'pull/31093/head'),
mock.call('push', '-f', 'https://github.com/mingxiaoh/pytorch.git', 'pull/31093/head:master')]
rebase_onto(pr, repo, "viable/strict", False)
calls = [
mock.call("fetch", "origin", "pull/31093/head:pull/31093/head"),
mock.call("rebase", "refs/remotes/origin/viable/strict", "pull/31093/head"),
mock.call(
"push",
"-f",
"https://github.com/mingxiaoh/pytorch.git",
"pull/31093/head:master",
),
]
mocked_run_git.assert_has_calls(calls)
self.assertTrue(
"Successfully rebased `master` onto `refs/remotes/origin/viable/strict`" in mocked_post_comment.call_args[0][3])
"Successfully rebased `master` onto `refs/remotes/origin/viable/strict`"
in mocked_post_comment.call_args[0][3]
)
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
@mock.patch('gitutils.GitRepo._run_git', return_value="Everything up-to-date")
@mock.patch('gitutils.GitRepo.rev_parse', side_effect=mocked_rev_parse)
@mock.patch('tryrebase.gh_post_comment')
def test_no_need_to_rebase(self, mocked_post_comment: Any, mocked_rp: Any, mocked_run_git: Any, mocked_gql: Any) -> None:
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch("gitutils.GitRepo._run_git", return_value="Everything up-to-date")
@mock.patch("gitutils.GitRepo.rev_parse", side_effect=mocked_rev_parse)
@mock.patch("tryrebase.gh_post_comment")
def test_no_need_to_rebase(
self,
mocked_post_comment: Any,
mocked_rp: Any,
mocked_run_git: Any,
mocked_gql: Any,
) -> None:
"Tests branch already up to date"
pr = GitHubPR("pytorch", "pytorch", 31093)
repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
rebase_onto(pr, repo, 'master')
calls = [mock.call('fetch', 'origin', 'pull/31093/head:pull/31093/head'),
mock.call('rebase', 'refs/remotes/origin/master', 'pull/31093/head'),
mock.call('push', '-f', 'https://github.com/mingxiaoh/pytorch.git', 'pull/31093/head:master')]
rebase_onto(pr, repo, "master")
calls = [
mock.call("fetch", "origin", "pull/31093/head:pull/31093/head"),
mock.call("rebase", "refs/remotes/origin/master", "pull/31093/head"),
mock.call(
"push",
"-f",
"https://github.com/mingxiaoh/pytorch.git",
"pull/31093/head:master",
),
]
mocked_run_git.assert_has_calls(calls)
self.assertTrue(
"Tried to rebase and push PR #31093, but it was already up to date" in mocked_post_comment.call_args[0][3])
"Tried to rebase and push PR #31093, but it was already up to date"
in mocked_post_comment.call_args[0][3]
)
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
@mock.patch('gitutils.GitRepo._run_git')
@mock.patch('gitutils.GitRepo.rev_parse', side_effect=lambda branch: "same sha")
@mock.patch('tryrebase.gh_post_comment')
def test_same_sha(self, mocked_post_comment: Any, mocked_rp: Any, mocked_run_git: Any, mocked_gql: Any) -> None:
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
@mock.patch("gitutils.GitRepo._run_git")
@mock.patch("gitutils.GitRepo.rev_parse", side_effect=lambda branch: "same sha")
@mock.patch("tryrebase.gh_post_comment")
def test_same_sha(
self,
mocked_post_comment: Any,
mocked_rp: Any,
mocked_run_git: Any,
mocked_gql: Any,
) -> None:
"Tests rebase results in same sha"
pr = GitHubPR("pytorch", "pytorch", 31093)
repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
with self.assertRaisesRegex(Exception, 'same sha as the target branch'):
rebase_onto(pr, repo, 'master')
with self.assertRaisesRegex(Exception, 'same sha as the target branch'):
rebase_ghstack_onto(pr, repo, 'master')
with self.assertRaisesRegex(Exception, "same sha as the target branch"):
rebase_onto(pr, repo, "master")
with self.assertRaisesRegex(Exception, "same sha as the target branch"):
rebase_ghstack_onto(pr, repo, "master")
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
import os
import re
from typing import List, Pattern, Optional, Tuple
from typing import List, Optional, Pattern, Tuple
BOT_COMMANDS_WIKI = "https://github.com/pytorch/pytorch/wiki/Bot-commands"
@ -10,9 +10,7 @@ CIFLOW_TRUNK_LABEL = re.compile(r"^ciflow/trunk")
OFFICE_HOURS_LINK = "https://github.com/pytorch/pytorch/wiki/Dev-Infra-Office-Hours"
CONTACT_US = f"Questions? Feedback? Please reach out to the [PyTorch DevX Team]({OFFICE_HOURS_LINK})"
ALTERNATIVES = (
f"Learn more about merging in the [wiki]({BOT_COMMANDS_WIKI})."
)
ALTERNATIVES = f"Learn more about merging in the [wiki]({BOT_COMMANDS_WIKI})."
def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
@ -46,31 +44,35 @@ class TryMergeExplainer(object):
self.project = project
self.ignore_current = ignore_current
def _get_flag_msg(self, ignore_current_checks: Optional[List[Tuple[str, Optional[str]]]] = None) -> str:
def _get_flag_msg(
self, ignore_current_checks: Optional[List[Tuple[str, Optional[str]]]] = None
) -> str:
if self.force:
return "Your change will be merged immediately since you used the force (-f) flag, " + \
"**bypassing any CI checks** (ETA: 1-5 minutes)."
return (
"Your change will be merged immediately since you used the force (-f) flag, "
+ "**bypassing any CI checks** (ETA: 1-5 minutes)."
)
elif self.ignore_current and ignore_current_checks is not None:
msg = f"Your change will be merged while ignoring the following {len(ignore_current_checks)} checks: "
msg += ', '.join(f"[{x[0]}]({x[1]})" for x in ignore_current_checks)
msg += ", ".join(f"[{x[0]}]({x[1]})" for x in ignore_current_checks)
return msg
else:
return "Your change will be merged once all checks pass (ETA 0-4 Hours)."
def get_merge_message(
self,
ignore_current_checks: Optional[List[Tuple[str, Optional[str]]]] = None
self, ignore_current_checks: Optional[List[Tuple[str, Optional[str]]]] = None
) -> str:
title = "### Merge started"
main_message = self._get_flag_msg(ignore_current_checks)
advanced_debugging = "\n".join((
"<details><summary>Advanced Debugging</summary>",
"Check the merge workflow status ",
f"<a href=\"{os.getenv('GH_RUN_URL')}\">here</a>",
"</details>"
))
advanced_debugging = "\n".join(
(
"<details><summary>Advanced Debugging</summary>",
"Check the merge workflow status ",
f"<a href=\"{os.getenv('GH_RUN_URL')}\">here</a>",
"</details>",
)
)
msg = title + "\n"
msg += main_message + "\n\n"

View File

@ -1,21 +1,24 @@
#!/usr/bin/env python3
import os
import re
import subprocess
import sys
import re
from typing import Any
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
from github_utils import gh_post_pr_comment as gh_post_comment
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
from trymerge import GitHubPR
SAME_SHA_ERROR = (
"\n```\nAborting rebase because rebasing the branch resulted in the same sha as the target branch.\n" +
"This usually happens because the PR has already been merged. Please rebase locally and push.\n```"
"\n```\nAborting rebase because rebasing the branch resulted in the same sha as the target branch.\n"
+ "This usually happens because the PR has already been merged. Please rebase locally and push.\n```"
)
def parse_args() -> Any:
from argparse import ArgumentParser
parser = ArgumentParser("Rebase PR into branch")
parser.add_argument("--dry-run", action="store_true")
parser.add_argument("--branch", type=str)
@ -23,7 +26,9 @@ def parse_args() -> Any:
return parser.parse_args()
def rebase_onto(pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run: bool = False) -> None:
def rebase_onto(
pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run: bool = False
) -> None:
branch = f"pull/{pr.pr_num}/head"
onto_branch = f"refs/remotes/origin/{onto_branch}"
remote_url = f"https://github.com/{pr.info['headRepository']['nameWithOwner']}.git"
@ -40,17 +45,34 @@ def rebase_onto(pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run: bool = F
else:
push_result = repo._run_git("push", "-f", remote_url, refspec)
if "Everything up-to-date" in push_result:
gh_post_comment(pr.org, pr.project, pr.pr_num,
f"Tried to rebase and push PR #{pr.pr_num}, but it was already up to date", dry_run=dry_run)
gh_post_comment(
pr.org,
pr.project,
pr.pr_num,
f"Tried to rebase and push PR #{pr.pr_num}, but it was already up to date",
dry_run=dry_run,
)
else:
gh_post_comment(pr.org, pr.project, pr.pr_num,
f"Successfully rebased `{pr.head_ref()}` onto `{onto_branch}`, please pull locally " +
f"before adding more changes (for example, via `git checkout {pr.head_ref()} && " +
"git pull --rebase`)", dry_run=dry_run)
gh_post_comment(
pr.org,
pr.project,
pr.pr_num,
f"Successfully rebased `{pr.head_ref()}` onto `{onto_branch}`, please pull locally "
+ f"before adding more changes (for example, via `git checkout {pr.head_ref()} && "
+ "git pull --rebase`)",
dry_run=dry_run,
)
def rebase_ghstack_onto(pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run: bool = False) -> None:
if subprocess.run([sys.executable, "-m", "ghstack", "--help"], capture_output=True).returncode != 0:
def rebase_ghstack_onto(
pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run: bool = False
) -> None:
if (
subprocess.run(
[sys.executable, "-m", "ghstack", "--help"], capture_output=True
).returncode
!= 0
):
subprocess.run([sys.executable, "-m", "pip", "install", "ghstack"])
orig_ref = f"{re.sub(r'/head$', '/orig', pr.head_ref())}"
onto_branch = f"refs/remotes/origin/{onto_branch}"
@ -68,11 +90,13 @@ def rebase_ghstack_onto(pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run:
repo._run_git("config", "--global", "user.name", name)
os.environ["OAUTH_TOKEN"] = os.environ["GITHUB_TOKEN"]
with open('.ghstackrc', 'w+') as f:
f.write('[ghstack]\n' +
"github_url=github.com\n" +
"github_username=pytorchmergebot\n" +
"remote_name=origin")
with open(".ghstackrc", "w+") as f:
f.write(
"[ghstack]\n"
+ "github_url=github.com\n"
+ "github_username=pytorchmergebot\n"
+ "remote_name=origin"
)
if dry_run:
print("Don't know how to dry-run ghstack")
@ -102,19 +126,37 @@ def rebase_ghstack_onto(pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run:
if "Updated" in line:
pr_num = int(line.split("/")[-1])
if pr_num != pr.pr_num:
gh_post_comment(pr.org, pr.project, pr_num,
f"Rebased `{orig_ref}` onto `{onto_branch}` because #{pr.pr_num} was rebased, "
"please pull locally before adding more changes (for example, via `ghstack " +
f"checkout https://github.com/{org}/{project}/pull/{pr_num}`)", dry_run=dry_run)
gh_post_comment(
pr.org,
pr.project,
pr_num,
f"Rebased `{orig_ref}` onto `{onto_branch}` because #{pr.pr_num} was rebased, "
"please pull locally before adding more changes (for example, via `ghstack "
+ f"checkout https://github.com/{org}/{project}/pull/{pr_num}`)",
dry_run=dry_run,
)
else:
gh_post_comment(pr.org, pr.project, pr_num,
f"Successfully rebased `{orig_ref}` onto `{onto_branch}`, please pull locally " +
"before adding more changes (for example, via `ghstack " +
f"checkout https://github.com/{org}/{project}/pull/{pr.pr_num}`)", dry_run=dry_run)
gh_post_comment(
pr.org,
pr.project,
pr_num,
f"Successfully rebased `{orig_ref}` onto `{onto_branch}`, please pull locally "
+ "before adding more changes (for example, via `ghstack "
+ f"checkout https://github.com/{org}/{project}/pull/{pr.pr_num}`)",
dry_run=dry_run,
)
if f"Skipped https://github.com/{org}/{project}/pull/{pr.pr_num}" in push_result:
gh_post_comment(pr.org, pr.project, pr.pr_num,
f"Tried to rebase and push PR #{pr.pr_num}, but it was already up to date", dry_run=dry_run)
if (
f"Skipped https://github.com/{org}/{project}/pull/{pr.pr_num}"
in push_result
):
gh_post_comment(
pr.org,
pr.project,
pr.pr_num,
f"Tried to rebase and push PR #{pr.pr_num}, but it was already up to date",
dry_run=dry_run,
)
def main() -> None:
@ -130,7 +172,13 @@ def main() -> None:
gh_post_comment(org, project, args.pr_num, msg, dry_run=args.dry_run)
if pr.is_closed():
gh_post_comment(org, project, args.pr_num, f"PR #{args.pr_num} is closed, won't rebase", dry_run=args.dry_run)
gh_post_comment(
org,
project,
args.pr_num,
f"PR #{args.pr_num} is closed, won't rebase",
dry_run=args.dry_run,
)
return
try:

View File

@ -1,9 +1,10 @@
import json
import os
import subprocess
import requests
from typing import Any, Dict
from argparse import ArgumentParser
from typing import Any, Dict
import requests
MERGEBOT_TOKEN = os.environ["MERGEBOT_TOKEN"]
PYTORCHBOT_TOKEN = os.environ["PYTORCHBOT_TOKEN"]

View File

@ -826,6 +826,8 @@ init_command = [
[[linter]]
code = 'UFMT'
include_patterns = [
'.github/**/*.py',
'test/run_test.py',
'test/onnx/**/*.py',
'test/test_dynamo_cudagraphs.py',
'tools/**/*.py',

View File

@ -2,9 +2,9 @@
import argparse
import copy
from datetime import datetime
from distutils.version import LooseVersion
import functools
import glob
import json
import os
import pathlib
import shutil
@ -12,23 +12,23 @@ import signal
import subprocess
import sys
import tempfile
import json
import glob
from typing import Dict, Optional, List, cast, Any
from datetime import datetime
from distutils.version import LooseVersion
from typing import Any, cast, Dict, List, Optional
import torch
from torch.utils import cpp_extension
from torch.testing._internal.common_utils import (
IS_CI,
FILE_SCHEMA,
TEST_WITH_ROCM,
shell,
set_cwd,
parser as common_parser,
is_slow_gradcheck_env,
)
import torch.distributed as dist
from torch.multiprocessing import current_process, get_context
from torch.testing._internal.common_utils import (
FILE_SCHEMA,
IS_CI,
is_slow_gradcheck_env,
parser as common_parser,
set_cwd,
shell,
TEST_WITH_ROCM,
)
from torch.utils import cpp_extension
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
@ -37,11 +37,12 @@ try:
sys.path.append(str(REPO_ROOT))
from tools.stats.export_test_times import TEST_TIMES_FILE
from tools.testing.test_selections import (
calculate_shards,
get_reordered_tests,
get_test_case_configs,
calculate_shards,
NUM_PROCS
NUM_PROCS,
)
HAVE_TEST_SELECTION_TOOLS = True
except ImportError:
HAVE_TEST_SELECTION_TOOLS = False
@ -65,9 +66,9 @@ def maybe_set_hip_visible_devies():
# Special handling of ROCm GHA runners for parallel (file granularity) tests.
if torch.version.hip:
p = current_process()
if p.name != 'MainProcess':
if p.name != "MainProcess":
# this is a Process from a parallel Pool, not the MainProcess
os.environ['HIP_VISIBLE_DEVICES'] = str(p._identity[0] % NUM_PROCS)
os.environ["HIP_VISIBLE_DEVICES"] = str(p._identity[0] % NUM_PROCS)
def strtobool(s):
@ -77,13 +78,15 @@ def strtobool(s):
def discover_tests(
base_dir: Optional[pathlib.Path] = None,
blocklisted_patterns: Optional[List[str]] = None,
blocklisted_tests: Optional[List[str]] = None,
extra_tests: Optional[List[str]] = None) -> List[str]:
base_dir: Optional[pathlib.Path] = None,
blocklisted_patterns: Optional[List[str]] = None,
blocklisted_tests: Optional[List[str]] = None,
extra_tests: Optional[List[str]] = None,
) -> List[str]:
"""
Searches for all python files starting with test_ excluding one specified by patterns
"""
def skip_test_p(name: str) -> bool:
rc = False
if blocklisted_patterns is not None:
@ -91,13 +94,16 @@ def discover_tests(
if blocklisted_tests is not None:
rc |= name in blocklisted_tests
return rc
cwd = pathlib.Path(__file__).resolve().parent if base_dir is None else base_dir
# This supports symlinks, so we can link domain library tests to PyTorch test directory
all_py_files = [pathlib.Path(p) for p in glob.glob(f"{cwd}/**/test_*.py", recursive=True)]
all_py_files = [
pathlib.Path(p) for p in glob.glob(f"{cwd}/**/test_*.py", recursive=True)
]
rc = [str(fname.relative_to(cwd))[:-3] for fname in all_py_files]
# Invert slashes on Windows
if sys.platform == "win32":
rc = [name.replace('\\', '/') for name in rc]
rc = [name.replace("\\", "/") for name in rc]
rc = [test for test in rc if not skip_test_p(test)]
if extra_tests is not None:
rc += extra_tests
@ -106,31 +112,31 @@ def discover_tests(
TESTS = discover_tests(
blocklisted_patterns=[
'ao',
'bottleneck_test',
'custom_backend',
'custom_operator',
'fx', # executed by test_fx.py
'jit', # executed by test_jit.py
'mobile',
'onnx',
'package', # executed by test_package.py
'quantization', # executed by test_quantization.py
'autograd', # executed by test_autograd.py
"ao",
"bottleneck_test",
"custom_backend",
"custom_operator",
"fx", # executed by test_fx.py
"jit", # executed by test_jit.py
"mobile",
"onnx",
"package", # executed by test_package.py
"quantization", # executed by test_quantization.py
"autograd", # executed by test_autograd.py
],
blocklisted_tests=[
'test_bundled_images',
'test_cpp_extensions_aot',
'test_determination',
'test_jit_fuser',
'test_jit_simple',
'test_jit_string',
'test_kernel_launch_checks',
'test_nnapi',
'test_segment_reductions',
'test_static_runtime',
'test_throughput_benchmark',
'test_typing',
"test_bundled_images",
"test_cpp_extensions_aot",
"test_determination",
"test_jit_fuser",
"test_jit_simple",
"test_jit_string",
"test_kernel_launch_checks",
"test_nnapi",
"test_segment_reductions",
"test_static_runtime",
"test_throughput_benchmark",
"test_typing",
"distributed/bin/test_script",
"distributed/elastic/multiprocessing/bin/test_script",
"distributed/launcher/bin/test_script",
@ -138,8 +144,8 @@ TESTS = discover_tests(
"distributed/launcher/bin/test_script_is_torchelastic_launched",
"distributed/launcher/bin/test_script_local_rank",
"distributed/test_c10d_spawn",
'distributions/test_transforms',
'distributions/test_utils',
"distributions/test_transforms",
"distributions/test_utils",
],
extra_tests=[
"test_cpp_extensions_aot_ninja",
@ -153,12 +159,12 @@ TESTS = discover_tests(
"distributed/elastic/utils/util_test",
"distributed/elastic/utils/distributed_test",
"distributed/elastic/multiprocessing/api_test",
]
],
)
# The doctests are a special case that don't correspond to a file that discover
# tests can enable.
TESTS = TESTS + ['doctests']
TESTS = TESTS + ["doctests"]
FSDP_TEST = [test for test in TESTS if test.startswith("distributed/fsdp")]
@ -243,34 +249,34 @@ RUN_PARALLEL_BLOCKLIST = [
] + FSDP_TEST
CI_SERIAL_LIST = [
'test_nn',
'test_fake_tensor',
'test_cpp_api_parity',
'test_reductions',
'test_cuda',
'test_jit_cuda_fuser', # OOM on test_issue_1785, also profiling?
'test_indexing',
'test_fx_backends',
'test_linalg',
'test_cpp_extensions_jit',
'test_torch',
'test_tensor_creation_ops',
'test_sparse_csr',
'test_dispatch',
'test_spectral_ops', # Cause CUDA illegal memory access https://github.com/pytorch/pytorch/issues/88916
'nn/test_pooling',
'nn/test_convolution', # Doesn't respect set_per_process_memory_fraction, results in OOM for other tests in slow gradcheck
'distributions/test_distributions',
'test_autograd', # slow gradcheck runs a test that checks the cuda memory allocator
'test_prims', # slow gradcheck runs a test that checks the cuda memory allocator
'test_modules', # failed test due to mismatched elements
'functorch/test_vmap', # OOM
'test_fx', # gets SIGKILL
'test_dataloader', # frequently hangs for ROCm
'test_serialization', # test_serialization_2gb_file allocates a tensor of 2GB, and could cause OOM
'_nvfuser/test_torchscript', # OOM on test_issue_1785
'test_schema_check', # Cause CUDA illegal memory access https://github.com/pytorch/pytorch/issues/95749
'functorch/test_memory_efficient_fusion', # Cause CUDA OOM on ROCm
"test_nn",
"test_fake_tensor",
"test_cpp_api_parity",
"test_reductions",
"test_cuda",
"test_jit_cuda_fuser", # OOM on test_issue_1785, also profiling?
"test_indexing",
"test_fx_backends",
"test_linalg",
"test_cpp_extensions_jit",
"test_torch",
"test_tensor_creation_ops",
"test_sparse_csr",
"test_dispatch",
"test_spectral_ops", # Cause CUDA illegal memory access https://github.com/pytorch/pytorch/issues/88916
"nn/test_pooling",
"nn/test_convolution", # Doesn't respect set_per_process_memory_fraction, results in OOM for other tests in slow gradcheck
"distributions/test_distributions",
"test_autograd", # slow gradcheck runs a test that checks the cuda memory allocator
"test_prims", # slow gradcheck runs a test that checks the cuda memory allocator
"test_modules", # failed test due to mismatched elements
"functorch/test_vmap", # OOM
"test_fx", # gets SIGKILL
"test_dataloader", # frequently hangs for ROCm
"test_serialization", # test_serialization_2gb_file allocates a tensor of 2GB, and could cause OOM
"_nvfuser/test_torchscript", # OOM on test_issue_1785
"test_schema_check", # Cause CUDA illegal memory access https://github.com/pytorch/pytorch/issues/95749
"functorch/test_memory_efficient_fusion", # Cause CUDA OOM on ROCm
]
# A subset of our TEST list that validates PyTorch's ops, modules, and autograd function as expected
@ -282,7 +288,7 @@ CORE_TEST_LIST = [
"test_ops_gradients",
"test_ops_fwd_gradients",
"test_ops_jit",
"test_torch"
"test_torch",
]
# A list of distributed tests that run on multiple backends, i.e. gloo, nccl. These backends are spread out
@ -377,6 +383,7 @@ TESTS_NOT_USING_GRADCHECK = [
"test_quantization",
]
def print_to_stderr(message):
print(message, file=sys.stderr)
@ -421,8 +428,10 @@ def run_test(
unittest_args.extend(ci_args)
if test_module in PYTEST_SKIP_RETRIES:
if not options.pytest:
raise RuntimeError("A test running without pytest cannot skip retries using "
"the PYTEST_SKIP_RETRIES set.")
raise RuntimeError(
"A test running without pytest cannot skip retries using "
"the PYTEST_SKIP_RETRIES set."
)
unittest_args = [arg for arg in unittest_args if "--reruns" not in arg]
# Extra arguments are not supported with pytest
@ -433,9 +442,11 @@ def run_test(
argv = [test_module + ".py"] + unittest_args
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
log_fd, log_path = tempfile.mkstemp(dir=REPO_ROOT / "test" / "test-reports",
prefix="{}_".format(test_module.replace("\\", "-").replace("/", "-")),
suffix=".log")
log_fd, log_path = tempfile.mkstemp(
dir=REPO_ROOT / "test" / "test-reports",
prefix="{}_".format(test_module.replace("\\", "-").replace("/", "-")),
suffix=".log",
)
os.close(log_fd)
command = (launcher_cmd or []) + executable + argv
print_to_stderr("Executing {} ... [{}]".format(command, datetime.now()))
@ -452,11 +463,15 @@ def test_cuda_primary_ctx(test_module, test_directory, options):
)
run_test_with_subprocess = functools.partial(run_test, extra_unittest_args=["--subprocess"])
run_test_with_subprocess = functools.partial(
run_test, extra_unittest_args=["--subprocess"]
)
def get_run_test_with_subprocess_fn():
return lambda test_module, test_directory, options: run_test_with_subprocess(test_module, test_directory, options)
return lambda test_module, test_directory, options: run_test_with_subprocess(
test_module, test_directory, options
)
def _test_cpp_extensions_aot(test_directory, options, use_ninja):
@ -493,7 +508,7 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
python_path = os.environ.get("PYTHONPATH", "")
from shutil import copyfile
os.environ['USE_NINJA'] = shell_env['USE_NINJA']
os.environ["USE_NINJA"] = shell_env["USE_NINJA"]
test_module = "test_cpp_extensions_aot" + ("_ninja" if use_ninja else "_no_ninja")
copyfile(
test_directory + "/test_cpp_extensions_aot.py",
@ -515,7 +530,7 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
os.environ["PYTHONPATH"] = python_path
if os.path.exists(test_directory + "/" + test_module + ".py"):
os.remove(test_directory + "/" + test_module + ".py")
os.environ.pop('USE_NINJA')
os.environ.pop("USE_NINJA")
def test_cpp_extensions_aot_ninja(test_module, test_directory, options):
@ -539,9 +554,15 @@ def test_distributed(test_module, test_directory, options):
else:
which_shard = num_shards = 1
# Round-robin all backends to different shards
backend_to_shard = {backend: i % num_shards + 1
for i, backend in enumerate(DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS[test_module])}
print_to_stderr(f"Map different backends to different shards for {test_module}: {backend_to_shard}")
backend_to_shard = {
backend: i % num_shards + 1
for i, backend in enumerate(
DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS[test_module]
)
}
print_to_stderr(
f"Map different backends to different shards for {test_module}: {backend_to_shard}"
)
config = DISTRIBUTED_TESTS_CONFIG
for backend, env_vars in config.items():
@ -551,7 +572,9 @@ def test_distributed(test_module, test_directory, options):
continue
# Default to the first shard if seeing an unrecognized backend
if which_shard != backend_to_shard.get(backend, 1):
print_to_stderr(f"Shard {which_shard}: {backend} should be run in {backend_to_shard.get(backend, 1)}")
print_to_stderr(
f"Shard {which_shard}: {backend} should be run in {backend_to_shard.get(backend, 1)}"
)
continue
for with_init_file in {True, False}:
if sys.platform == "win32" and not with_init_file:
@ -611,7 +634,12 @@ def test_distributed(test_module, test_directory, options):
test_module, test_directory, options, launcher_cmd=mpiexec
)
else:
return_code = run_test(test_module, test_directory, options, extra_unittest_args=["--subprocess"])
return_code = run_test(
test_module,
test_directory,
options,
extra_unittest_args=["--subprocess"],
)
if return_code != 0:
return return_code
finally:
@ -626,8 +654,10 @@ def run_doctests(test_module, test_directory, options):
Assumes the incoming test module is called doctest, and simply executes the
xdoctest runner on the torch library itself.
"""
import xdoctest
import pathlib
import xdoctest
pkgpath = pathlib.Path(torch.__file__).parent
exclude_module_list = []
@ -638,42 +668,47 @@ def run_doctests(test_module, test_directory, options):
# 'cuda': 'auto',
# 'cuda1': 'auto',
# 'qengine': 'auto',
'lapack': 0,
'cuda': 0,
'cuda1': 0,
'qengine': 0,
'autograd_profiler': 0,
'cpp_ext': 0,
'monitor': 0,
"lapack": 0,
"cuda": 0,
"cuda1": 0,
"qengine": 0,
"autograd_profiler": 0,
"cpp_ext": 0,
"monitor": 0,
"onnx": "auto",
}
# Resolve "auto" based on a test to determine if the feature is available.
if enabled['cuda'] == 'auto' and torch.cuda.is_available():
enabled['cuda'] = True
if enabled["cuda"] == "auto" and torch.cuda.is_available():
enabled["cuda"] = True
if enabled['cuda1'] == 'auto' and torch.cuda.is_available() and torch.cuda.device_count() > 1:
enabled['cuda1'] = True
if (
enabled["cuda1"] == "auto"
and torch.cuda.is_available()
and torch.cuda.device_count() > 1
):
enabled["cuda1"] = True
if enabled['lapack'] == 'auto' and torch._C.has_lapack:
enabled['lapack'] = True
if enabled["lapack"] == "auto" and torch._C.has_lapack:
enabled["lapack"] = True
if enabled['qengine'] == 'auto':
if enabled["qengine"] == "auto":
try:
# Is there a better check if quantization is enabled?
import torch.ao.nn.quantized as nnq # NOQA
torch.backends.quantized.engine = 'qnnpack'
torch.backends.quantized.engine = 'fbgemm'
torch.backends.quantized.engine = "qnnpack"
torch.backends.quantized.engine = "fbgemm"
except (ImportError, RuntimeError):
...
else:
enabled['qengine'] = True
enabled["qengine"] = True
if enabled["onnx"] == "auto":
try:
import onnx # NOQA
import onnxscript # NOQA
import onnxruntime # NOQA
import onnxscript # NOQA
except ImportError:
exclude_module_list.append("torch.onnx._internal.fx.*")
enabled["onnx"] = False
@ -681,69 +716,79 @@ def run_doctests(test_module, test_directory, options):
enabled["onnx"] = True
# Set doctest environment variables
if enabled['cuda']:
os.environ['TORCH_DOCTEST_CUDA'] = '1'
if enabled["cuda"]:
os.environ["TORCH_DOCTEST_CUDA"] = "1"
if enabled['cuda1']:
os.environ['TORCH_DOCTEST_CUDA1'] = '1'
if enabled["cuda1"]:
os.environ["TORCH_DOCTEST_CUDA1"] = "1"
if enabled['lapack']:
os.environ['TORCH_DOCTEST_LAPACK'] = '1'
if enabled["lapack"]:
os.environ["TORCH_DOCTEST_LAPACK"] = "1"
if enabled['qengine']:
os.environ['TORCH_DOCTEST_QENGINE'] = '1'
if enabled["qengine"]:
os.environ["TORCH_DOCTEST_QENGINE"] = "1"
if enabled['autograd_profiler']:
os.environ['TORCH_DOCTEST_AUTOGRAD_PROFILER'] = '1'
if enabled["autograd_profiler"]:
os.environ["TORCH_DOCTEST_AUTOGRAD_PROFILER"] = "1"
if enabled['cpp_ext']:
os.environ['TORCH_DOCTEST_CPP_EXT'] = '1'
if enabled["cpp_ext"]:
os.environ["TORCH_DOCTEST_CPP_EXT"] = "1"
if enabled['monitor']:
os.environ['TORCH_DOCTEST_MONITOR'] = '1'
if enabled["monitor"]:
os.environ["TORCH_DOCTEST_MONITOR"] = "1"
if enabled["onnx"]:
os.environ['TORCH_DOCTEST_ONNX'] = '1'
os.environ["TORCH_DOCTEST_ONNX"] = "1"
if 0:
# TODO: could try to enable some of these
os.environ['TORCH_DOCTEST_QUANTIZED_DYNAMIC'] = '1'
os.environ['TORCH_DOCTEST_ANOMOLY'] = '1'
os.environ['TORCH_DOCTEST_AUTOGRAD'] = '1'
os.environ['TORCH_DOCTEST_HUB'] = '1'
os.environ['TORCH_DOCTEST_DATALOADER'] = '1'
os.environ['TORCH_DOCTEST_FUTURES'] = '1'
os.environ["TORCH_DOCTEST_QUANTIZED_DYNAMIC"] = "1"
os.environ["TORCH_DOCTEST_ANOMOLY"] = "1"
os.environ["TORCH_DOCTEST_AUTOGRAD"] = "1"
os.environ["TORCH_DOCTEST_HUB"] = "1"
os.environ["TORCH_DOCTEST_DATALOADER"] = "1"
os.environ["TORCH_DOCTEST_FUTURES"] = "1"
pkgpath = os.path.dirname(torch.__file__)
xdoctest_config = {
'global_exec': r'\n'.join([
'from torch import nn',
'import torch.nn.functional as F',
'import torch',
]),
'analysis': 'static', # set to "auto" to test doctests in compiled modules
'style': 'google',
'options': '+IGNORE_WHITESPACE',
"global_exec": r"\n".join(
[
"from torch import nn",
"import torch.nn.functional as F",
"import torch",
]
),
"analysis": "static", # set to "auto" to test doctests in compiled modules
"style": "google",
"options": "+IGNORE_WHITESPACE",
}
xdoctest_verbose = max(1, options.verbose)
run_summary = xdoctest.runner.doctest_module(
os.fspath(pkgpath), config=xdoctest_config, verbose=xdoctest_verbose,
command=options.xdoctest_command, argv=[],
exclude=exclude_module_list)
result = 1 if run_summary.get('n_failed', 0) else 0
os.fspath(pkgpath),
config=xdoctest_config,
verbose=xdoctest_verbose,
command=options.xdoctest_command,
argv=[],
exclude=exclude_module_list,
)
result = 1 if run_summary.get("n_failed", 0) else 0
return result
def print_log_file(test: str, file_path: str, failed: bool) -> None:
num_lines = sum(1 for _ in open(file_path, 'rb'))
num_lines = sum(1 for _ in open(file_path, "rb"))
n = 100
with open(file_path, "r") as f:
print_to_stderr("")
if failed:
if n < num_lines:
print_to_stderr(f"Expand the folded group to see the beginning of the log file of {test}")
print_to_stderr(f"##[group]PRINTING BEGINNING OF LOG FILE of {test} ({file_path})")
print_to_stderr(
f"Expand the folded group to see the beginning of the log file of {test}"
)
print_to_stderr(
f"##[group]PRINTING BEGINNING OF LOG FILE of {test} ({file_path})"
)
for _ in range(num_lines - n):
print_to_stderr(next(f).rstrip())
print_to_stderr("##[endgroup]")
@ -776,11 +821,13 @@ def get_pytest_args(options):
"--use-pytest",
"-vv",
"-rfEX",
"-p", "no:xdist",
"-p",
"no:xdist",
]
pytest_args.extend(rerun_options)
return pytest_args
def run_test_ops(test_module, test_directory, options):
default_unittest_args = get_pytest_args(options)
@ -789,11 +836,13 @@ def run_test_ops(test_module, test_directory, options):
pool = get_context("spawn").Pool(NUM_PROCS)
for i in range(NUM_PROCS):
extra_unittest_args = default_unittest_args.copy()
extra_unittest_args.extend([
f"--shard-id={i}",
f"--num-shards={NUM_PROCS}",
"-k=not _linalg_cholesky_",
])
extra_unittest_args.extend(
[
f"--shard-id={i}",
f"--num-shards={NUM_PROCS}",
"-k=not _linalg_cholesky_",
]
)
return_code = pool.apply_async(
run_test,
@ -813,9 +862,11 @@ def run_test_ops(test_module, test_directory, options):
return return_code.get()
extra_unittest_args = default_unittest_args.copy()
extra_unittest_args.extend([
"-k=_linalg_cholesky_",
])
extra_unittest_args.extend(
[
"-k=_linalg_cholesky_",
]
)
return_code = run_test(
test_module,
@ -859,9 +910,7 @@ CUSTOM_HANDLERS = {
}
PYTEST_SKIP_RETRIES = {
'test_public_bindings'
}
PYTEST_SKIP_RETRIES = {"test_public_bindings"}
def parse_test_module(test):
@ -881,7 +930,7 @@ def parse_args():
description="Run the PyTorch unit test suite",
epilog="where TESTS is any of: {}".format(", ".join(TESTS)),
formatter_class=argparse.RawTextHelpFormatter,
parents=[common_parser]
parents=[common_parser],
)
parser.add_argument(
"-v",
@ -905,22 +954,20 @@ def parse_args():
"If this flag is present, we will only run functorch tests. "
"If this flag is not present, we will run all tests "
"(including functorch tests)."
)
),
)
parser.add_argument(
"--mps",
"--mps",
action="store_true",
help=(
"If this flag is present, we will only run test_mps and test_metal"
)
help=("If this flag is present, we will only run test_mps and test_metal"),
)
parser.add_argument(
"-core",
"--core",
action="store_true",
help="Only run core tests, or tests that validate PyTorch's ops, modules,"
"and autograd. They are defined by CORE_TEST_LIST."
"and autograd. They are defined by CORE_TEST_LIST.",
)
parser.add_argument(
"-pt",
@ -1028,12 +1075,13 @@ def parse_args():
)
parser.add_argument(
"--xdoctest-command",
default='all',
default="all",
help=(
"Control the specific doctest action. "
"Use 'list' to simply parse doctests and check syntax. "
"Use 'all' to execute all doctests or specify a specific "
"doctest to run")
"doctest to run"
),
)
group = parser.add_mutually_exclusive_group()
@ -1088,11 +1136,15 @@ def find_test_index(test, selected_tests, find_last_index=False):
return found_idx
def exclude_tests(exclude_list, selected_tests, exclude_message=None, exact_match=False):
def exclude_tests(
exclude_list, selected_tests, exclude_message=None, exact_match=False
):
for exclude_test in exclude_list:
tests_copy = selected_tests[:]
for test in tests_copy:
if (not exact_match and test.startswith(exclude_test)) or test == exclude_test:
if (
not exact_match and test.startswith(exclude_test)
) or test == exclude_test:
if exclude_message is not None:
print_to_stderr("Excluding {} {}".format(test, exclude_message))
selected_tests.remove(test)
@ -1101,19 +1153,19 @@ def exclude_tests(exclude_list, selected_tests, exclude_message=None, exact_matc
def must_serial(file: str) -> bool:
return (
os.getenv("PYTORCH_TEST_RUN_EVERYTHING_IN_SERIAL", "0") == "1" or
"distributed" in os.getenv("TEST_CONFIG", "") or
"dynamo" in os.getenv("TEST_CONFIG", "") or
"distributed" in file or
file in CUSTOM_HANDLERS or
file in RUN_PARALLEL_BLOCKLIST or
file in CI_SERIAL_LIST or
file in JIT_EXECUTOR_TESTS
os.getenv("PYTORCH_TEST_RUN_EVERYTHING_IN_SERIAL", "0") == "1"
or "distributed" in os.getenv("TEST_CONFIG", "")
or "dynamo" in os.getenv("TEST_CONFIG", "")
or "distributed" in file
or file in CUSTOM_HANDLERS
or file in RUN_PARALLEL_BLOCKLIST
or file in CI_SERIAL_LIST
or file in JIT_EXECUTOR_TESTS
)
def can_run_in_pytest(test):
return os.getenv('PYTORCH_TEST_DO_NOT_USE_PYTEST', '0') == '0'
return os.getenv("PYTORCH_TEST_DO_NOT_USE_PYTEST", "0") == "0"
def get_selected_tests(options):
@ -1127,9 +1179,13 @@ def get_selected_tests(options):
if options.distributed_tests:
selected_tests = list(
filter(lambda test_name: (test_name in DISTRIBUTED_TESTS and
test_name not in DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS),
selected_tests)
filter(
lambda test_name: (
test_name in DISTRIBUTED_TESTS
and test_name not in DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS
),
selected_tests,
)
)
# Filter to only run core tests when --core option is specified
@ -1143,10 +1199,10 @@ def get_selected_tests(options):
selected_tests = [tname for tname in selected_tests if tname in FUNCTORCH_TESTS]
if options.mps:
selected_tests = ['test_mps', 'test_metal']
selected_tests = ["test_mps", "test_metal"]
else:
# Exclude all mps tests otherwise
options.exclude.extend(['test_mps', 'test_metal'])
options.exclude.extend(["test_mps", "test_metal"])
# process reordering
if options.bring_to_front:
@ -1221,25 +1277,39 @@ def get_selected_tests(options):
else:
print("Found test time stats from artifacts")
test_file_times_config = test_file_times[test_config]
shards = calculate_shards(num_shards, selected_tests, test_file_times_config,
must_serial=must_serial)
shards = calculate_shards(
num_shards,
selected_tests,
test_file_times_config,
must_serial=must_serial,
)
_, tests_from_shard = shards[which_shard - 1]
selected_tests = tests_from_shard
# skip all distributed tests if distributed package is not available.
if not dist.is_available():
selected_tests = exclude_tests(DISTRIBUTED_TESTS, selected_tests,
"PyTorch is built without distributed support.")
selected_tests = exclude_tests(
DISTRIBUTED_TESTS,
selected_tests,
"PyTorch is built without distributed support.",
)
# skip tests that require LAPACK when it's not available
if not torch._C.has_lapack:
selected_tests = exclude_tests(TESTS_REQUIRING_LAPACK, selected_tests,
"PyTorch is built without LAPACK support.")
selected_tests = exclude_tests(
TESTS_REQUIRING_LAPACK,
selected_tests,
"PyTorch is built without LAPACK support.",
)
if is_slow_gradcheck_env():
selected_tests = exclude_tests(TESTS_NOT_USING_GRADCHECK, selected_tests,
"Running in slow gradcheck mode, skipping tests "
"that don't use gradcheck.", exact_match=True)
selected_tests = exclude_tests(
TESTS_NOT_USING_GRADCHECK,
selected_tests,
"Running in slow gradcheck mode, skipping tests "
"that don't use gradcheck.",
exact_match=True,
)
if options.distributed_tests:
# Run distributed tests with multiple backends across all shards, one per backend
@ -1303,12 +1373,24 @@ def main():
# parallel = in parallel with other files
# serial = this file on it's own. The file might still be run in parallel with itself (ex test_ops)
selected_tests_parallel = [x for x in selected_tests if not must_serial(x)]
selected_tests_serial = [x for x in selected_tests if x not in selected_tests_parallel]
print_to_stderr("parallel (file granularity) tests:\n {}".format("\n ".join(selected_tests_parallel)))
print_to_stderr("serial (file granularity) tests:\n {}".format("\n ".join(selected_tests_serial)))
selected_tests_serial = [
x for x in selected_tests if x not in selected_tests_parallel
]
print_to_stderr(
"parallel (file granularity) tests:\n {}".format(
"\n ".join(selected_tests_parallel)
)
)
print_to_stderr(
"serial (file granularity) tests:\n {}".format(
"\n ".join(selected_tests_serial)
)
)
# See Note [ROCm parallel CI testing]
pool = get_context("spawn").Pool(NUM_PROCS, maxtasksperchild=None if torch.version.hip else 1)
pool = get_context("spawn").Pool(
NUM_PROCS, maxtasksperchild=None if torch.version.hip else 1
)
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
def success_callback(err_message):
@ -1321,20 +1403,24 @@ def main():
return False
try:
os.environ['PARALLEL_TESTING'] = '1'
os.environ["PARALLEL_TESTING"] = "1"
for test in selected_tests_parallel:
options_clone = copy.deepcopy(options)
if can_run_in_pytest(test):
options_clone.pytest = True
pool.apply_async(run_test_module, args=(test, test_directory, options_clone), callback=success_callback)
pool.apply_async(
run_test_module,
args=(test, test_directory, options_clone),
callback=success_callback,
)
pool.close()
pool.join()
del os.environ['PARALLEL_TESTING']
del os.environ["PARALLEL_TESTING"]
if not options.continue_through_error and len(failure_messages) != 0:
raise RuntimeError(
"\n".join(failure_messages) +
"\n\nTip: You can keep running tests even on failure by "
"\n".join(failure_messages)
+ "\n\nTip: You can keep running tests even on failure by "
"passing --keep-going to run_test.py.\n"
"If running on CI, add the 'keep-going' label to "
"your PR and rerun your jobs."