mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Apply ufmt to run_test and GitHub Python util scripts (#97588)
This has been bugging me for a while as I'm working on these Python scripts and they are not tracked by ufmt linter. So I add these script into that linter. ``` [[linter]] code = 'UFMT' include_patterns = [ '.github/**/*.py', 'test/run_test.py', ``` This change should just work and not break anything as ufmt (black + usort) linter is very safe to use for standalone util scripts. Pull Request resolved: https://github.com/pytorch/pytorch/pull/97588 Approved by: https://github.com/kit1980
This commit is contained in:
97
.github/scripts/build_triton_wheel.py
vendored
97
.github/scripts/build_triton_wheel.py
vendored
@ -1,13 +1,15 @@
|
||||
#!/usr/bin/env python3
|
||||
from subprocess import check_call
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from subprocess import check_call
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional
|
||||
import sys
|
||||
import shutil
|
||||
|
||||
SCRIPT_DIR = Path(__file__).parent
|
||||
REPO_DIR = SCRIPT_DIR.parent.parent
|
||||
|
||||
|
||||
def read_triton_pin() -> str:
|
||||
with open(REPO_DIR / ".ci" / "docker" / "ci_commit_pins" / "triton.txt") as f:
|
||||
return f.read().strip()
|
||||
@ -19,7 +21,7 @@ def read_triton_version() -> str:
|
||||
|
||||
|
||||
def check_and_replace(inp: str, src: str, dst: str) -> str:
|
||||
""" Checks that `src` can be found in `input` and replaces it with `dst` """
|
||||
"""Checks that `src` can be found in `input` and replaces it with `dst`"""
|
||||
if src not in inp:
|
||||
raise RuntimeError(f"Can't find ${src} in the input")
|
||||
return inp.replace(src, dst)
|
||||
@ -29,9 +31,11 @@ def patch_setup_py(path: Path, *, version: str, name: str = "triton") -> None:
|
||||
with open(path) as f:
|
||||
orig = f.read()
|
||||
# Replace name
|
||||
orig = check_and_replace(orig, "name=\"triton\",", f"name=\"{name}\",")
|
||||
orig = check_and_replace(orig, 'name="triton",', f'name="{name}",')
|
||||
# Replace version
|
||||
orig = check_and_replace(orig, f"version=\"{read_triton_version()}\",", f"version=\"{version}\",")
|
||||
orig = check_and_replace(
|
||||
orig, f'version="{read_triton_version()}",', f'version="{version}",'
|
||||
)
|
||||
with open(path, "w") as f:
|
||||
f.write(orig)
|
||||
|
||||
@ -40,12 +44,20 @@ def patch_init_py(path: Path, *, version: str) -> None:
|
||||
with open(path) as f:
|
||||
orig = f.read()
|
||||
# Replace version
|
||||
orig = check_and_replace(orig, f"__version__ = '{read_triton_version()}'", f"__version__ = \"{version}\"")
|
||||
orig = check_and_replace(
|
||||
orig, f"__version__ = '{read_triton_version()}'", f'__version__ = "{version}"'
|
||||
)
|
||||
with open(path, "w") as f:
|
||||
f.write(orig)
|
||||
|
||||
|
||||
def build_triton(*, version: str, commit_hash: str, build_conda: bool = False, py_version : Optional[str] = None) -> Path:
|
||||
def build_triton(
|
||||
*,
|
||||
version: str,
|
||||
commit_hash: str,
|
||||
build_conda: bool = False,
|
||||
py_version: Optional[str] = None,
|
||||
) -> Path:
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
triton_basedir = Path(tmpdir) / "triton"
|
||||
triton_pythondir = triton_basedir / "python"
|
||||
@ -53,26 +65,60 @@ def build_triton(*, version: str, commit_hash: str, build_conda: bool = False, p
|
||||
check_call(["git", "checkout", commit_hash], cwd=triton_basedir)
|
||||
if build_conda:
|
||||
with open(triton_basedir / "meta.yaml", "w") as meta:
|
||||
print(f"package:\n name: torchtriton\n version: {version}+{commit_hash[:10]}\n", file=meta)
|
||||
print(
|
||||
f"package:\n name: torchtriton\n version: {version}+{commit_hash[:10]}\n",
|
||||
file=meta,
|
||||
)
|
||||
print("source:\n path: .\n", file=meta)
|
||||
print("build:\n string: py{{py}}\n number: 1\n script: cd python; "
|
||||
"python setup.py install --single-version-externally-managed --record=record.txt\n", file=meta)
|
||||
print("requirements:\n host:\n - python\n - setuptools\n run:\n - python\n"
|
||||
" - filelock\n - pytorch\n", file=meta)
|
||||
print("about:\n home: https://github.com/openai/triton\n license: MIT\n summary:"
|
||||
" 'A language and compiler for custom Deep Learning operation'", file=meta)
|
||||
print(
|
||||
"build:\n string: py{{py}}\n number: 1\n script: cd python; "
|
||||
"python setup.py install --single-version-externally-managed --record=record.txt\n",
|
||||
file=meta,
|
||||
)
|
||||
print(
|
||||
"requirements:\n host:\n - python\n - setuptools\n run:\n - python\n"
|
||||
" - filelock\n - pytorch\n",
|
||||
file=meta,
|
||||
)
|
||||
print(
|
||||
"about:\n home: https://github.com/openai/triton\n license: MIT\n summary:"
|
||||
" 'A language and compiler for custom Deep Learning operation'",
|
||||
file=meta,
|
||||
)
|
||||
|
||||
patch_init_py(triton_pythondir / "triton" / "__init__.py", version=f"{version}+{commit_hash[:10]}")
|
||||
patch_init_py(
|
||||
triton_pythondir / "triton" / "__init__.py",
|
||||
version=f"{version}+{commit_hash[:10]}",
|
||||
)
|
||||
if py_version is None:
|
||||
py_version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
||||
check_call(["conda", "build", "--python", py_version,
|
||||
"-c", "pytorch-nightly", "--output-folder", tmpdir, "."], cwd=triton_basedir)
|
||||
check_call(
|
||||
[
|
||||
"conda",
|
||||
"build",
|
||||
"--python",
|
||||
py_version,
|
||||
"-c",
|
||||
"pytorch-nightly",
|
||||
"--output-folder",
|
||||
tmpdir,
|
||||
".",
|
||||
],
|
||||
cwd=triton_basedir,
|
||||
)
|
||||
conda_path = list(Path(tmpdir).glob("linux-64/torchtriton*.bz2"))[0]
|
||||
shutil.copy(conda_path, Path.cwd())
|
||||
return Path.cwd() / conda_path.name
|
||||
|
||||
patch_setup_py(triton_pythondir / "setup.py", name="pytorch-triton", version=f"{version}+{commit_hash[:10]}")
|
||||
patch_init_py(triton_pythondir / "triton" / "__init__.py", version=f"{version}+{commit_hash[:10]}")
|
||||
patch_setup_py(
|
||||
triton_pythondir / "setup.py",
|
||||
name="pytorch-triton",
|
||||
version=f"{version}+{commit_hash[:10]}",
|
||||
)
|
||||
patch_init_py(
|
||||
triton_pythondir / "triton" / "__init__.py",
|
||||
version=f"{version}+{commit_hash[:10]}",
|
||||
)
|
||||
check_call([sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir)
|
||||
whl_path = list((triton_pythondir / "dist").glob("*.whl"))[0]
|
||||
shutil.copy(whl_path, Path.cwd())
|
||||
@ -81,16 +127,19 @@ def build_triton(*, version: str, commit_hash: str, build_conda: bool = False, p
|
||||
|
||||
def main() -> None:
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser("Build Triton binaries")
|
||||
parser.add_argument("--build-conda", action="store_true")
|
||||
parser.add_argument("--py-version", type=str)
|
||||
parser.add_argument("--commit-hash", type=str, default=read_triton_pin())
|
||||
parser.add_argument("--triton-version", type=str, default=read_triton_version())
|
||||
args = parser.parse_args()
|
||||
build_triton(commit_hash=args.commit_hash,
|
||||
version=args.triton_version,
|
||||
build_conda=args.build_conda,
|
||||
py_version=args.py_version)
|
||||
build_triton(
|
||||
commit_hash=args.commit_hash,
|
||||
version=args.triton_version,
|
||||
build_conda=args.build_conda,
|
||||
py_version=args.py_version,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
20
.github/scripts/check_labels.py
vendored
20
.github/scripts/check_labels.py
vendored
@ -3,21 +3,12 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from gitutils import (
|
||||
get_git_remote_name,
|
||||
get_git_repo_dir,
|
||||
GitRepo,
|
||||
)
|
||||
from github_utils import gh_delete_comment, gh_post_pr_comment
|
||||
|
||||
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
|
||||
from label_utils import has_required_labels, is_label_err_comment, LABEL_ERR_MSG
|
||||
from trymerge import GitHubPR
|
||||
from github_utils import (
|
||||
gh_delete_comment,
|
||||
gh_post_pr_comment,
|
||||
)
|
||||
from label_utils import (
|
||||
LABEL_ERR_MSG,
|
||||
is_label_err_comment,
|
||||
has_required_labels,
|
||||
)
|
||||
|
||||
|
||||
def delete_all_label_err_comments(pr: "GitHubPR") -> None:
|
||||
for comment in pr.get_comments():
|
||||
@ -33,6 +24,7 @@ def add_label_err_comment(pr: "GitHubPR") -> None:
|
||||
|
||||
def parse_args() -> Any:
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser("Check PR labels")
|
||||
parser.add_argument("pr_num", type=int)
|
||||
|
||||
|
25
.github/scripts/collect_ciflow_labels.py
vendored
25
.github/scripts/collect_ciflow_labels.py
vendored
@ -1,11 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Set, cast
|
||||
import yaml
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, cast, Dict, List, Set
|
||||
|
||||
import yaml
|
||||
|
||||
GITHUB_DIR = Path(__file__).parent.parent
|
||||
|
||||
|
||||
def get_workflows_push_tags() -> Set[str]:
|
||||
"Extract all known push tags from workflows"
|
||||
rc: Set[str] = set()
|
||||
@ -22,8 +24,10 @@ def get_workflows_push_tags() -> Set[str]:
|
||||
|
||||
|
||||
def filter_ciflow_tags(tags: Set[str]) -> List[str]:
|
||||
" Return sorted list of ciflow tags"
|
||||
return sorted(tag[:-2] for tag in tags if tag.startswith("ciflow/") and tag.endswith("/*"))
|
||||
"Return sorted list of ciflow tags"
|
||||
return sorted(
|
||||
tag[:-2] for tag in tags if tag.startswith("ciflow/") and tag.endswith("/*")
|
||||
)
|
||||
|
||||
|
||||
def read_probot_config() -> Dict[str, Any]:
|
||||
@ -40,6 +44,7 @@ def update_probot_config(labels: Set[str]) -> None:
|
||||
|
||||
if __name__ == "__main__":
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser("Validate or update list of tags")
|
||||
parser.add_argument("--validate-tags", action="store_true")
|
||||
args = parser.parse_args()
|
||||
@ -51,9 +56,15 @@ if __name__ == "__main__":
|
||||
if config_tags != ciflow_tags:
|
||||
print("Tags mismatch!")
|
||||
if ciflow_tags.difference(config_tags):
|
||||
print("Reference in workflows but not in config", ciflow_tags.difference(config_tags))
|
||||
print(
|
||||
"Reference in workflows but not in config",
|
||||
ciflow_tags.difference(config_tags),
|
||||
)
|
||||
if config_tags.difference(ciflow_tags):
|
||||
print("Reference in config, but not in workflows", config_tags.difference(ciflow_tags))
|
||||
print(
|
||||
"Reference in config, but not in workflows",
|
||||
config_tags.difference(ciflow_tags),
|
||||
)
|
||||
print(f"Please run {__file__} to remediate the difference")
|
||||
sys.exit(-1)
|
||||
print("All tags are listed in pytorch-probot.yml")
|
||||
|
3
.github/scripts/comment_on_pr.py
vendored
3
.github/scripts/comment_on_pr.py
vendored
@ -1,8 +1,9 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from github_utils import gh_post_pr_comment
|
||||
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
|
||||
from trymerge_explainer import BOT_COMMANDS_WIKI
|
||||
import os
|
||||
|
||||
|
||||
def parse_args() -> Any:
|
||||
|
@ -6,6 +6,7 @@ from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
|
||||
# From: https://docs.github.com/en/rest/reference/checks
|
||||
class GitHubAnnotationLevel(str, Enum):
|
||||
NOTICE = "notice"
|
||||
@ -24,7 +25,12 @@ class GitHubAnnotation(NamedTuple):
|
||||
title: Optional[str]
|
||||
raw_details: Optional[str]
|
||||
|
||||
PYTORCH_ROOT = Path(subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).decode('ascii').strip())
|
||||
|
||||
PYTORCH_ROOT = Path(
|
||||
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
|
||||
.decode("ascii")
|
||||
.strip()
|
||||
)
|
||||
|
||||
annotations = []
|
||||
for line in sys.stdin:
|
||||
@ -33,7 +39,6 @@ for line in sys.stdin:
|
||||
path = lint_message.get("path")
|
||||
line = lint_message.get("line")
|
||||
|
||||
|
||||
code = lint_message["code"]
|
||||
severity = lint_message["severity"]
|
||||
name = lint_message["name"]
|
||||
@ -48,16 +53,18 @@ for line in sys.stdin:
|
||||
# normalize path relative to git root
|
||||
path = Path(path).relative_to(PYTORCH_ROOT)
|
||||
|
||||
annotations.append(GitHubAnnotation(
|
||||
path=str(path),
|
||||
start_line=int(line),
|
||||
end_line=int(line),
|
||||
start_column=None,
|
||||
end_column=None,
|
||||
annotation_level=GitHubAnnotationLevel.FAILURE,
|
||||
message=description,
|
||||
title=f"({code}) {name}",
|
||||
raw_details=None,
|
||||
)._asdict())
|
||||
annotations.append(
|
||||
GitHubAnnotation(
|
||||
path=str(path),
|
||||
start_line=int(line),
|
||||
end_line=int(line),
|
||||
start_column=None,
|
||||
end_column=None,
|
||||
annotation_level=GitHubAnnotationLevel.FAILURE,
|
||||
message=description,
|
||||
title=f"({code}) {name}",
|
||||
raw_details=None,
|
||||
)._asdict()
|
||||
)
|
||||
|
||||
print(json.dumps(annotations), flush=True)
|
||||
|
@ -2,15 +2,18 @@
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import yaml
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
|
||||
WORKFLOWS = REPO_ROOT / ".github" / "workflows"
|
||||
EXPECTED_GROUP = "${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}" \
|
||||
EXPECTED_GROUP = (
|
||||
"${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}"
|
||||
"-${{ github.event_name == 'workflow_dispatch' }}"
|
||||
)
|
||||
|
||||
|
||||
def should_check(filename: Path) -> bool:
|
||||
|
14
.github/scripts/export_pytorch_labels.py
vendored
14
.github/scripts/export_pytorch_labels.py
vendored
@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
'''
|
||||
"""
|
||||
Test ownership was introduced in https://github.com/pytorch/pytorch/issues/66232.
|
||||
|
||||
As a part of enforcing test ownership, we want to maintain a list of existing PyTorch labels
|
||||
@ -8,17 +8,19 @@ pytorch/pytorch labels so that the file could be uploaded to S3.
|
||||
|
||||
This script assumes the correct env vars are set for AWS permissions.
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import boto3 # type: ignore[import]
|
||||
import json
|
||||
|
||||
from label_utils import gh_get_labels
|
||||
from typing import Any
|
||||
|
||||
|
||||
def parse_args() -> Any:
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser("Export PR labels")
|
||||
parser.add_argument("org", type=str)
|
||||
parser.add_argument("repo", type=str)
|
||||
@ -30,9 +32,9 @@ def main() -> None:
|
||||
args = parse_args()
|
||||
print(f"Exporting labels for {args.org}/{args.repo}")
|
||||
labels_file_name = "pytorch_labels.json"
|
||||
obj = boto3.resource('s3').Object('ossci-metrics', labels_file_name)
|
||||
obj = boto3.resource("s3").Object("ossci-metrics", labels_file_name)
|
||||
obj.put(Body=json.dumps(gh_get_labels(args.org, args.repo)).encode())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
63
.github/scripts/fetch_latest_green_commit.py
vendored
63
.github/scripts/fetch_latest_green_commit.py
vendored
@ -1,20 +1,23 @@
|
||||
import sys
|
||||
from typing import Any, Dict, List, NamedTuple, Tuple, cast
|
||||
from gitutils import _check_output
|
||||
|
||||
import rockset # type: ignore[import]
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, cast, Dict, List, NamedTuple, Tuple
|
||||
|
||||
import rockset # type: ignore[import]
|
||||
from gitutils import _check_output
|
||||
|
||||
|
||||
def eprint(msg: str) -> None:
|
||||
print(msg, file=sys.stderr)
|
||||
|
||||
|
||||
class WorkflowCheck(NamedTuple):
|
||||
workflowName: str
|
||||
name: str
|
||||
jobName: str
|
||||
conclusion: str
|
||||
|
||||
|
||||
def get_latest_commits() -> List[str]:
|
||||
latest_viable_commit = _check_output(
|
||||
[
|
||||
@ -39,42 +42,46 @@ def get_latest_commits() -> List[str]:
|
||||
|
||||
return commits
|
||||
|
||||
|
||||
def query_commits(commits: List[str]) -> List[Dict[str, Any]]:
|
||||
rs = rockset.RocksetClient(
|
||||
host="api.usw2a1.rockset.com", api_key=os.environ["ROCKSET_API_KEY"]
|
||||
)
|
||||
params = [{
|
||||
"name": "shas",
|
||||
"type": "string",
|
||||
"value": ",".join(commits)
|
||||
}]
|
||||
params = [{"name": "shas", "type": "string", "value": ",".join(commits)}]
|
||||
res = rs.QueryLambdas.execute_query_lambda(
|
||||
query_lambda='commit_jobs_batch_query',
|
||||
version='8003fdfd18b64696',
|
||||
workspace='commons',
|
||||
parameters=params
|
||||
query_lambda="commit_jobs_batch_query",
|
||||
version="8003fdfd18b64696",
|
||||
workspace="commons",
|
||||
parameters=params,
|
||||
)
|
||||
|
||||
return cast(List[Dict[str, Any]], res.results)
|
||||
|
||||
|
||||
def print_commit_status(commit: str, results: Dict[str, Any]) -> None:
|
||||
print(commit)
|
||||
for check in results['results']:
|
||||
if check['sha'] == commit:
|
||||
for check in results["results"]:
|
||||
if check["sha"] == commit:
|
||||
print(f"\t{check['conclusion']:>10}: {check['name']}")
|
||||
|
||||
def get_commit_results(commit: str, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
|
||||
def get_commit_results(
|
||||
commit: str, results: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
workflow_checks = []
|
||||
for check in results:
|
||||
if check['sha'] == commit:
|
||||
workflow_checks.append(WorkflowCheck(
|
||||
workflowName=check['workflowName'],
|
||||
name=check['name'],
|
||||
jobName=check['jobName'],
|
||||
conclusion=check['conclusion'],
|
||||
)._asdict())
|
||||
if check["sha"] == commit:
|
||||
workflow_checks.append(
|
||||
WorkflowCheck(
|
||||
workflowName=check["workflowName"],
|
||||
name=check["name"],
|
||||
jobName=check["jobName"],
|
||||
conclusion=check["conclusion"],
|
||||
)._asdict()
|
||||
)
|
||||
return workflow_checks
|
||||
|
||||
|
||||
def isGreen(commit: str, results: List[Dict[str, Any]]) -> Tuple[bool, str]:
|
||||
workflow_checks = get_commit_results(commit, results)
|
||||
|
||||
@ -87,8 +94,8 @@ def isGreen(commit: str, results: List[Dict[str, Any]]) -> Tuple[bool, str]:
|
||||
}
|
||||
|
||||
for check in workflow_checks:
|
||||
workflowName = check['workflowName']
|
||||
conclusion = check['conclusion']
|
||||
workflowName = check["workflowName"]
|
||||
conclusion = check["conclusion"]
|
||||
for required_check in regex:
|
||||
if re.match(required_check, workflowName, flags=re.IGNORECASE):
|
||||
if conclusion not in ["success", "skipped"]:
|
||||
@ -102,6 +109,7 @@ def isGreen(commit: str, results: List[Dict[str, Any]]) -> Tuple[bool, str]:
|
||||
|
||||
return (True, "")
|
||||
|
||||
|
||||
def get_latest_green_commit(commits: List[str], results: List[Dict[str, Any]]) -> Any:
|
||||
for commit in commits:
|
||||
eprint(f"Checking {commit}")
|
||||
@ -113,13 +121,14 @@ def get_latest_green_commit(commits: List[str], results: List[Dict[str, Any]]) -
|
||||
eprint("RED: " + msg)
|
||||
return None
|
||||
|
||||
def main() -> None:
|
||||
|
||||
def main() -> None:
|
||||
commits = get_latest_commits()
|
||||
results = query_commits(commits)
|
||||
|
||||
latest_viable_commit = get_latest_green_commit(commits, results)
|
||||
print(latest_viable_commit)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
43
.github/scripts/generate_binary_build_matrix.py
vendored
43
.github/scripts/generate_binary_build_matrix.py
vendored
@ -10,7 +10,7 @@ architectures:
|
||||
* Latest ROCM
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
CUDA_ARCHES = ["11.7", "11.8"]
|
||||
@ -19,7 +19,8 @@ CUDA_ARCHES = ["11.7", "11.8"]
|
||||
ROCM_ARCHES = ["5.3", "5.4.2"]
|
||||
|
||||
|
||||
CPU_CXX11_ABI_ARCH = ['cpu-cxx11-abi']
|
||||
CPU_CXX11_ABI_ARCH = ["cpu-cxx11-abi"]
|
||||
|
||||
|
||||
def arch_type(arch_version: str) -> str:
|
||||
if arch_version in CUDA_ARCHES:
|
||||
@ -121,9 +122,12 @@ def generate_conda_matrix(os: str) -> List[Dict[str, str]]:
|
||||
return ret
|
||||
|
||||
|
||||
def generate_libtorch_matrix(os: str, abi_version: str,
|
||||
arches: Optional[List[str]] = None,
|
||||
libtorch_variants: Optional[List[str]] = None) -> List[Dict[str, str]]:
|
||||
def generate_libtorch_matrix(
|
||||
os: str,
|
||||
abi_version: str,
|
||||
arches: Optional[List[str]] = None,
|
||||
libtorch_variants: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
if arches is None:
|
||||
arches = ["cpu"]
|
||||
if os == "linux":
|
||||
@ -163,7 +167,9 @@ def generate_libtorch_matrix(os: str, abi_version: str,
|
||||
"devtoolset": abi_version if os != "windows" else "",
|
||||
"container_image": LIBTORCH_CONTAINER_IMAGES[
|
||||
(arch_version, abi_version)
|
||||
] if os != "windows" else "",
|
||||
]
|
||||
if os != "windows"
|
||||
else "",
|
||||
"package_type": "libtorch",
|
||||
"build_name": f"libtorch-{gpu_arch_type}{gpu_arch_version}-{libtorch_variant}-{abi_version}".replace(
|
||||
".", "_"
|
||||
@ -173,9 +179,11 @@ def generate_libtorch_matrix(os: str, abi_version: str,
|
||||
return ret
|
||||
|
||||
|
||||
def generate_wheels_matrix(os: str,
|
||||
arches: Optional[List[str]] = None,
|
||||
python_versions: Optional[List[str]] = None) -> List[Dict[str, str]]:
|
||||
def generate_wheels_matrix(
|
||||
os: str,
|
||||
arches: Optional[List[str]] = None,
|
||||
python_versions: Optional[List[str]] = None,
|
||||
) -> List[Dict[str, str]]:
|
||||
package_type = "wheel"
|
||||
if os == "linux":
|
||||
# NOTE: We only build manywheel packages for linux
|
||||
@ -196,7 +204,11 @@ def generate_wheels_matrix(os: str,
|
||||
for python_version in python_versions:
|
||||
for arch_version in arches:
|
||||
gpu_arch_type = arch_type(arch_version)
|
||||
gpu_arch_version = "" if arch_version == "cpu" or arch_version == "cpu-cxx11-abi" else arch_version
|
||||
gpu_arch_version = (
|
||||
""
|
||||
if arch_version == "cpu" or arch_version == "cpu-cxx11-abi"
|
||||
else arch_version
|
||||
)
|
||||
# Skip rocm 3.11 binaries for now as the docker image are not correct
|
||||
if python_version == "3.11" and gpu_arch_type == "rocm":
|
||||
continue
|
||||
@ -215,8 +227,7 @@ def generate_wheels_matrix(os: str,
|
||||
"devtoolset": "",
|
||||
"container_image": WHEEL_CONTAINER_IMAGES[arch_version],
|
||||
"package_type": package_type,
|
||||
"pytorch_extra_install_requirements":
|
||||
"nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||
"pytorch_extra_install_requirements": "nvidia-cuda-nvrtc-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | " # noqa: B950
|
||||
"nvidia-cuda-runtime-cu11==11.7.99; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||
"nvidia-cuda-cupti-cu11==11.7.101; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||
"nvidia-cudnn-cu11==8.5.0.96; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||
@ -227,9 +238,7 @@ def generate_wheels_matrix(os: str,
|
||||
"nvidia-cusparse-cu11==11.7.4.91; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||
"nvidia-nccl-cu11==2.14.3; platform_system == 'Linux' and platform_machine == 'x86_64' | "
|
||||
"nvidia-nvtx-cu11==11.7.91; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||
"build_name":
|
||||
f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-with-pypi-cudnn"
|
||||
.replace(
|
||||
"build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}-with-pypi-cudnn".replace( # noqa: B950
|
||||
".", "_"
|
||||
),
|
||||
}
|
||||
@ -243,7 +252,9 @@ def generate_wheels_matrix(os: str,
|
||||
"desired_cuda": translate_desired_cuda(
|
||||
gpu_arch_type, gpu_arch_version
|
||||
),
|
||||
"devtoolset": "cxx11-abi" if arch_version == "cpu-cxx11-abi" else "",
|
||||
"devtoolset": "cxx11-abi"
|
||||
if arch_version == "cpu-cxx11-abi"
|
||||
else "",
|
||||
"container_image": WHEEL_CONTAINER_IMAGES[arch_version],
|
||||
"package_type": package_type,
|
||||
"build_name": f"{package_type}-py{python_version}-{gpu_arch_type}{gpu_arch_version}".replace(
|
||||
|
118
.github/scripts/generate_ci_workflows.py
vendored
118
.github/scripts/generate_ci_workflows.py
vendored
@ -1,17 +1,16 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, Set, List, Literal, Iterable
|
||||
|
||||
import jinja2
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing_extensions import TypedDict # Python 3.11+
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Literal, Set
|
||||
|
||||
import generate_binary_build_matrix # type: ignore[import]
|
||||
|
||||
import jinja2
|
||||
from typing_extensions import TypedDict # Python 3.11+
|
||||
|
||||
Arch = Literal["windows", "linux", "macos"]
|
||||
|
||||
GITHUB_DIR = Path(__file__).resolve().parent.parent
|
||||
@ -23,6 +22,7 @@ LABEL_CIFLOW_BINARIES_LIBTORCH = "ciflow/binaries_libtorch"
|
||||
LABEL_CIFLOW_BINARIES_CONDA = "ciflow/binaries_conda"
|
||||
LABEL_CIFLOW_BINARIES_WHEEL = "ciflow/binaries_wheel"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CIFlowConfig:
|
||||
# For use to enable workflows to run on pytorch/pytorch-canary
|
||||
@ -36,10 +36,12 @@ class CIFlowConfig:
|
||||
if LABEL_CIFLOW_PERIODIC not in self.labels:
|
||||
self.labels.add(LABEL_CIFLOW_TRUNK)
|
||||
|
||||
|
||||
class Config(TypedDict):
|
||||
num_shards: int
|
||||
runner: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class BinaryBuildWorkflow:
|
||||
os: str
|
||||
@ -47,23 +49,28 @@ class BinaryBuildWorkflow:
|
||||
package_type: str
|
||||
|
||||
# Optional fields
|
||||
build_environment: str = ''
|
||||
abi_version: str = ''
|
||||
build_environment: str = ""
|
||||
abi_version: str = ""
|
||||
ciflow_config: CIFlowConfig = field(default_factory=CIFlowConfig)
|
||||
is_scheduled: str = ''
|
||||
branches: str = 'nightly'
|
||||
is_scheduled: str = ""
|
||||
branches: str = "nightly"
|
||||
# Mainly for macos
|
||||
cross_compile_arm64: bool = False
|
||||
xcode_version: str = ''
|
||||
xcode_version: str = ""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.abi_version:
|
||||
self.build_environment = f"{self.os}-binary-{self.package_type}-{self.abi_version}"
|
||||
self.build_environment = (
|
||||
f"{self.os}-binary-{self.package_type}-{self.abi_version}"
|
||||
)
|
||||
else:
|
||||
self.build_environment = f"{self.os}-binary-{self.package_type}"
|
||||
|
||||
def generate_workflow_file(self, workflow_template: jinja2.Template) -> None:
|
||||
output_file_path = GITHUB_DIR / f"workflows/generated-{self.build_environment}-{self.branches}.yml"
|
||||
output_file_path = (
|
||||
GITHUB_DIR
|
||||
/ f"workflows/generated-{self.build_environment}-{self.branches}.yml"
|
||||
)
|
||||
with open(output_file_path, "w") as output_file:
|
||||
GENERATED = "generated" # Note that please keep the variable GENERATED otherwise phabricator will hide the whole file
|
||||
output_file.writelines([f"# @{GENERATED} DO NOT EDIT MANUALLY\n"])
|
||||
@ -77,17 +84,21 @@ class BinaryBuildWorkflow:
|
||||
output_file.write("\n")
|
||||
print(output_file_path)
|
||||
|
||||
|
||||
class OperatingSystem:
|
||||
LINUX = "linux"
|
||||
WINDOWS = "windows"
|
||||
MACOS = "macos"
|
||||
MACOS_ARM64 = "macos-arm64"
|
||||
|
||||
|
||||
LINUX_BINARY_BUILD_WORFKLOWS = [
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.LINUX,
|
||||
package_type="manywheel",
|
||||
build_configs=generate_binary_build_matrix.generate_wheels_matrix(OperatingSystem.LINUX),
|
||||
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
|
||||
OperatingSystem.LINUX
|
||||
),
|
||||
ciflow_config=CIFlowConfig(
|
||||
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL},
|
||||
isolated_workflow=True,
|
||||
@ -96,7 +107,9 @@ LINUX_BINARY_BUILD_WORFKLOWS = [
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.LINUX,
|
||||
package_type="conda",
|
||||
build_configs=generate_binary_build_matrix.generate_conda_matrix(OperatingSystem.LINUX),
|
||||
build_configs=generate_binary_build_matrix.generate_conda_matrix(
|
||||
OperatingSystem.LINUX
|
||||
),
|
||||
ciflow_config=CIFlowConfig(
|
||||
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_CONDA},
|
||||
isolated_workflow=True,
|
||||
@ -133,18 +146,16 @@ LINUX_BINARY_SMOKE_WORKFLOWS = [
|
||||
os=OperatingSystem.LINUX,
|
||||
package_type="manywheel",
|
||||
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
|
||||
OperatingSystem.LINUX,
|
||||
arches=["11.8"],
|
||||
python_versions=["3.8"]),
|
||||
OperatingSystem.LINUX, arches=["11.8"], python_versions=["3.8"]
|
||||
),
|
||||
branches="master",
|
||||
),
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.LINUX,
|
||||
package_type="manywheel",
|
||||
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
|
||||
OperatingSystem.LINUX,
|
||||
arches=["11.7"],
|
||||
python_versions=["3.8"]),
|
||||
OperatingSystem.LINUX, arches=["11.7"], python_versions=["3.8"]
|
||||
),
|
||||
branches="master",
|
||||
),
|
||||
BinaryBuildWorkflow(
|
||||
@ -152,7 +163,8 @@ LINUX_BINARY_SMOKE_WORKFLOWS = [
|
||||
package_type="libtorch",
|
||||
abi_version=generate_binary_build_matrix.CXX11_ABI,
|
||||
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
|
||||
OperatingSystem.LINUX, generate_binary_build_matrix.CXX11_ABI,
|
||||
OperatingSystem.LINUX,
|
||||
generate_binary_build_matrix.CXX11_ABI,
|
||||
arches=["cpu"],
|
||||
libtorch_variants=["shared-with-deps"],
|
||||
),
|
||||
@ -163,7 +175,8 @@ LINUX_BINARY_SMOKE_WORKFLOWS = [
|
||||
package_type="libtorch",
|
||||
abi_version=generate_binary_build_matrix.PRE_CXX11_ABI,
|
||||
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
|
||||
OperatingSystem.LINUX, generate_binary_build_matrix.PRE_CXX11_ABI,
|
||||
OperatingSystem.LINUX,
|
||||
generate_binary_build_matrix.PRE_CXX11_ABI,
|
||||
arches=["cpu"],
|
||||
libtorch_variants=["shared-with-deps"],
|
||||
),
|
||||
@ -175,7 +188,9 @@ WINDOWS_BINARY_BUILD_WORKFLOWS = [
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.WINDOWS,
|
||||
package_type="wheel",
|
||||
build_configs=generate_binary_build_matrix.generate_wheels_matrix(OperatingSystem.WINDOWS),
|
||||
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
|
||||
OperatingSystem.WINDOWS
|
||||
),
|
||||
ciflow_config=CIFlowConfig(
|
||||
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL},
|
||||
isolated_workflow=True,
|
||||
@ -184,7 +199,9 @@ WINDOWS_BINARY_BUILD_WORKFLOWS = [
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.WINDOWS,
|
||||
package_type="conda",
|
||||
build_configs=generate_binary_build_matrix.generate_conda_matrix(OperatingSystem.WINDOWS),
|
||||
build_configs=generate_binary_build_matrix.generate_conda_matrix(
|
||||
OperatingSystem.WINDOWS
|
||||
),
|
||||
ciflow_config=CIFlowConfig(
|
||||
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_CONDA},
|
||||
isolated_workflow=True,
|
||||
@ -221,7 +238,8 @@ WINDOWS_BINARY_SMOKE_WORKFLOWS = [
|
||||
package_type="libtorch",
|
||||
abi_version=generate_binary_build_matrix.RELEASE,
|
||||
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
|
||||
OperatingSystem.WINDOWS, generate_binary_build_matrix.RELEASE,
|
||||
OperatingSystem.WINDOWS,
|
||||
generate_binary_build_matrix.RELEASE,
|
||||
arches=["cpu"],
|
||||
libtorch_variants=["shared-with-deps"],
|
||||
),
|
||||
@ -232,7 +250,8 @@ WINDOWS_BINARY_SMOKE_WORKFLOWS = [
|
||||
package_type="libtorch",
|
||||
abi_version=generate_binary_build_matrix.DEBUG,
|
||||
build_configs=generate_binary_build_matrix.generate_libtorch_matrix(
|
||||
OperatingSystem.WINDOWS, generate_binary_build_matrix.DEBUG,
|
||||
OperatingSystem.WINDOWS,
|
||||
generate_binary_build_matrix.DEBUG,
|
||||
arches=["cpu"],
|
||||
libtorch_variants=["shared-with-deps"],
|
||||
),
|
||||
@ -244,7 +263,9 @@ MACOS_BINARY_BUILD_WORKFLOWS = [
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.MACOS,
|
||||
package_type="wheel",
|
||||
build_configs=generate_binary_build_matrix.generate_wheels_matrix(OperatingSystem.MACOS),
|
||||
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
|
||||
OperatingSystem.MACOS
|
||||
),
|
||||
ciflow_config=CIFlowConfig(
|
||||
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL},
|
||||
isolated_workflow=True,
|
||||
@ -253,7 +274,9 @@ MACOS_BINARY_BUILD_WORKFLOWS = [
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.MACOS,
|
||||
package_type="conda",
|
||||
build_configs=generate_binary_build_matrix.generate_conda_matrix(OperatingSystem.MACOS),
|
||||
build_configs=generate_binary_build_matrix.generate_conda_matrix(
|
||||
OperatingSystem.MACOS
|
||||
),
|
||||
ciflow_config=CIFlowConfig(
|
||||
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_CONDA},
|
||||
isolated_workflow=True,
|
||||
@ -274,7 +297,9 @@ MACOS_BINARY_BUILD_WORKFLOWS = [
|
||||
BinaryBuildWorkflow(
|
||||
os=OperatingSystem.MACOS_ARM64,
|
||||
package_type="wheel",
|
||||
build_configs=generate_binary_build_matrix.generate_wheels_matrix(OperatingSystem.MACOS_ARM64),
|
||||
build_configs=generate_binary_build_matrix.generate_wheels_matrix(
|
||||
OperatingSystem.MACOS_ARM64
|
||||
),
|
||||
cross_compile_arm64=True,
|
||||
ciflow_config=CIFlowConfig(
|
||||
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL},
|
||||
@ -285,7 +310,9 @@ MACOS_BINARY_BUILD_WORKFLOWS = [
|
||||
os=OperatingSystem.MACOS_ARM64,
|
||||
package_type="conda",
|
||||
cross_compile_arm64=True,
|
||||
build_configs=generate_binary_build_matrix.generate_conda_matrix(OperatingSystem.MACOS_ARM64),
|
||||
build_configs=generate_binary_build_matrix.generate_conda_matrix(
|
||||
OperatingSystem.MACOS_ARM64
|
||||
),
|
||||
ciflow_config=CIFlowConfig(
|
||||
labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_CONDA},
|
||||
isolated_workflow=True,
|
||||
@ -293,6 +320,7 @@ MACOS_BINARY_BUILD_WORKFLOWS = [
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
jinja_env = jinja2.Environment(
|
||||
variable_start_string="!{{",
|
||||
@ -302,11 +330,26 @@ def main() -> None:
|
||||
|
||||
# not ported yet
|
||||
template_and_workflows = [
|
||||
(jinja_env.get_template("linux_binary_build_workflow.yml.j2"), LINUX_BINARY_BUILD_WORFKLOWS),
|
||||
(jinja_env.get_template("linux_binary_build_workflow.yml.j2"), LINUX_BINARY_SMOKE_WORKFLOWS),
|
||||
(jinja_env.get_template("windows_binary_build_workflow.yml.j2"), WINDOWS_BINARY_BUILD_WORKFLOWS),
|
||||
(jinja_env.get_template("windows_binary_build_workflow.yml.j2"), WINDOWS_BINARY_SMOKE_WORKFLOWS),
|
||||
(jinja_env.get_template("macos_binary_build_workflow.yml.j2"), MACOS_BINARY_BUILD_WORKFLOWS),
|
||||
(
|
||||
jinja_env.get_template("linux_binary_build_workflow.yml.j2"),
|
||||
LINUX_BINARY_BUILD_WORFKLOWS,
|
||||
),
|
||||
(
|
||||
jinja_env.get_template("linux_binary_build_workflow.yml.j2"),
|
||||
LINUX_BINARY_SMOKE_WORKFLOWS,
|
||||
),
|
||||
(
|
||||
jinja_env.get_template("windows_binary_build_workflow.yml.j2"),
|
||||
WINDOWS_BINARY_BUILD_WORKFLOWS,
|
||||
),
|
||||
(
|
||||
jinja_env.get_template("windows_binary_build_workflow.yml.j2"),
|
||||
WINDOWS_BINARY_SMOKE_WORKFLOWS,
|
||||
),
|
||||
(
|
||||
jinja_env.get_template("macos_binary_build_workflow.yml.j2"),
|
||||
MACOS_BINARY_BUILD_WORKFLOWS,
|
||||
),
|
||||
]
|
||||
# Delete the existing generated files first, this should align with .gitattributes file description.
|
||||
existing_workflows = GITHUB_DIR.glob("workflows/generated-*")
|
||||
@ -323,5 +366,6 @@ def main() -> None:
|
||||
for workflow in workflows:
|
||||
workflow.generate_workflow_file(workflow_template=template)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
40
.github/scripts/generate_pytorch_version.py
vendored
40
.github/scripts/generate_pytorch_version.py
vendored
@ -2,8 +2,8 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import re
|
||||
import subprocess
|
||||
|
||||
from datetime import datetime
|
||||
from distutils.util import strtobool
|
||||
@ -13,21 +13,27 @@ LEADING_V_PATTERN = re.compile("^v")
|
||||
TRAILING_RC_PATTERN = re.compile("-rc[0-9]*$")
|
||||
LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$")
|
||||
|
||||
|
||||
class NoGitTagException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_pytorch_root() -> Path:
|
||||
return Path(subprocess.check_output(
|
||||
['git', 'rev-parse', '--show-toplevel']
|
||||
).decode('ascii').strip())
|
||||
return Path(
|
||||
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
|
||||
.decode("ascii")
|
||||
.strip()
|
||||
)
|
||||
|
||||
|
||||
def get_tag() -> str:
|
||||
root = get_pytorch_root()
|
||||
try:
|
||||
dirty_tag = subprocess.check_output(
|
||||
['git', 'describe', '--tags', '--exact'],
|
||||
cwd=root
|
||||
).decode('ascii').strip()
|
||||
dirty_tag = (
|
||||
subprocess.check_output(["git", "describe", "--tags", "--exact"], cwd=root)
|
||||
.decode("ascii")
|
||||
.strip()
|
||||
)
|
||||
except subprocess.CalledProcessError:
|
||||
return ""
|
||||
# Strip leading v that we typically do when we tag branches
|
||||
@ -41,13 +47,15 @@ def get_tag() -> str:
|
||||
return ""
|
||||
return tag
|
||||
|
||||
|
||||
def get_base_version() -> str:
|
||||
root = get_pytorch_root()
|
||||
dirty_version = open(root / 'version.txt', 'r').read().strip()
|
||||
dirty_version = open(root / "version.txt", "r").read().strip()
|
||||
# Strips trailing a0 from version.txt, not too sure why it's there in the
|
||||
# first place
|
||||
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)
|
||||
|
||||
|
||||
class PytorchVersion:
|
||||
def __init__(
|
||||
self,
|
||||
@ -74,10 +82,11 @@ class PytorchVersion:
|
||||
return f"{get_tag()}{self.get_post_build_suffix()}"
|
||||
|
||||
def get_nightly_version(self) -> str:
|
||||
date_str = datetime.today().strftime('%Y%m%d')
|
||||
date_str = datetime.today().strftime("%Y%m%d")
|
||||
build_suffix = self.get_post_build_suffix()
|
||||
return f"{get_base_version()}.dev{date_str}{build_suffix}"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate pytorch version for binary builds"
|
||||
@ -86,30 +95,29 @@ def main() -> None:
|
||||
"--no-build-suffix",
|
||||
action="store_true",
|
||||
help="Whether or not to add a build suffix typically (+cpu)",
|
||||
default=strtobool(os.environ.get("NO_BUILD_SUFFIX", "False"))
|
||||
default=strtobool(os.environ.get("NO_BUILD_SUFFIX", "False")),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-arch-type",
|
||||
type=str,
|
||||
help="GPU arch you are building for, typically (cpu, cuda, rocm)",
|
||||
default=os.environ.get("GPU_ARCH_TYPE", "cpu")
|
||||
default=os.environ.get("GPU_ARCH_TYPE", "cpu"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-arch-version",
|
||||
type=str,
|
||||
help="GPU arch version, typically (10.2, 4.0), leave blank for CPU",
|
||||
default=os.environ.get("GPU_ARCH_VERSION", "")
|
||||
default=os.environ.get("GPU_ARCH_VERSION", ""),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
version_obj = PytorchVersion(
|
||||
args.gpu_arch_type,
|
||||
args.gpu_arch_version,
|
||||
args.no_build_suffix
|
||||
args.gpu_arch_type, args.gpu_arch_version, args.no_build_suffix
|
||||
)
|
||||
try:
|
||||
print(version_obj.get_release_version())
|
||||
except NoGitTagException:
|
||||
print(version_obj.get_nightly_version())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
41
.github/scripts/get_workflow_job_id.py
vendored
41
.github/scripts/get_workflow_job_id.py
vendored
@ -11,9 +11,10 @@ import time
|
||||
import urllib
|
||||
import urllib.parse
|
||||
|
||||
from typing import Any, Callable, Dict, List, Tuple, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
|
||||
def parse_json_and_links(conn: Any) -> Tuple[Any, Dict[str, Dict[str, str]]]:
|
||||
links = {}
|
||||
# Extract links which GH uses for pagination
|
||||
@ -26,18 +27,26 @@ def parse_json_and_links(conn: Any) -> Tuple[Any, Dict[str, Dict[str, str]]]:
|
||||
continue
|
||||
url = urllib.parse.unquote(url.strip("<> "))
|
||||
qparams = urllib.parse.parse_qs(params_.strip(), separator=";")
|
||||
params = {k: v[0].strip('"') for k, v in qparams.items() if type(v) is list and len(v) > 0}
|
||||
params = {
|
||||
k: v[0].strip('"')
|
||||
for k, v in qparams.items()
|
||||
if type(v) is list and len(v) > 0
|
||||
}
|
||||
params["url"] = url
|
||||
if "rel" in params:
|
||||
links[params["rel"]] = params
|
||||
|
||||
return json.load(conn), links
|
||||
|
||||
def fetch_url(url: str, *,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
reader: Callable[[Any], Any] = lambda x: x.read(),
|
||||
retries: Optional[int] = 3,
|
||||
backoff_timeout: float = .5) -> Any:
|
||||
|
||||
def fetch_url(
|
||||
url: str,
|
||||
*,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
reader: Callable[[Any], Any] = lambda x: x.read(),
|
||||
retries: Optional[int] = 3,
|
||||
backoff_timeout: float = 0.5,
|
||||
) -> Any:
|
||||
if headers is None:
|
||||
headers = {}
|
||||
try:
|
||||
@ -46,14 +55,21 @@ def fetch_url(url: str, *,
|
||||
except urllib.error.HTTPError as err:
|
||||
if isinstance(retries, (int, float)) and retries > 0:
|
||||
time.sleep(backoff_timeout)
|
||||
return fetch_url(url, headers=headers, reader=reader, retries=retries - 1, backoff_timeout=backoff_timeout)
|
||||
return fetch_url(
|
||||
url,
|
||||
headers=headers,
|
||||
reader=reader,
|
||||
retries=retries - 1,
|
||||
backoff_timeout=backoff_timeout,
|
||||
)
|
||||
exception_message = (
|
||||
"Is github alright?",
|
||||
f"Recieved status code '{err.code}' when attempting to retrieve {url}:\n",
|
||||
f"{err.reason}\n\nheaders={err.headers}"
|
||||
f"{err.reason}\n\nheaders={err.headers}",
|
||||
)
|
||||
raise RuntimeError(exception_message) from err
|
||||
|
||||
|
||||
def parse_args() -> Any:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@ -72,7 +88,9 @@ def fetch_jobs(url: str, headers: Dict[str, str]) -> List[Dict[str, str]]:
|
||||
jobs = response["jobs"]
|
||||
assert type(jobs) is list
|
||||
while "next" in links.keys():
|
||||
response, links = fetch_url(links["next"]["url"], headers=headers, reader=parse_json_and_links)
|
||||
response, links = fetch_url(
|
||||
links["next"]["url"], headers=headers, reader=parse_json_and_links
|
||||
)
|
||||
jobs.extend(response["jobs"])
|
||||
|
||||
return jobs
|
||||
@ -92,6 +110,7 @@ def fetch_jobs(url: str, headers: Dict[str, str]) -> List[Dict[str, str]]:
|
||||
# looking for RUNNER_NAME will uniquely identify the job we're currently
|
||||
# running.
|
||||
|
||||
|
||||
def find_job_id(args: Any) -> str:
|
||||
# From https://docs.github.com/en/actions/learn-github-actions/environment-variables
|
||||
PYTORCH_REPO = os.environ.get("GITHUB_REPOSITORY", "pytorch/pytorch")
|
||||
@ -115,6 +134,7 @@ def find_job_id(args: Any) -> str:
|
||||
|
||||
raise RuntimeError(f"Can't find job id for runner {args.runner_name}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
try:
|
||||
@ -123,5 +143,6 @@ def main() -> None:
|
||||
print(repr(e), file=sys.stderr)
|
||||
print(f"workflow-{args.workflow_run_id}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
71
.github/scripts/github_utils.py
vendored
71
.github/scripts/github_utils.py
vendored
@ -21,56 +21,69 @@ class GitHubComment:
|
||||
|
||||
|
||||
def gh_fetch_url(
|
||||
url: str, *,
|
||||
url: str,
|
||||
*,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
method: Optional[str] = None,
|
||||
reader: Callable[[Any], Any] = lambda x: x.read()
|
||||
reader: Callable[[Any], Any] = lambda x: x.read(),
|
||||
) -> Any:
|
||||
if headers is None:
|
||||
headers = {}
|
||||
token = os.environ.get("GITHUB_TOKEN")
|
||||
if token is not None and url.startswith('https://api.github.com/'):
|
||||
headers['Authorization'] = f'token {token}'
|
||||
if token is not None and url.startswith("https://api.github.com/"):
|
||||
headers["Authorization"] = f"token {token}"
|
||||
data_ = json.dumps(data).encode() if data is not None else None
|
||||
try:
|
||||
with urlopen(Request(url, headers=headers, data=data_, method=method)) as conn:
|
||||
return reader(conn)
|
||||
except HTTPError as err:
|
||||
if err.code == 403 and all(key in err.headers for key in ['X-RateLimit-Limit', 'X-RateLimit-Used']):
|
||||
print(f"""Rate limit exceeded:
|
||||
if err.code == 403 and all(
|
||||
key in err.headers for key in ["X-RateLimit-Limit", "X-RateLimit-Used"]
|
||||
):
|
||||
print(
|
||||
f"""Rate limit exceeded:
|
||||
Used: {err.headers['X-RateLimit-Used']}
|
||||
Limit: {err.headers['X-RateLimit-Limit']}
|
||||
Remaining: {err.headers['X-RateLimit-Remaining']}
|
||||
Resets at: {err.headers['x-RateLimit-Reset']}""")
|
||||
Resets at: {err.headers['x-RateLimit-Reset']}"""
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def gh_fetch_json(
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
headers = {'Accept': 'application/vnd.github.v3+json'}
|
||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||
if params is not None and len(params) > 0:
|
||||
url += '?' + '&'.join(f"{name}={quote(str(val))}" for name, val in params.items())
|
||||
return cast(List[Dict[str, Any]], gh_fetch_url(url, headers=headers, data=data, reader=json.load))
|
||||
url += "?" + "&".join(
|
||||
f"{name}={quote(str(val))}" for name, val in params.items()
|
||||
)
|
||||
return cast(
|
||||
List[Dict[str, Any]],
|
||||
gh_fetch_url(url, headers=headers, data=data, reader=json.load),
|
||||
)
|
||||
|
||||
|
||||
def _gh_fetch_json_any(
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Any:
|
||||
headers = {'Accept': 'application/vnd.github.v3+json'}
|
||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||
if params is not None and len(params) > 0:
|
||||
url += '?' + '&'.join(f"{name}={quote(str(val))}" for name, val in params.items())
|
||||
url += "?" + "&".join(
|
||||
f"{name}={quote(str(val))}" for name, val in params.items()
|
||||
)
|
||||
return gh_fetch_url(url, headers=headers, data=data, reader=json.load)
|
||||
|
||||
|
||||
def gh_fetch_json_list(
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
return cast(List[Dict[str, Any]], _gh_fetch_json_any(url, params, data))
|
||||
|
||||
@ -78,24 +91,38 @@ def gh_fetch_json_list(
|
||||
def gh_fetch_json_dict(
|
||||
url: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any] :
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data))
|
||||
|
||||
|
||||
def _gh_post_comment(url: str, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]:
|
||||
def _gh_post_comment(
|
||||
url: str, comment: str, dry_run: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
if dry_run:
|
||||
print(comment)
|
||||
return []
|
||||
return gh_fetch_json_list(url, data={"body": comment})
|
||||
|
||||
|
||||
def gh_post_pr_comment(org: str, repo: str, pr_num: int, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]:
|
||||
return _gh_post_comment(f'https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/comments', comment, dry_run)
|
||||
def gh_post_pr_comment(
|
||||
org: str, repo: str, pr_num: int, comment: str, dry_run: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
return _gh_post_comment(
|
||||
f"https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/comments",
|
||||
comment,
|
||||
dry_run,
|
||||
)
|
||||
|
||||
|
||||
def gh_post_commit_comment(org: str, repo: str, sha: str, comment: str, dry_run: bool = False) -> List[Dict[str, Any]]:
|
||||
return _gh_post_comment(f'https://api.github.com/repos/{org}/{repo}/commits/{sha}/comments', comment, dry_run)
|
||||
def gh_post_commit_comment(
|
||||
org: str, repo: str, sha: str, comment: str, dry_run: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
return _gh_post_comment(
|
||||
f"https://api.github.com/repos/{org}/{repo}/commits/{sha}/comments",
|
||||
comment,
|
||||
dry_run,
|
||||
)
|
||||
|
||||
|
||||
def gh_delete_comment(org: str, repo: str, comment_id: int) -> None:
|
||||
|
117
.github/scripts/gitutils.py
vendored
117
.github/scripts/gitutils.py
vendored
@ -5,7 +5,7 @@ import re
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import cast, Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, cast, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
RE_GITHUB_URL_MATCH = re.compile("^https://.*@?github.com/(.+)/(.+)$")
|
||||
@ -17,6 +17,7 @@ def get_git_remote_name() -> str:
|
||||
|
||||
def get_git_repo_dir() -> str:
|
||||
from pathlib import Path
|
||||
|
||||
return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parent.parent.parent))
|
||||
|
||||
|
||||
@ -25,13 +26,14 @@ def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]:
|
||||
Converts list to dict preserving elements with duplicate keys
|
||||
"""
|
||||
rc: Dict[str, List[str]] = defaultdict(lambda: [])
|
||||
for (key, val) in items:
|
||||
for key, val in items:
|
||||
rc[key].append(val)
|
||||
return dict(rc)
|
||||
|
||||
|
||||
def _check_output(items: List[str], encoding: str = "utf-8") -> str:
|
||||
from subprocess import check_output, CalledProcessError, STDOUT
|
||||
from subprocess import CalledProcessError, check_output, STDOUT
|
||||
|
||||
try:
|
||||
return check_output(items, stderr=STDOUT).decode(encoding)
|
||||
except CalledProcessError as e:
|
||||
@ -53,13 +55,15 @@ class GitCommit:
|
||||
author_date: datetime
|
||||
commit_date: Optional[datetime]
|
||||
|
||||
def __init__(self,
|
||||
commit_hash: str,
|
||||
author: str,
|
||||
author_date: datetime,
|
||||
title: str,
|
||||
body: str,
|
||||
commit_date: Optional[datetime] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
commit_hash: str,
|
||||
author: str,
|
||||
author_date: datetime,
|
||||
title: str,
|
||||
body: str,
|
||||
commit_date: Optional[datetime] = None,
|
||||
) -> None:
|
||||
self.commit_hash = commit_hash
|
||||
self.author = author
|
||||
self.author_date = author_date
|
||||
@ -100,13 +104,14 @@ def parse_fuller_format(lines: Union[str, List[str]]) -> GitCommit:
|
||||
assert lines[3].startswith("Commit: ")
|
||||
assert lines[4].startswith("CommitDate: ")
|
||||
assert len(lines[5]) == 0
|
||||
return GitCommit(commit_hash=lines[0].split()[1].strip(),
|
||||
author=lines[1].split(":", 1)[1].strip(),
|
||||
author_date=datetime.fromtimestamp(int(lines[2].split(":", 1)[1].strip())),
|
||||
commit_date=datetime.fromtimestamp(int(lines[4].split(":", 1)[1].strip())),
|
||||
title=lines[6].strip(),
|
||||
body="\n".join(lines[7:]),
|
||||
)
|
||||
return GitCommit(
|
||||
commit_hash=lines[0].split()[1].strip(),
|
||||
author=lines[1].split(":", 1)[1].strip(),
|
||||
author_date=datetime.fromtimestamp(int(lines[2].split(":", 1)[1].strip())),
|
||||
commit_date=datetime.fromtimestamp(int(lines[4].split(":", 1)[1].strip())),
|
||||
title=lines[6].strip(),
|
||||
body="\n".join(lines[7:]),
|
||||
)
|
||||
|
||||
|
||||
class GitRepo:
|
||||
@ -139,16 +144,16 @@ class GitRepo:
|
||||
self._run_git("fetch", self.remote, f"{ref}:{branch}")
|
||||
|
||||
def show_ref(self, name: str) -> str:
|
||||
refs = self._run_git('show-ref', '-s', name).strip().split('\n')
|
||||
refs = self._run_git("show-ref", "-s", name).strip().split("\n")
|
||||
if not all(refs[i] == refs[0] for i in range(1, len(refs))):
|
||||
raise RuntimeError(f"referce {name} is ambigous")
|
||||
return refs[0]
|
||||
|
||||
def rev_parse(self, name: str) -> str:
|
||||
return self._run_git('rev-parse', '--verify', name).strip()
|
||||
return self._run_git("rev-parse", "--verify", name).strip()
|
||||
|
||||
def get_merge_base(self, from_ref: str, to_ref: str) -> str:
|
||||
return self._run_git('merge-base', from_ref, to_ref).strip()
|
||||
return self._run_git("merge-base", from_ref, to_ref).strip()
|
||||
|
||||
def patch_id(self, ref: Union[str, List[str]]) -> List[Tuple[str, str]]:
|
||||
is_list = isinstance(ref, list)
|
||||
@ -156,25 +161,31 @@ class GitRepo:
|
||||
if len(ref) == 0:
|
||||
return []
|
||||
ref = " ".join(ref)
|
||||
rc = _check_output(['sh', '-c', f'git -C {self.repo_dir} show {ref}|git patch-id --stable']).strip()
|
||||
rc = _check_output(
|
||||
["sh", "-c", f"git -C {self.repo_dir} show {ref}|git patch-id --stable"]
|
||||
).strip()
|
||||
return [cast(Tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")]
|
||||
|
||||
def commits_resolving_gh_pr(self, pr_num: int) -> List[str]:
|
||||
owner, name = self.gh_owner_and_name()
|
||||
msg = f"Pull Request resolved: https://github.com/{owner}/{name}/pull/{pr_num}"
|
||||
rc = self._run_git('log', '--format=%H', '--grep', msg).strip()
|
||||
rc = self._run_git("log", "--format=%H", "--grep", msg).strip()
|
||||
return rc.split("\n") if len(rc) > 0 else []
|
||||
|
||||
def get_commit(self, ref: str) -> GitCommit:
|
||||
return parse_fuller_format(self._run_git('show', '--format=fuller', '--date=unix', '--shortstat', ref))
|
||||
return parse_fuller_format(
|
||||
self._run_git("show", "--format=fuller", "--date=unix", "--shortstat", ref)
|
||||
)
|
||||
|
||||
def cherry_pick(self, ref: str) -> None:
|
||||
self._run_git('cherry-pick', '-x', ref)
|
||||
self._run_git("cherry-pick", "-x", ref)
|
||||
|
||||
def revert(self, ref: str) -> None:
|
||||
self._run_git("revert", "--no-edit", ref)
|
||||
|
||||
def compute_branch_diffs(self, from_branch: str, to_branch: str) -> Tuple[List[str], List[str]]:
|
||||
def compute_branch_diffs(
|
||||
self, from_branch: str, to_branch: str
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Returns list of commmits that are missing in each other branch since their merge base
|
||||
Might be slow if merge base is between two branches is pretty far off
|
||||
@ -182,8 +193,8 @@ class GitRepo:
|
||||
from_ref = self.rev_parse(from_branch)
|
||||
to_ref = self.rev_parse(to_branch)
|
||||
merge_base = self.get_merge_base(from_ref, to_ref)
|
||||
from_commits = self.revlist(f'{merge_base}..{from_ref}')
|
||||
to_commits = self.revlist(f'{merge_base}..{to_ref}')
|
||||
from_commits = self.revlist(f"{merge_base}..{from_ref}")
|
||||
to_commits = self.revlist(f"{merge_base}..{to_ref}")
|
||||
from_ids = fuzzy_list_to_dict(self.patch_id(from_commits))
|
||||
to_ids = fuzzy_list_to_dict(self.patch_id(to_commits))
|
||||
for patch_id in set(from_ids).intersection(set(to_ids)):
|
||||
@ -199,14 +210,19 @@ class GitRepo:
|
||||
# HACK: Same commit were merged, reverted and landed again
|
||||
# which creates a tracking problem
|
||||
if (
|
||||
"pytorch/pytorch" not in self.remote_url() or
|
||||
frc.commit_hash not in {"0a6a1b27a464ba5be5f587cce2ee12ab8c504dbf",
|
||||
"6d0f4a1d545a8f161df459e8d4ccafd4b9017dbe",
|
||||
"edf909e58f06150f7be41da2f98a3b9de3167bca",
|
||||
"a58c6aea5a0c9f8759a4154e46f544c8b03b8db1",
|
||||
"7106d216c29ca16a3504aa2bedad948ebcf4abc2"}
|
||||
"pytorch/pytorch" not in self.remote_url()
|
||||
or frc.commit_hash
|
||||
not in {
|
||||
"0a6a1b27a464ba5be5f587cce2ee12ab8c504dbf",
|
||||
"6d0f4a1d545a8f161df459e8d4ccafd4b9017dbe",
|
||||
"edf909e58f06150f7be41da2f98a3b9de3167bca",
|
||||
"a58c6aea5a0c9f8759a4154e46f544c8b03b8db1",
|
||||
"7106d216c29ca16a3504aa2bedad948ebcf4abc2",
|
||||
}
|
||||
):
|
||||
raise RuntimeError(f"Unexpected differences between {frc} and {toc}")
|
||||
raise RuntimeError(
|
||||
f"Unexpected differences between {frc} and {toc}"
|
||||
)
|
||||
from_commits.remove(frc.commit_hash)
|
||||
to_commits.remove(toc.commit_hash)
|
||||
continue
|
||||
@ -217,11 +233,13 @@ class GitRepo:
|
||||
# Another HACK: Patch-id is not stable for commits with binary files or for big changes across commits
|
||||
# I.e. cherry-picking those from one branch into another will change patchid
|
||||
if "pytorch/pytorch" in self.remote_url():
|
||||
for excluded_commit in {"8e09e20c1dafcdbdb45c2d1574da68a32e54a3a5",
|
||||
"5f37e5c2a39c3acb776756a17730b865f0953432",
|
||||
"b5222584e6d6990c6585981a936defd1af14c0ba",
|
||||
"84d9a2e42d5ed30ec3b8b4140c38dd83abbce88d",
|
||||
"f211ec90a6cdc8a2a5795478b5b5c8d7d7896f7e"}:
|
||||
for excluded_commit in {
|
||||
"8e09e20c1dafcdbdb45c2d1574da68a32e54a3a5",
|
||||
"5f37e5c2a39c3acb776756a17730b865f0953432",
|
||||
"b5222584e6d6990c6585981a936defd1af14c0ba",
|
||||
"84d9a2e42d5ed30ec3b8b4140c38dd83abbce88d",
|
||||
"f211ec90a6cdc8a2a5795478b5b5c8d7d7896f7e",
|
||||
}:
|
||||
if excluded_commit in from_commits:
|
||||
from_commits.remove(excluded_commit)
|
||||
|
||||
@ -281,7 +299,14 @@ class GitRepo:
|
||||
|
||||
def clone_repo(username: str, password: str, org: str, project: str) -> GitRepo:
|
||||
path = tempfile.mkdtemp()
|
||||
_check_output(['git', 'clone', f'https://{username}:{password}@github.com/{org}/{project}', path]).strip()
|
||||
_check_output(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
f"https://{username}:{password}@github.com/{org}/{project}",
|
||||
path,
|
||||
]
|
||||
).strip()
|
||||
return GitRepo(path=path)
|
||||
|
||||
|
||||
@ -337,17 +362,21 @@ def patterns_to_regex(allowed_patterns: List[str]) -> Any:
|
||||
rc += ")"
|
||||
return re.compile(rc)
|
||||
|
||||
|
||||
def _shasum(value: str) -> str:
|
||||
import hashlib
|
||||
|
||||
m = hashlib.sha256()
|
||||
m.update(value.encode("utf-8"))
|
||||
return m.hexdigest()
|
||||
|
||||
|
||||
def are_ghstack_branches_in_sync(repo: GitRepo, head_ref: str) -> bool:
|
||||
""" Checks that diff between base and head is the same as diff between orig and its parent """
|
||||
orig_ref = re.sub(r'/head$', '/orig', head_ref)
|
||||
base_ref = re.sub(r'/head$', '/base', head_ref)
|
||||
"""Checks that diff between base and head is the same as diff between orig and its parent"""
|
||||
orig_ref = re.sub(r"/head$", "/orig", head_ref)
|
||||
base_ref = re.sub(r"/head$", "/base", head_ref)
|
||||
orig_diff_sha = _shasum(repo.diff(f"{repo.remote}/{orig_ref}"))
|
||||
head_diff_sha = _shasum(repo.diff(f"{repo.remote}/{base_ref}", f"{repo.remote}/{head_ref}"))
|
||||
head_diff_sha = _shasum(
|
||||
repo.diff(f"{repo.remote}/{base_ref}", f"{repo.remote}/{head_ref}")
|
||||
)
|
||||
return orig_diff_sha == head_diff_sha
|
||||
|
51
.github/scripts/label_utils.py
vendored
51
.github/scripts/label_utils.py
vendored
@ -3,13 +3,10 @@
|
||||
import json
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import List, Any, Tuple, TYPE_CHECKING, Union
|
||||
from urllib.request import urlopen, Request
|
||||
from typing import Any, List, Tuple, TYPE_CHECKING, Union
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
from github_utils import (
|
||||
GitHubComment,
|
||||
gh_fetch_json,
|
||||
)
|
||||
from github_utils import gh_fetch_json, GitHubComment
|
||||
|
||||
# TODO: this is a temp workaround to avoid circular dependencies,
|
||||
# and should be removed once GitHubPR is refactored out of trymerge script.
|
||||
@ -31,14 +28,15 @@ For more information, see
|
||||
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.
|
||||
"""
|
||||
|
||||
|
||||
# Modified from https://github.com/pytorch/pytorch/blob/b00206d4737d1f1e7a442c9f8a1cadccd272a386/torch/hub.py#L129
|
||||
def _read_url(url: Request) -> Tuple[Any, Any]:
|
||||
with urlopen(url) as r:
|
||||
return r.headers, r.read().decode(r.headers.get_content_charset('utf-8'))
|
||||
return r.headers, r.read().decode(r.headers.get_content_charset("utf-8"))
|
||||
|
||||
|
||||
def request_for_labels(url: str) -> Tuple[Any, Any]:
|
||||
headers = {'Accept': 'application/vnd.github.v3+json'}
|
||||
headers = {"Accept": "application/vnd.github.v3+json"}
|
||||
return _read_url(Request(url, headers=headers))
|
||||
|
||||
|
||||
@ -50,10 +48,12 @@ def update_labels(labels: List[str], info: str) -> None:
|
||||
def get_last_page_num_from_header(header: Any) -> int:
|
||||
# Link info looks like: <https://api.github.com/repositories/65600975/labels?per_page=100&page=2>;
|
||||
# rel="next", <https://api.github.com/repositories/65600975/labels?per_page=100&page=3>; rel="last"
|
||||
link_info = header['link']
|
||||
link_info = header["link"]
|
||||
prefix = "&page="
|
||||
suffix = ">;"
|
||||
return int(link_info[link_info.rindex(prefix) + len(prefix):link_info.rindex(suffix)])
|
||||
return int(
|
||||
link_info[link_info.rindex(prefix) + len(prefix) : link_info.rindex(suffix)]
|
||||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
@ -64,7 +64,9 @@ def gh_get_labels(org: str, repo: str) -> List[str]:
|
||||
update_labels(labels, info)
|
||||
|
||||
last_page = get_last_page_num_from_header(header)
|
||||
assert last_page > 0, "Error reading header info to determine total number of pages of labels"
|
||||
assert (
|
||||
last_page > 0
|
||||
), "Error reading header info to determine total number of pages of labels"
|
||||
for page_number in range(2, last_page + 1): # skip page 1
|
||||
_, info = request_for_labels(prefix + f"&page={page_number}")
|
||||
update_labels(labels, info)
|
||||
@ -72,26 +74,37 @@ def gh_get_labels(org: str, repo: str) -> List[str]:
|
||||
return labels
|
||||
|
||||
|
||||
def gh_add_labels(org: str, repo: str, pr_num: int, labels: Union[str, List[str]]) -> None:
|
||||
def gh_add_labels(
|
||||
org: str, repo: str, pr_num: int, labels: Union[str, List[str]]
|
||||
) -> None:
|
||||
gh_fetch_json(
|
||||
f'https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/labels',
|
||||
f"https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/labels",
|
||||
data={"labels": labels},
|
||||
)
|
||||
|
||||
|
||||
def get_release_notes_labels(org: str, repo: str) -> List[str]:
|
||||
return [label for label in gh_get_labels(org, repo) if label.lstrip().startswith("release notes:")]
|
||||
return [
|
||||
label
|
||||
for label in gh_get_labels(org, repo)
|
||||
if label.lstrip().startswith("release notes:")
|
||||
]
|
||||
|
||||
|
||||
def has_required_labels(pr: "GitHubPR") -> bool:
|
||||
pr_labels = pr.get_labels()
|
||||
# Check if PR is not user facing
|
||||
is_not_user_facing_pr = any(label.strip() == "topic: not user facing" for label in pr_labels)
|
||||
return (
|
||||
is_not_user_facing_pr or
|
||||
any(label.strip() in get_release_notes_labels(pr.org, pr.project) for label in pr_labels)
|
||||
is_not_user_facing_pr = any(
|
||||
label.strip() == "topic: not user facing" for label in pr_labels
|
||||
)
|
||||
return is_not_user_facing_pr or any(
|
||||
label.strip() in get_release_notes_labels(pr.org, pr.project)
|
||||
for label in pr_labels
|
||||
)
|
||||
|
||||
|
||||
def is_label_err_comment(comment: GitHubComment) -> bool:
|
||||
return comment.body_text.lstrip(" #").startswith(LABEL_ERR_MSG_TITLE) and comment.author_login in BOT_AUTHORS
|
||||
return (
|
||||
comment.body_text.lstrip(" #").startswith(LABEL_ERR_MSG_TITLE)
|
||||
and comment.author_login in BOT_AUTHORS
|
||||
)
|
||||
|
28
.github/scripts/lint_native_functions.py
vendored
28
.github/scripts/lint_native_functions.py
vendored
@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
'''
|
||||
"""
|
||||
Verify that it is possible to round-trip native_functions.yaml via ruamel under some
|
||||
configuration. Keeping native_functions.yaml consistent in this way allows us to
|
||||
run codemods on the file using ruamel without introducing line noise. Note that we don't
|
||||
@ -12,24 +12,27 @@ you may find that you want to use some format that is not what ruamel prefers.
|
||||
it is OK to modify this script (instead of reformatting native_functions.yaml)--the point
|
||||
is simply to make sure that there is *some* configuration of ruamel that can round trip
|
||||
the YAML, not to be prescriptive about it.
|
||||
'''
|
||||
"""
|
||||
|
||||
import ruamel.yaml # type: ignore[import]
|
||||
import difflib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
|
||||
import ruamel.yaml # type: ignore[import]
|
||||
|
||||
|
||||
def fn(base: str) -> str:
|
||||
return str(base / Path("aten/src/ATen/native/native_functions.yaml"))
|
||||
|
||||
with open(Path(__file__).parent.parent.parent / fn('.'), "r") as f:
|
||||
|
||||
with open(Path(__file__).parent.parent.parent / fn("."), "r") as f:
|
||||
contents = f.read()
|
||||
|
||||
yaml = ruamel.yaml.YAML() # type: ignore[attr-defined]
|
||||
yaml.preserve_quotes = True # type: ignore[assignment]
|
||||
yaml.width = 1000 # type: ignore[assignment]
|
||||
yaml.boolean_representation = ['False', 'True'] # type: ignore[attr-defined]
|
||||
yaml.boolean_representation = ["False", "True"] # type: ignore[attr-defined]
|
||||
r = yaml.load(contents)
|
||||
|
||||
# Cuz ruamel's author intentionally didn't include conversion to string
|
||||
@ -40,12 +43,19 @@ new_contents = string_stream.getvalue()
|
||||
string_stream.close()
|
||||
|
||||
if contents != new_contents:
|
||||
print("""\
|
||||
print(
|
||||
"""\
|
||||
|
||||
## LINT FAILURE: native_functions.yaml ##
|
||||
|
||||
native_functions.yaml failed lint; please apply the diff below to fix lint.
|
||||
If you think this is in error, please see .github/scripts/lint_native_functions.py
|
||||
""", file=sys.stderr)
|
||||
sys.stdout.writelines(difflib.unified_diff(contents.splitlines(True), new_contents.splitlines(True), fn('a'), fn('b')))
|
||||
""",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.stdout.writelines(
|
||||
difflib.unified_diff(
|
||||
contents.splitlines(True), new_contents.splitlines(True), fn("a"), fn("b")
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
2
.github/scripts/parse_ref.py
vendored
2
.github/scripts/parse_ref.py
vendored
@ -14,7 +14,7 @@ def set_output(name: str, val: str) -> None:
|
||||
|
||||
def main() -> None:
|
||||
ref = os.environ["GITHUB_REF"]
|
||||
m = re.match(r'^refs/(\w+)/(.*)$', ref)
|
||||
m = re.match(r"^refs/(\w+)/(.*)$", ref)
|
||||
if m:
|
||||
category, stripped = m.groups()
|
||||
if category == "heads":
|
||||
|
206
.github/scripts/run_torchbench.py
vendored
206
.github/scripts/run_torchbench.py
vendored
@ -9,19 +9,21 @@ Testing environment:
|
||||
- Python 3.8
|
||||
- CUDA 11.3
|
||||
"""
|
||||
import argparse
|
||||
|
||||
# Known issues:
|
||||
# 1. Does not reuse the build artifact in other CI workflows
|
||||
# 2. CI jobs are serialized because there is only one worker
|
||||
import os
|
||||
import boto3 # type: ignore[import]
|
||||
import git # type: ignore[import]
|
||||
import pathlib
|
||||
import argparse
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import boto3 # type: ignore[import]
|
||||
import git # type: ignore[import]
|
||||
|
||||
TORCHBENCH_CONFIG_NAME = "config.yaml"
|
||||
TORCHBENCH_USERBENCHMARK_CONFIG_NAME = "ub-config.yaml"
|
||||
MAGIC_PREFIX = "RUN_TORCHBENCH:"
|
||||
@ -37,15 +39,18 @@ S3_BUCKET = "ossci-metrics"
|
||||
S3_PREFIX = "torchbench-pr-test"
|
||||
S3_URL_BASE = f"https://{S3_BUCKET}.s3.amazonaws.com/"
|
||||
|
||||
|
||||
class S3Client:
|
||||
def __init__(self, bucket: str = S3_BUCKET, prefix: str = S3_PREFIX):
|
||||
self.s3 = boto3.client('s3')
|
||||
self.resource = boto3.resource('s3')
|
||||
self.s3 = boto3.client("s3")
|
||||
self.resource = boto3.resource("s3")
|
||||
self.bucket = bucket
|
||||
self.prefix = prefix
|
||||
|
||||
def upload_file(self, file_path: Path, filekey_prefix: str) -> None:
|
||||
assert file_path.is_file(), f"Specified file path {file_path} does not exist or not file."
|
||||
assert (
|
||||
file_path.is_file()
|
||||
), f"Specified file path {file_path} does not exist or not file."
|
||||
file_name = file_path.name
|
||||
s3_key = f"{self.prefix}/{filekey_prefix}/{file_name}"
|
||||
print(f"Uploading file {file_name} to S3 with key: {s3_key}")
|
||||
@ -53,6 +58,7 @@ class S3Client:
|
||||
# output the result URL
|
||||
print(f"Uploaded the result file {file_name} to {S3_URL_BASE}{s3_key}")
|
||||
|
||||
|
||||
def gen_abtest_config(control: str, treatment: str, models: List[str]) -> str:
|
||||
d = {}
|
||||
d["control"] = control
|
||||
@ -65,18 +71,23 @@ def gen_abtest_config(control: str, treatment: str, models: List[str]) -> str:
|
||||
config = config + "\n"
|
||||
return config
|
||||
|
||||
|
||||
def setup_gha_env(name: str, val: str) -> None:
|
||||
fname = os.environ["GITHUB_ENV"]
|
||||
content = f"{name}={val}\n"
|
||||
with open(fname, "a") as fo:
|
||||
fo.write(content)
|
||||
|
||||
|
||||
def find_current_branch(repo_path: str) -> str:
|
||||
repo = git.Repo(repo_path)
|
||||
name: str = repo.active_branch.name
|
||||
return name
|
||||
|
||||
def deploy_torchbench_config(output_dir: str, config: str, config_name: str = TORCHBENCH_CONFIG_NAME) -> None:
|
||||
|
||||
def deploy_torchbench_config(
|
||||
output_dir: str, config: str, config_name: str = TORCHBENCH_CONFIG_NAME
|
||||
) -> None:
|
||||
# Create test dir if needed
|
||||
pathlib.Path(output_dir).mkdir(exist_ok=True)
|
||||
# TorchBench config file name
|
||||
@ -84,20 +95,37 @@ def deploy_torchbench_config(output_dir: str, config: str, config_name: str = TO
|
||||
with open(config_path, "w") as fp:
|
||||
fp.write(config)
|
||||
|
||||
|
||||
def get_valid_models(torchbench_path: str) -> List[str]:
|
||||
benchmark_path = os.path.join(torchbench_path, "torchbenchmark", "models")
|
||||
valid_models = [model for model in os.listdir(benchmark_path) if os.path.isdir(os.path.join(benchmark_path, model))]
|
||||
valid_models = [
|
||||
model
|
||||
for model in os.listdir(benchmark_path)
|
||||
if os.path.isdir(os.path.join(benchmark_path, model))
|
||||
]
|
||||
return valid_models
|
||||
|
||||
|
||||
def get_valid_userbenchmarks(torchbench_path: str) -> List[str]:
|
||||
def is_valid_ub_dir(ub_path: str) -> bool:
|
||||
return os.path.isdir(ub_path) and os.path.exists(os.path.join(ub_path, "__init__.py"))
|
||||
return os.path.isdir(ub_path) and os.path.exists(
|
||||
os.path.join(ub_path, "__init__.py")
|
||||
)
|
||||
|
||||
ub_path = os.path.join(os.path.abspath(torchbench_path), "userbenchmark")
|
||||
ubs = list(filter(is_valid_ub_dir, [os.path.join(ub_path, ubdir) for ubdir in os.listdir(ub_path)]))
|
||||
ubs = list(
|
||||
filter(
|
||||
is_valid_ub_dir,
|
||||
[os.path.join(ub_path, ubdir) for ubdir in os.listdir(ub_path)],
|
||||
)
|
||||
)
|
||||
valid_ubs = list(map(lambda x: os.path.basename(x), ubs))
|
||||
return valid_ubs
|
||||
|
||||
def extract_models_from_pr(torchbench_path: str, prbody_file: str) -> Tuple[List[str], List[str]]:
|
||||
|
||||
def extract_models_from_pr(
|
||||
torchbench_path: str, prbody_file: str
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
model_list = []
|
||||
userbenchmark_list = []
|
||||
pr_list = []
|
||||
@ -106,7 +134,9 @@ def extract_models_from_pr(torchbench_path: str, prbody_file: str) -> Tuple[List
|
||||
magic_lines = list(filter(lambda x: x.startswith(MAGIC_PREFIX), lines))
|
||||
if magic_lines:
|
||||
# Only the first magic line will be recognized.
|
||||
pr_list = list(map(lambda x: x.strip(), magic_lines[0][len(MAGIC_PREFIX):].split(",")))
|
||||
pr_list = list(
|
||||
map(lambda x: x.strip(), magic_lines[0][len(MAGIC_PREFIX) :].split(","))
|
||||
)
|
||||
valid_models = get_valid_models(torchbench_path)
|
||||
valid_ubs = get_valid_userbenchmarks(torchbench_path)
|
||||
for pr_bm in pr_list:
|
||||
@ -115,53 +145,87 @@ def extract_models_from_pr(torchbench_path: str, prbody_file: str) -> Tuple[List
|
||||
elif pr_bm in valid_ubs:
|
||||
userbenchmark_list.append(pr_bm)
|
||||
else:
|
||||
print(f"The model or benchmark {pr_bm} you specified does not exist in TorchBench suite. Please double check.")
|
||||
print(
|
||||
f"The model or benchmark {pr_bm} you specified does not exist in TorchBench suite. Please double check."
|
||||
)
|
||||
exit(-1)
|
||||
# Shortcut: if pr_list is ["ALL"], run all the model tests
|
||||
if "ALL" in model_list:
|
||||
model_list = ["ALL"]
|
||||
return model_list, userbenchmark_list
|
||||
|
||||
|
||||
def find_torchbench_branch(prbody_file: str) -> str:
|
||||
branch_name: str = ""
|
||||
with open(prbody_file, "r") as pf:
|
||||
lines = map(lambda x: x.strip(), pf.read().splitlines())
|
||||
magic_lines = list(filter(lambda x: x.startswith(MAGIC_TORCHBENCH_PREFIX), lines))
|
||||
magic_lines = list(
|
||||
filter(lambda x: x.startswith(MAGIC_TORCHBENCH_PREFIX), lines)
|
||||
)
|
||||
if magic_lines:
|
||||
# Only the first magic line will be recognized.
|
||||
branch_name = magic_lines[0][len(MAGIC_TORCHBENCH_PREFIX):].strip()
|
||||
branch_name = magic_lines[0][len(MAGIC_TORCHBENCH_PREFIX) :].strip()
|
||||
# If not specified, use main as the default branch
|
||||
if not branch_name:
|
||||
branch_name = "main"
|
||||
return branch_name
|
||||
|
||||
|
||||
def run_torchbench(pytorch_path: str, torchbench_path: str, output_dir: str) -> None:
|
||||
# Copy system environment so that we will not override
|
||||
env = dict(os.environ)
|
||||
command = ["python", "bisection.py", "--work-dir", output_dir,
|
||||
"--pytorch-src", pytorch_path, "--torchbench-src", torchbench_path,
|
||||
"--config", os.path.join(output_dir, TORCHBENCH_CONFIG_NAME),
|
||||
"--output", os.path.join(output_dir, "result.txt")]
|
||||
command = [
|
||||
"python",
|
||||
"bisection.py",
|
||||
"--work-dir",
|
||||
output_dir,
|
||||
"--pytorch-src",
|
||||
pytorch_path,
|
||||
"--torchbench-src",
|
||||
torchbench_path,
|
||||
"--config",
|
||||
os.path.join(output_dir, TORCHBENCH_CONFIG_NAME),
|
||||
"--output",
|
||||
os.path.join(output_dir, "result.txt"),
|
||||
]
|
||||
print(f"Running torchbench command: {command}")
|
||||
subprocess.check_call(command, cwd=torchbench_path, env=env)
|
||||
|
||||
def run_userbenchmarks(pytorch_path: str, torchbench_path: str, base_sha: str, head_sha: str,
|
||||
userbenchmark: str, output_dir: str) -> None:
|
||||
|
||||
def run_userbenchmarks(
|
||||
pytorch_path: str,
|
||||
torchbench_path: str,
|
||||
base_sha: str,
|
||||
head_sha: str,
|
||||
userbenchmark: str,
|
||||
output_dir: str,
|
||||
) -> None:
|
||||
# Copy system environment so that we will not override
|
||||
env = dict(os.environ)
|
||||
command = ["python", "./.github/scripts/abtest.py",
|
||||
"--pytorch-repo", pytorch_path,
|
||||
"--base", base_sha,
|
||||
"--head", head_sha,
|
||||
"--userbenchmark", userbenchmark,
|
||||
"--output-dir", output_dir]
|
||||
command = [
|
||||
"python",
|
||||
"./.github/scripts/abtest.py",
|
||||
"--pytorch-repo",
|
||||
pytorch_path,
|
||||
"--base",
|
||||
base_sha,
|
||||
"--head",
|
||||
head_sha,
|
||||
"--userbenchmark",
|
||||
userbenchmark,
|
||||
"--output-dir",
|
||||
output_dir,
|
||||
]
|
||||
print(f"Running torchbench userbenchmark command: {command}")
|
||||
subprocess.check_call(command, cwd=torchbench_path, env=env)
|
||||
|
||||
|
||||
def process_upload_s3(result_dir: str) -> None:
|
||||
# validate result directory
|
||||
result_dir_path = Path(result_dir)
|
||||
assert result_dir_path.exists(), f"Specified result directory {result_dir} doesn't exist."
|
||||
assert (
|
||||
result_dir_path.exists()
|
||||
), f"Specified result directory {result_dir} doesn't exist."
|
||||
# upload all files to S3 bucket oss-ci-metrics
|
||||
files = [x for x in result_dir_path.iterdir() if x.is_file()]
|
||||
# upload file to S3 bucket
|
||||
@ -170,54 +234,92 @@ def process_upload_s3(result_dir: str) -> None:
|
||||
for f in files:
|
||||
s3_client.upload_file(f, filekey_prefix)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Run TorchBench tests based on PR')
|
||||
parser.add_argument('--pr-body', help="The file that contains body of a Pull Request")
|
||||
|
||||
subparsers = parser.add_subparsers(dest='command')
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run TorchBench tests based on PR")
|
||||
parser.add_argument(
|
||||
"--pr-body", help="The file that contains body of a Pull Request"
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command")
|
||||
# parser for setup the torchbench branch name env
|
||||
branch_parser = subparsers.add_parser("set-torchbench-branch")
|
||||
# parser to run the torchbench branch
|
||||
run_parser = subparsers.add_parser("run")
|
||||
run_parser.add_argument('--pr-num', required=True, type=str, help="The Pull Request number")
|
||||
run_parser.add_argument('--pr-base-sha', required=True, type=str, help="The Pull Request base hash")
|
||||
run_parser.add_argument('--pr-head-sha', required=True, type=str, help="The Pull Request head hash")
|
||||
run_parser.add_argument('--pytorch-path', required=True, type=str, help="Path to pytorch repository")
|
||||
run_parser.add_argument('--torchbench-path', required=True, type=str, help="Path to TorchBench repository")
|
||||
run_parser.add_argument(
|
||||
"--pr-num", required=True, type=str, help="The Pull Request number"
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--pr-base-sha", required=True, type=str, help="The Pull Request base hash"
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--pr-head-sha", required=True, type=str, help="The Pull Request head hash"
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--pytorch-path", required=True, type=str, help="Path to pytorch repository"
|
||||
)
|
||||
run_parser.add_argument(
|
||||
"--torchbench-path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="Path to TorchBench repository",
|
||||
)
|
||||
# parser to upload results to S3
|
||||
upload_parser = subparsers.add_parser("upload-s3")
|
||||
upload_parser.add_argument('--result-dir', required=True, type=str, help="Path to benchmark output")
|
||||
upload_parser.add_argument(
|
||||
"--result-dir", required=True, type=str, help="Path to benchmark output"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == 'set-torchbench-branch':
|
||||
if args.command == "set-torchbench-branch":
|
||||
branch_name = find_torchbench_branch(args.pr_body)
|
||||
# env name: "TORCHBENCH_BRANCH"
|
||||
setup_gha_env(MAGIC_TORCHBENCH_PREFIX[:-1], branch_name)
|
||||
elif args.command == 'run':
|
||||
output_dir: str = os.path.join(os.environ["HOME"], ".torchbench", "bisection", f"pr{args.pr_num}")
|
||||
elif args.command == "run":
|
||||
output_dir: str = os.path.join(
|
||||
os.environ["HOME"], ".torchbench", "bisection", f"pr{args.pr_num}"
|
||||
)
|
||||
# Assert the current branch in args.torchbench_path is the same as the one specified in pr body
|
||||
branch_name = find_torchbench_branch(args.pr_body)
|
||||
current_branch = find_current_branch(args.torchbench_path)
|
||||
assert branch_name == current_branch, f"Torchbench repo {args.torchbench_path} is on branch {current_branch}, \
|
||||
assert (
|
||||
branch_name == current_branch
|
||||
), f"Torchbench repo {args.torchbench_path} is on branch {current_branch}, \
|
||||
but user specified to run on branch {branch_name}."
|
||||
print(f"Ready to run TorchBench with benchmark. Result will be saved in the directory: {output_dir}.")
|
||||
print(
|
||||
f"Ready to run TorchBench with benchmark. Result will be saved in the directory: {output_dir}."
|
||||
)
|
||||
# Identify the specified models and userbenchmarks
|
||||
models, userbenchmarks = extract_models_from_pr(args.torchbench_path, args.pr_body)
|
||||
models, userbenchmarks = extract_models_from_pr(
|
||||
args.torchbench_path, args.pr_body
|
||||
)
|
||||
if models:
|
||||
torchbench_config = gen_abtest_config(args.pr_base_sha, args.pr_head_sha, models)
|
||||
torchbench_config = gen_abtest_config(
|
||||
args.pr_base_sha, args.pr_head_sha, models
|
||||
)
|
||||
deploy_torchbench_config(output_dir, torchbench_config)
|
||||
run_torchbench(pytorch_path=args.pytorch_path, torchbench_path=args.torchbench_path, output_dir=output_dir)
|
||||
run_torchbench(
|
||||
pytorch_path=args.pytorch_path,
|
||||
torchbench_path=args.torchbench_path,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
if userbenchmarks:
|
||||
assert len(userbenchmarks) == 1, \
|
||||
"We don't support running multiple userbenchmarks in single workflow yet." \
|
||||
assert len(userbenchmarks) == 1, (
|
||||
"We don't support running multiple userbenchmarks in single workflow yet."
|
||||
"If you need, please submit a feature request."
|
||||
run_userbenchmarks(pytorch_path=args.pytorch_path, torchbench_path=args.torchbench_path,
|
||||
base_sha=args.pr_base_sha, head_sha=args.pr_head_sha,
|
||||
userbenchmark=userbenchmarks[0], output_dir=output_dir)
|
||||
)
|
||||
run_userbenchmarks(
|
||||
pytorch_path=args.pytorch_path,
|
||||
torchbench_path=args.torchbench_path,
|
||||
base_sha=args.pr_base_sha,
|
||||
head_sha=args.pr_head_sha,
|
||||
userbenchmark=userbenchmarks[0],
|
||||
output_dir=output_dir,
|
||||
)
|
||||
if not models and not userbenchmarks:
|
||||
print("Can't parse valid models or userbenchmarks from the pr body. Quit.")
|
||||
exit(-1)
|
||||
elif args.command == 'upload-s3':
|
||||
elif args.command == "upload-s3":
|
||||
process_upload_s3(args.result_dir)
|
||||
else:
|
||||
print(f"The command {args.command} is not supported.")
|
||||
|
60
.github/scripts/test_check_labels.py
vendored
60
.github/scripts/test_check_labels.py
vendored
@ -1,30 +1,35 @@
|
||||
"""test_check_labels.py"""
|
||||
|
||||
from typing import Any, List
|
||||
from unittest import TestCase, mock, main
|
||||
from unittest import main, mock, TestCase
|
||||
|
||||
from check_labels import (
|
||||
main as check_labels_main,
|
||||
add_label_err_comment,
|
||||
delete_all_label_err_comments,
|
||||
main as check_labels_main,
|
||||
)
|
||||
from github_utils import GitHubComment
|
||||
from label_utils import BOT_AUTHORS, LABEL_ERR_MSG_TITLE
|
||||
from test_trymerge import mocked_gh_graphql, mock_gh_get_info
|
||||
from test_trymerge import mock_gh_get_info, mocked_gh_graphql
|
||||
from trymerge import GitHubPR
|
||||
|
||||
|
||||
def mock_parse_args() -> object:
|
||||
class Object(object):
|
||||
def __init__(self) -> None:
|
||||
self.pr_num = 76123
|
||||
|
||||
return Object()
|
||||
|
||||
|
||||
def mock_add_label_err_comment(pr: "GitHubPR") -> None:
|
||||
pass
|
||||
|
||||
|
||||
def mock_delete_all_label_err_comments(pr: "GitHubPR") -> None:
|
||||
pass
|
||||
|
||||
|
||||
def mock_get_comments() -> List[GitHubComment]:
|
||||
return [
|
||||
# Case 1 - a non label err comment
|
||||
@ -49,9 +54,9 @@ def mock_get_comments() -> List[GitHubComment]:
|
||||
|
||||
|
||||
class TestCheckLabels(TestCase):
|
||||
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
|
||||
@mock.patch('trymerge.GitHubPR.get_comments', return_value=[mock_get_comments()[0]])
|
||||
@mock.patch('check_labels.gh_post_pr_comment')
|
||||
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
|
||||
@mock.patch("trymerge.GitHubPR.get_comments", return_value=[mock_get_comments()[0]])
|
||||
@mock.patch("check_labels.gh_post_pr_comment")
|
||||
def test_correctly_add_label_err_comment(
|
||||
self, mock_gh_post_pr_comment: Any, mock_get_comments: Any, mock_gh_grphql: Any
|
||||
) -> None:
|
||||
@ -60,9 +65,9 @@ class TestCheckLabels(TestCase):
|
||||
add_label_err_comment(pr)
|
||||
mock_gh_post_pr_comment.assert_called_once()
|
||||
|
||||
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
|
||||
@mock.patch('trymerge.GitHubPR.get_comments', return_value=[mock_get_comments()[1]])
|
||||
@mock.patch('check_labels.gh_post_pr_comment')
|
||||
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
|
||||
@mock.patch("trymerge.GitHubPR.get_comments", return_value=[mock_get_comments()[1]])
|
||||
@mock.patch("check_labels.gh_post_pr_comment")
|
||||
def test_not_add_label_err_comment(
|
||||
self, mock_gh_post_pr_comment: Any, mock_get_comments: Any, mock_gh_grphql: Any
|
||||
) -> None:
|
||||
@ -71,9 +76,9 @@ class TestCheckLabels(TestCase):
|
||||
add_label_err_comment(pr)
|
||||
mock_gh_post_pr_comment.assert_not_called()
|
||||
|
||||
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
|
||||
@mock.patch('trymerge.GitHubPR.get_comments', return_value=mock_get_comments())
|
||||
@mock.patch('check_labels.gh_delete_comment')
|
||||
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
|
||||
@mock.patch("trymerge.GitHubPR.get_comments", return_value=mock_get_comments())
|
||||
@mock.patch("check_labels.gh_delete_comment")
|
||||
def test_correctly_delete_all_label_err_comments(
|
||||
self, mock_gh_delete_comment: Any, mock_get_comments: Any, mock_gh_grphql: Any
|
||||
) -> None:
|
||||
@ -82,11 +87,16 @@ class TestCheckLabels(TestCase):
|
||||
delete_all_label_err_comments(pr)
|
||||
mock_gh_delete_comment.assert_called_once_with("pytorch", "pytorch", 2)
|
||||
|
||||
@mock.patch('trymerge.gh_get_pr_info', return_value=mock_gh_get_info())
|
||||
@mock.patch('check_labels.parse_args', return_value=mock_parse_args())
|
||||
@mock.patch('check_labels.has_required_labels', return_value=False)
|
||||
@mock.patch('check_labels.delete_all_label_err_comments', side_effect=mock_delete_all_label_err_comments)
|
||||
@mock.patch('check_labels.add_label_err_comment', side_effect=mock_add_label_err_comment)
|
||||
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
|
||||
@mock.patch("check_labels.parse_args", return_value=mock_parse_args())
|
||||
@mock.patch("check_labels.has_required_labels", return_value=False)
|
||||
@mock.patch(
|
||||
"check_labels.delete_all_label_err_comments",
|
||||
side_effect=mock_delete_all_label_err_comments,
|
||||
)
|
||||
@mock.patch(
|
||||
"check_labels.add_label_err_comment", side_effect=mock_add_label_err_comment
|
||||
)
|
||||
def test_ci_comments_and_exit0_without_required_labels(
|
||||
self,
|
||||
mock_add_label_err_comment: Any,
|
||||
@ -101,11 +111,16 @@ class TestCheckLabels(TestCase):
|
||||
mock_add_label_err_comment.assert_called_once()
|
||||
mock_delete_all_label_err_comments.assert_not_called()
|
||||
|
||||
@mock.patch('trymerge.gh_get_pr_info', return_value=mock_gh_get_info())
|
||||
@mock.patch('check_labels.parse_args', return_value=mock_parse_args())
|
||||
@mock.patch('check_labels.has_required_labels', return_value=True)
|
||||
@mock.patch('check_labels.delete_all_label_err_comments', side_effect=mock_delete_all_label_err_comments)
|
||||
@mock.patch('check_labels.add_label_err_comment', side_effect=mock_add_label_err_comment)
|
||||
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
|
||||
@mock.patch("check_labels.parse_args", return_value=mock_parse_args())
|
||||
@mock.patch("check_labels.has_required_labels", return_value=True)
|
||||
@mock.patch(
|
||||
"check_labels.delete_all_label_err_comments",
|
||||
side_effect=mock_delete_all_label_err_comments,
|
||||
)
|
||||
@mock.patch(
|
||||
"check_labels.add_label_err_comment", side_effect=mock_add_label_err_comment
|
||||
)
|
||||
def test_ci_exit0_with_required_labels(
|
||||
self,
|
||||
mock_add_label_err_comment: Any,
|
||||
@ -120,5 +135,6 @@ class TestCheckLabels(TestCase):
|
||||
mock_add_label_err_comment.assert_not_called()
|
||||
mock_delete_all_label_err_comments.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
106
.github/scripts/test_fetch_latest_green_commit.py
vendored
106
.github/scripts/test_fetch_latest_green_commit.py
vendored
@ -1,5 +1,6 @@
|
||||
from unittest import TestCase, main, mock
|
||||
from typing import Any, List, Dict
|
||||
from typing import Any, Dict, List
|
||||
from unittest import main, mock, TestCase
|
||||
|
||||
from fetch_latest_green_commit import isGreen, WorkflowCheck
|
||||
|
||||
workflowNames = [
|
||||
@ -15,46 +16,72 @@ workflowNames = [
|
||||
"pr-labels",
|
||||
"Close stale pull requests",
|
||||
"Update S3 HTML indices for download.pytorch.org",
|
||||
"Create Release"
|
||||
"Create Release",
|
||||
]
|
||||
|
||||
def set_workflow_job_status(workflow: List[Dict[str, Any]], name: str, status: str) -> List[Dict[str, Any]]:
|
||||
|
||||
def set_workflow_job_status(
|
||||
workflow: List[Dict[str, Any]], name: str, status: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
for check in workflow:
|
||||
if check['workflowName'] == name:
|
||||
check['conclusion'] = status
|
||||
if check["workflowName"] == name:
|
||||
check["conclusion"] = status
|
||||
return workflow
|
||||
|
||||
|
||||
class TestChecks:
|
||||
def make_test_checks(self) -> List[Dict[str, Any]]:
|
||||
workflow_checks = []
|
||||
for i in range(len(workflowNames)):
|
||||
workflow_checks.append(WorkflowCheck(
|
||||
workflowName=workflowNames[i],
|
||||
name="test/job",
|
||||
jobName="job",
|
||||
conclusion="success",
|
||||
)._asdict())
|
||||
workflow_checks.append(
|
||||
WorkflowCheck(
|
||||
workflowName=workflowNames[i],
|
||||
name="test/job",
|
||||
jobName="job",
|
||||
conclusion="success",
|
||||
)._asdict()
|
||||
)
|
||||
return workflow_checks
|
||||
|
||||
|
||||
class TestPrintCommits(TestCase):
|
||||
@mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks())
|
||||
@mock.patch(
|
||||
"fetch_latest_green_commit.get_commit_results",
|
||||
return_value=TestChecks().make_test_checks(),
|
||||
)
|
||||
def test_all_successful(self, mock_get_commit_results: Any) -> None:
|
||||
"Test with workflows are successful"
|
||||
workflow_checks = mock_get_commit_results()
|
||||
self.assertTrue(isGreen("sha", workflow_checks)[0])
|
||||
|
||||
@mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks())
|
||||
@mock.patch(
|
||||
"fetch_latest_green_commit.get_commit_results",
|
||||
return_value=TestChecks().make_test_checks(),
|
||||
)
|
||||
def test_necessary_successful(self, mock_get_commit_results: Any) -> None:
|
||||
"Test with necessary workflows are successful"
|
||||
workflow_checks = mock_get_commit_results()
|
||||
workflow_checks = set_workflow_job_status(workflow_checks, workflowNames[8], "failed")
|
||||
workflow_checks = set_workflow_job_status(workflow_checks, workflowNames[9], "failed")
|
||||
workflow_checks = set_workflow_job_status(workflow_checks, workflowNames[10], "failed")
|
||||
workflow_checks = set_workflow_job_status(workflow_checks, workflowNames[11], "failed")
|
||||
workflow_checks = set_workflow_job_status(workflow_checks, workflowNames[12], "failed")
|
||||
workflow_checks = set_workflow_job_status(
|
||||
workflow_checks, workflowNames[8], "failed"
|
||||
)
|
||||
workflow_checks = set_workflow_job_status(
|
||||
workflow_checks, workflowNames[9], "failed"
|
||||
)
|
||||
workflow_checks = set_workflow_job_status(
|
||||
workflow_checks, workflowNames[10], "failed"
|
||||
)
|
||||
workflow_checks = set_workflow_job_status(
|
||||
workflow_checks, workflowNames[11], "failed"
|
||||
)
|
||||
workflow_checks = set_workflow_job_status(
|
||||
workflow_checks, workflowNames[12], "failed"
|
||||
)
|
||||
self.assertTrue(isGreen("sha", workflow_checks)[0])
|
||||
|
||||
@mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks())
|
||||
@mock.patch(
|
||||
"fetch_latest_green_commit.get_commit_results",
|
||||
return_value=TestChecks().make_test_checks(),
|
||||
)
|
||||
def test_necessary_skipped(self, mock_get_commit_results: Any) -> None:
|
||||
"Test with necessary job (ex: pull) skipped"
|
||||
workflow_checks = mock_get_commit_results()
|
||||
@ -62,15 +89,25 @@ class TestPrintCommits(TestCase):
|
||||
result = isGreen("sha", workflow_checks)
|
||||
self.assertTrue(result[0])
|
||||
|
||||
@mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks())
|
||||
@mock.patch(
|
||||
"fetch_latest_green_commit.get_commit_results",
|
||||
return_value=TestChecks().make_test_checks(),
|
||||
)
|
||||
def test_skippable_skipped(self, mock_get_commit_results: Any) -> None:
|
||||
"Test with skippable jobs (periodic and docker-release-builds skipped"
|
||||
workflow_checks = mock_get_commit_results()
|
||||
workflow_checks = set_workflow_job_status(workflow_checks, "periodic", "skipped")
|
||||
workflow_checks = set_workflow_job_status(workflow_checks, "docker-release-builds", "skipped")
|
||||
workflow_checks = set_workflow_job_status(
|
||||
workflow_checks, "periodic", "skipped"
|
||||
)
|
||||
workflow_checks = set_workflow_job_status(
|
||||
workflow_checks, "docker-release-builds", "skipped"
|
||||
)
|
||||
self.assertTrue(isGreen("sha", workflow_checks))
|
||||
|
||||
@mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks())
|
||||
@mock.patch(
|
||||
"fetch_latest_green_commit.get_commit_results",
|
||||
return_value=TestChecks().make_test_checks(),
|
||||
)
|
||||
def test_necessary_failed(self, mock_get_commit_results: Any) -> None:
|
||||
"Test with necessary job (ex: Lint) failed"
|
||||
workflow_checks = mock_get_commit_results()
|
||||
@ -79,22 +116,33 @@ class TestPrintCommits(TestCase):
|
||||
self.assertFalse(result[0])
|
||||
self.assertEqual(result[1], "Lint checks were not successful")
|
||||
|
||||
@mock.patch('fetch_latest_green_commit.get_commit_results', return_value=TestChecks().make_test_checks())
|
||||
@mock.patch(
|
||||
"fetch_latest_green_commit.get_commit_results",
|
||||
return_value=TestChecks().make_test_checks(),
|
||||
)
|
||||
def test_skippable_failed(self, mock_get_commit_results: Any) -> None:
|
||||
"Test with failing skippable jobs (ex: docker-release-builds) should pass"
|
||||
workflow_checks = mock_get_commit_results()
|
||||
workflow_checks = set_workflow_job_status(workflow_checks, "periodic", "skipped")
|
||||
workflow_checks = set_workflow_job_status(workflow_checks, "docker-release-builds", "failed")
|
||||
workflow_checks = set_workflow_job_status(
|
||||
workflow_checks, "periodic", "skipped"
|
||||
)
|
||||
workflow_checks = set_workflow_job_status(
|
||||
workflow_checks, "docker-release-builds", "failed"
|
||||
)
|
||||
result = isGreen("sha", workflow_checks)
|
||||
self.assertTrue(result[0])
|
||||
|
||||
@mock.patch('fetch_latest_green_commit.get_commit_results', return_value={})
|
||||
@mock.patch("fetch_latest_green_commit.get_commit_results", return_value={})
|
||||
def test_no_workflows(self, mock_get_commit_results: Any) -> None:
|
||||
"Test with missing workflows"
|
||||
workflow_checks = mock_get_commit_results()
|
||||
result = isGreen("sha", workflow_checks)
|
||||
self.assertFalse(result[0])
|
||||
self.assertEqual(result[1], "missing required workflows: pull, trunk, lint, linux-binary, windows-binary")
|
||||
self.assertEqual(
|
||||
result[1],
|
||||
"missing required workflows: pull, trunk, lint, linux-binary, windows-binary",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
24
.github/scripts/test_gitutils.py
vendored
24
.github/scripts/test_gitutils.py
vendored
@ -1,7 +1,14 @@
|
||||
#!/usr/bin/env python3
|
||||
from gitutils import PeekableIterator, patterns_to_regex, GitRepo, are_ghstack_branches_in_sync, _shasum
|
||||
from unittest import TestCase, main, SkipTest
|
||||
from pathlib import Path
|
||||
from unittest import main, SkipTest, TestCase
|
||||
|
||||
from gitutils import (
|
||||
_shasum,
|
||||
are_ghstack_branches_in_sync,
|
||||
GitRepo,
|
||||
patterns_to_regex,
|
||||
PeekableIterator,
|
||||
)
|
||||
|
||||
|
||||
BASE_DIR = Path(__file__).parent
|
||||
@ -15,6 +22,7 @@ class TestPeekableIterator(TestCase):
|
||||
|
||||
def test_is_iterable(self) -> None:
|
||||
from collections.abc import Iterator
|
||||
|
||||
iter_ = PeekableIterator("")
|
||||
self.assertTrue(isinstance(iter_, Iterator))
|
||||
|
||||
@ -35,7 +43,8 @@ class TestPattern(TestCase):
|
||||
patterns_re = patterns_to_regex(allowed_patterns)
|
||||
fnames = [
|
||||
"aten/src/ATen/native/LinearAlgebra.cpp",
|
||||
"aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp"]
|
||||
"aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp",
|
||||
]
|
||||
for filename in fnames:
|
||||
self.assertTrue(patterns_re.match(filename))
|
||||
|
||||
@ -44,11 +53,13 @@ class TestGitRepo(TestCase):
|
||||
def setUp(self) -> None:
|
||||
repo_dir = BASE_DIR.parent.parent.absolute()
|
||||
if not (repo_dir / ".git").is_dir():
|
||||
raise SkipTest("Can't find git directory, make sure to run this test on real repo checkout")
|
||||
raise SkipTest(
|
||||
"Can't find git directory, make sure to run this test on real repo checkout"
|
||||
)
|
||||
self.repo = GitRepo(str(repo_dir))
|
||||
|
||||
def _skip_if_ref_does_not_exist(self, ref: str) -> None:
|
||||
""" Skip test if ref is missing as stale branches are deleted with time """
|
||||
"""Skip test if ref is missing as stale branches are deleted with time"""
|
||||
try:
|
||||
self.repo.show_ref(ref)
|
||||
except RuntimeError as e:
|
||||
@ -69,5 +80,6 @@ class TestGitRepo(TestCase):
|
||||
self._skip_if_ref_does_not_exist(head_ref)
|
||||
self.assertFalse(are_ghstack_branches_in_sync(self.repo, head_ref))
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
48
.github/scripts/test_label_utils.py
vendored
48
.github/scripts/test_label_utils.py
vendored
@ -1,29 +1,37 @@
|
||||
from typing import Any
|
||||
from unittest import TestCase, mock, main
|
||||
from unittest import main, mock, TestCase
|
||||
|
||||
from label_utils import (
|
||||
get_last_page_num_from_header,
|
||||
gh_get_labels,
|
||||
has_required_labels,
|
||||
)
|
||||
from trymerge import GitHubPR
|
||||
from test_trymerge import mocked_gh_graphql
|
||||
from trymerge import GitHubPR
|
||||
|
||||
|
||||
release_notes_labels = [
|
||||
"release notes: nn",
|
||||
]
|
||||
|
||||
|
||||
class TestLabelUtils(TestCase):
|
||||
MOCK_HEADER_LINKS_TO_PAGE_NUMS = {
|
||||
1: {"link": "<https://api.github.com/dummy/labels?per_page=10&page=1>; rel='last'"},
|
||||
1: {
|
||||
"link": "<https://api.github.com/dummy/labels?per_page=10&page=1>; rel='last'"
|
||||
},
|
||||
2: {"link": "<https://api.github.com/dummy/labels?per_page=1&page=2>;"},
|
||||
3: {"link": "<https://api.github.com/dummy/labels?per_page=1&page=2&page=3>;"},
|
||||
}
|
||||
|
||||
def test_get_last_page_num_from_header(self) -> None:
|
||||
for expected_page_num, mock_header in self.MOCK_HEADER_LINKS_TO_PAGE_NUMS.items():
|
||||
self.assertEqual(get_last_page_num_from_header(mock_header), expected_page_num)
|
||||
for (
|
||||
expected_page_num,
|
||||
mock_header,
|
||||
) in self.MOCK_HEADER_LINKS_TO_PAGE_NUMS.items():
|
||||
self.assertEqual(
|
||||
get_last_page_num_from_header(mock_header), expected_page_num
|
||||
)
|
||||
|
||||
MOCK_LABEL_INFO = '[{"name": "foo"}]'
|
||||
|
||||
@ -49,23 +57,35 @@ class TestLabelUtils(TestCase):
|
||||
gh_get_labels("foo", "bar")
|
||||
self.assertIn("number of pages of labels", str(err.exception))
|
||||
|
||||
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
|
||||
@mock.patch('label_utils.get_release_notes_labels', return_value=release_notes_labels)
|
||||
def test_pr_with_missing_labels(self, mocked_rn_labels: Any, mocked_gql: Any) -> None:
|
||||
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
|
||||
@mock.patch(
|
||||
"label_utils.get_release_notes_labels", return_value=release_notes_labels
|
||||
)
|
||||
def test_pr_with_missing_labels(
|
||||
self, mocked_rn_labels: Any, mocked_gql: Any
|
||||
) -> None:
|
||||
"Test PR with no 'release notes:' label or 'topic: not user facing' label"
|
||||
pr = GitHubPR("pytorch", "pytorch", 82169)
|
||||
self.assertFalse(has_required_labels(pr))
|
||||
|
||||
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
|
||||
@mock.patch('label_utils.get_release_notes_labels', return_value=release_notes_labels)
|
||||
def test_pr_with_release_notes_label(self, mocked_rn_labels: Any, mocked_gql: Any) -> None:
|
||||
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
|
||||
@mock.patch(
|
||||
"label_utils.get_release_notes_labels", return_value=release_notes_labels
|
||||
)
|
||||
def test_pr_with_release_notes_label(
|
||||
self, mocked_rn_labels: Any, mocked_gql: Any
|
||||
) -> None:
|
||||
"Test PR with 'release notes: nn' label"
|
||||
pr = GitHubPR("pytorch", "pytorch", 71759)
|
||||
self.assertTrue(has_required_labels(pr))
|
||||
|
||||
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
|
||||
@mock.patch('label_utils.get_release_notes_labels', return_value=release_notes_labels)
|
||||
def test_pr_with_not_user_facing_label(self, mocked_rn_labels: Any, mocked_gql: Any) -> None:
|
||||
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
|
||||
@mock.patch(
|
||||
"label_utils.get_release_notes_labels", return_value=release_notes_labels
|
||||
)
|
||||
def test_pr_with_not_user_facing_label(
|
||||
self, mocked_rn_labels: Any, mocked_gql: Any
|
||||
) -> None:
|
||||
"Test PR with 'topic: not user facing' label"
|
||||
pr = GitHubPR("pytorch", "pytorch", 75095)
|
||||
self.assertTrue(has_required_labels(pr))
|
||||
|
296
.github/scripts/test_trymerge.py
vendored
296
.github/scripts/test_trymerge.py
vendored
@ -10,30 +10,32 @@
|
||||
import json
|
||||
import os
|
||||
from hashlib import sha256
|
||||
|
||||
from trymerge import (
|
||||
find_matching_merge_rule,
|
||||
gh_graphql,
|
||||
gh_get_team_members,
|
||||
read_merge_rules,
|
||||
validate_revert,
|
||||
GitHubPR,
|
||||
MergeRule,
|
||||
MandatoryChecksMissingError,
|
||||
PostCommentError,
|
||||
FlakyRule,
|
||||
categorize_checks,
|
||||
get_rockset_results,
|
||||
main as trymerge_main,
|
||||
get_classifications,
|
||||
)
|
||||
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
|
||||
from typing import Any, Dict, List, Optional
|
||||
from unittest import TestCase, main, mock
|
||||
from unittest import main, mock, TestCase
|
||||
from urllib.error import HTTPError
|
||||
|
||||
if 'GIT_REMOTE_URL' not in os.environ:
|
||||
os.environ['GIT_REMOTE_URL'] = "https://github.com/pytorch/pytorch"
|
||||
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
|
||||
|
||||
from trymerge import (
|
||||
categorize_checks,
|
||||
find_matching_merge_rule,
|
||||
FlakyRule,
|
||||
get_classifications,
|
||||
get_rockset_results,
|
||||
gh_get_team_members,
|
||||
gh_graphql,
|
||||
GitHubPR,
|
||||
main as trymerge_main,
|
||||
MandatoryChecksMissingError,
|
||||
MergeRule,
|
||||
PostCommentError,
|
||||
read_merge_rules,
|
||||
validate_revert,
|
||||
)
|
||||
|
||||
if "GIT_REMOTE_URL" not in os.environ:
|
||||
os.environ["GIT_REMOTE_URL"] = "https://github.com/pytorch/pytorch"
|
||||
|
||||
|
||||
def mock_query(
|
||||
fallback_function: Any,
|
||||
@ -67,8 +69,13 @@ def mock_query(
|
||||
err_msg = f"If you are seeing this message during workflow run, please make sure to update {file_name}"
|
||||
err_msg += f" locally, by deleting it and running {os.path.basename(__file__)} with "
|
||||
err_msg += " GitHub Personal Access Token passed via GITHUB_TOKEN environment variable"
|
||||
err_msg += " the rockset api key passed via ROCKSET_API_KEY environment variable"
|
||||
if os.getenv("GITHUB_TOKEN") is None or os.getenv("ROCKSET_API_KEY") is None:
|
||||
err_msg += (
|
||||
" the rockset api key passed via ROCKSET_API_KEY environment variable"
|
||||
)
|
||||
if (
|
||||
os.getenv("GITHUB_TOKEN") is None
|
||||
or os.getenv("ROCKSET_API_KEY") is None
|
||||
):
|
||||
err_msg = (
|
||||
"Failed to update cached GraphQL queries as GITHUB_TOKEN or ROCKSET_API_KEY is not defined."
|
||||
+ err_msg
|
||||
@ -89,8 +96,10 @@ def mocked_gh_graphql(query: str, **kwargs: Any) -> Any:
|
||||
|
||||
def gh_graphql_wrapper(query: str, kwargs: Any) -> Any:
|
||||
return gh_graphql(query, **kwargs)
|
||||
|
||||
return mock_query(gh_graphql_wrapper, "gql_mocks.json", key_function, query, kwargs)
|
||||
|
||||
|
||||
def mocked_rockset_results(head_sha: str, merge_base: str, num_retries: int = 3) -> Any:
|
||||
return mock_query(
|
||||
get_rockset_results,
|
||||
@ -100,6 +109,7 @@ def mocked_rockset_results(head_sha: str, merge_base: str, num_retries: int = 3)
|
||||
merge_base,
|
||||
)
|
||||
|
||||
|
||||
def mock_parse_args(revert: bool = False, force: bool = False) -> Any:
|
||||
class Object(object):
|
||||
def __init__(self) -> None:
|
||||
@ -108,69 +118,87 @@ def mock_parse_args(revert: bool = False, force: bool = False) -> Any:
|
||||
self.pr_num = 76123
|
||||
self.dry_run = True
|
||||
self.comment_id = 0
|
||||
self.reason = 'this is for testing'
|
||||
self.reason = "this is for testing"
|
||||
self.ignore_current = False
|
||||
|
||||
return Object()
|
||||
|
||||
def mock_revert(repo: GitRepo, pr: GitHubPR, *,
|
||||
dry_run: bool = False,
|
||||
comment_id: Optional[int] = None,
|
||||
reason: Optional[str] = None) -> None:
|
||||
|
||||
def mock_revert(
|
||||
repo: GitRepo,
|
||||
pr: GitHubPR,
|
||||
*,
|
||||
dry_run: bool = False,
|
||||
comment_id: Optional[int] = None,
|
||||
reason: Optional[str] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def mock_merge(pr_num: int, repo: GitRepo,
|
||||
dry_run: bool = False,
|
||||
skip_mandatory_checks: bool = False,
|
||||
comment_id: Optional[int] = None,
|
||||
timeout_minutes: int = 400,
|
||||
stale_pr_days: int = 3,
|
||||
ignore_current: bool = False) -> None:
|
||||
|
||||
def mock_merge(
|
||||
pr_num: int,
|
||||
repo: GitRepo,
|
||||
dry_run: bool = False,
|
||||
skip_mandatory_checks: bool = False,
|
||||
comment_id: Optional[int] = None,
|
||||
timeout_minutes: int = 400,
|
||||
stale_pr_days: int = 3,
|
||||
ignore_current: bool = False,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def mock_gh_get_info() -> Any:
|
||||
return {"closed": False,
|
||||
"isCrossRepository": False,
|
||||
"files": {"nodes": [], "pageInfo": {"hasNextPage": False}}, "changedFiles": 0}
|
||||
return {
|
||||
"closed": False,
|
||||
"isCrossRepository": False,
|
||||
"files": {"nodes": [], "pageInfo": {"hasNextPage": False}},
|
||||
"changedFiles": 0,
|
||||
}
|
||||
|
||||
|
||||
def mocked_read_merge_rules_NE(repo: Any, org: str, project: str) -> List[MergeRule]:
|
||||
return [
|
||||
MergeRule(name="mock with nonexistent check",
|
||||
patterns=["*"],
|
||||
approved_by=[],
|
||||
mandatory_checks_name=["Lint",
|
||||
"Facebook CLA Check",
|
||||
"nonexistent"],
|
||||
),
|
||||
MergeRule(
|
||||
name="mock with nonexistent check",
|
||||
patterns=["*"],
|
||||
approved_by=[],
|
||||
mandatory_checks_name=["Lint", "Facebook CLA Check", "nonexistent"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def mocked_read_merge_rules(repo: Any, org: str, project: str) -> List[MergeRule]:
|
||||
return [
|
||||
MergeRule(name="super",
|
||||
patterns=["*"],
|
||||
approved_by=["pytorch/metamates"],
|
||||
mandatory_checks_name=["Lint",
|
||||
"Facebook CLA Check",
|
||||
"pull / linux-xenial-cuda11.3-py3.7-gcc7 / build",
|
||||
],
|
||||
),
|
||||
MergeRule(
|
||||
name="super",
|
||||
patterns=["*"],
|
||||
approved_by=["pytorch/metamates"],
|
||||
mandatory_checks_name=[
|
||||
"Lint",
|
||||
"Facebook CLA Check",
|
||||
"pull / linux-xenial-cuda11.3-py3.7-gcc7 / build",
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def mocked_read_merge_rules_raise(repo: Any, org: str, project: str) -> List[MergeRule]:
|
||||
raise RuntimeError("testing")
|
||||
|
||||
|
||||
def empty_flaky_rules() -> List[FlakyRule]:
|
||||
return []
|
||||
|
||||
|
||||
def empty_rockset_results(head_sha: str, merge_base: str) -> List[Dict[str, Any]]:
|
||||
return []
|
||||
|
||||
|
||||
def dummy_merge_base() -> str:
|
||||
return "dummy"
|
||||
|
||||
|
||||
class DummyGitRepo(GitRepo):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(get_git_repo_dir(), get_git_remote_name())
|
||||
@ -185,7 +213,7 @@ class DummyGitRepo(GitRepo):
|
||||
@mock.patch("trymerge.read_flaky_rules", side_effect=empty_flaky_rules)
|
||||
@mock.patch("trymerge.get_rockset_results", side_effect=empty_rockset_results)
|
||||
@mock.patch("trymerge.GitHubPR.get_merge_base", side_effect=dummy_merge_base)
|
||||
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
|
||||
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
|
||||
class TestTryMerge(TestCase):
|
||||
def test_merge_rules_valid(self, *args: Any) -> None:
|
||||
"Test that merge_rules.yaml can be parsed"
|
||||
@ -193,21 +221,23 @@ class TestTryMerge(TestCase):
|
||||
merge_rules = read_merge_rules(repo, "pytorch", "pytorch")
|
||||
self.assertGreater(len(merge_rules), 1)
|
||||
|
||||
@mock.patch('trymerge.read_merge_rules', side_effect=mocked_read_merge_rules)
|
||||
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
|
||||
def test_match_rules(self, *args: Any) -> None:
|
||||
"Tests that PR passes merge rules"
|
||||
pr = GitHubPR("pytorch", "pytorch", 77700)
|
||||
repo = DummyGitRepo()
|
||||
self.assertTrue(find_matching_merge_rule(pr, repo) is not None)
|
||||
|
||||
@mock.patch('trymerge.read_merge_rules', side_effect=mocked_read_merge_rules_raise)
|
||||
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_raise)
|
||||
def test_read_merge_rules_fails(self, *args: Any) -> None:
|
||||
"Tests that PR fails to read the merge rules"
|
||||
pr = GitHubPR("pytorch", "pytorch", 77700)
|
||||
repo = DummyGitRepo()
|
||||
self.assertRaisesRegex(RuntimeError, "testing", lambda: find_matching_merge_rule(pr, repo))
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError, "testing", lambda: find_matching_merge_rule(pr, repo)
|
||||
)
|
||||
|
||||
@mock.patch('trymerge.read_merge_rules', side_effect=mocked_read_merge_rules)
|
||||
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
|
||||
def test_lint_fails(self, *args: Any) -> None:
|
||||
"Tests that PR fails mandatory lint check"
|
||||
pr = GitHubPR("pytorch", "pytorch", 90791)
|
||||
@ -223,8 +253,8 @@ class TestTryMerge(TestCase):
|
||||
self.assertTrue("You've committed this PR" in comment.body_text)
|
||||
|
||||
def test_get_author_null(self, *args: Any) -> None:
|
||||
""" Tests that PR author can be computed
|
||||
If reply contains NULL
|
||||
"""Tests that PR author can be computed
|
||||
If reply contains NULL
|
||||
"""
|
||||
pr = GitHubPR("pytorch", "pytorch", 71759)
|
||||
author = pr.get_author()
|
||||
@ -239,8 +269,7 @@ class TestTryMerge(TestCase):
|
||||
self.assertTrue(author is not None)
|
||||
|
||||
def test_last_pushed_at(self, *args: Any) -> None:
|
||||
""" Tests that last_pushed_at will return None on merge commits.
|
||||
"""
|
||||
"""Tests that last_pushed_at will return None on merge commits."""
|
||||
pr = GitHubPR("pytorch", "pytorch", 71759)
|
||||
self.assertIsNotNone(pr.last_pushed_at())
|
||||
|
||||
@ -295,27 +324,26 @@ class TestTryMerge(TestCase):
|
||||
self.assertEqual(len(non_existing_team), 0)
|
||||
|
||||
def test_get_author_many_commits(self, *args: Any) -> None:
|
||||
""" Tests that authors for all commits can be fetched
|
||||
"""
|
||||
"""Tests that authors for all commits can be fetched"""
|
||||
pr = GitHubPR("pytorch", "pytorch", 76118)
|
||||
authors = pr.get_authors()
|
||||
self.assertGreater(pr.get_commit_count(), 100)
|
||||
self.assertGreater(len(authors), 50)
|
||||
self.assertTrue("@" in pr.get_author())
|
||||
|
||||
@mock.patch('trymerge.read_merge_rules', side_effect=mocked_read_merge_rules_NE)
|
||||
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules_NE)
|
||||
def test_pending_status_check(self, *args: Any) -> None:
|
||||
""" Tests that PR with nonexistent/pending status checks fails with the right reason.
|
||||
"""
|
||||
"""Tests that PR with nonexistent/pending status checks fails with the right reason."""
|
||||
pr = GitHubPR("pytorch", "pytorch", 76118)
|
||||
repo = DummyGitRepo()
|
||||
self.assertRaisesRegex(MandatoryChecksMissingError,
|
||||
".*are pending/not yet run.*",
|
||||
lambda: find_matching_merge_rule(pr, repo))
|
||||
self.assertRaisesRegex(
|
||||
MandatoryChecksMissingError,
|
||||
".*are pending/not yet run.*",
|
||||
lambda: find_matching_merge_rule(pr, repo),
|
||||
)
|
||||
|
||||
def test_get_author_many_reviews(self, *args: Any) -> None:
|
||||
""" Tests that all reviews can be fetched
|
||||
"""
|
||||
"""Tests that all reviews can be fetched"""
|
||||
pr = GitHubPR("pytorch", "pytorch", 76123)
|
||||
approved_by = pr.get_approved_by()
|
||||
self.assertGreater(len(approved_by), 0)
|
||||
@ -323,56 +351,62 @@ class TestTryMerge(TestCase):
|
||||
self.assertGreater(len(pr._reviews), 100)
|
||||
|
||||
def test_get_checkruns_many_runs(self, *args: Any) -> None:
|
||||
""" Tests that all checkruns can be fetched
|
||||
"""
|
||||
"""Tests that all checkruns can be fetched"""
|
||||
pr = GitHubPR("pytorch", "pytorch", 77700)
|
||||
conclusions = pr.get_checkrun_conclusions()
|
||||
self.assertEqual(len(conclusions), 79)
|
||||
self.assertTrue("pull / linux-docs / build-docs (cpp)" in conclusions.keys())
|
||||
|
||||
def test_cancelled_gets_ignored(self, *args: Any) -> None:
|
||||
""" Tests that cancelled workflow does not override existing successfull status
|
||||
"""
|
||||
"""Tests that cancelled workflow does not override existing successfull status"""
|
||||
pr = GitHubPR("pytorch", "pytorch", 82169)
|
||||
conclusions = pr.get_checkrun_conclusions()
|
||||
lint_checks = [name for name in conclusions.keys() if "Lint" in name]
|
||||
self.assertTrue(len(lint_checks) > 0)
|
||||
self.assertTrue(all([conclusions[name].status == "SUCCESS" for name in lint_checks]))
|
||||
self.assertTrue(
|
||||
all([conclusions[name].status == "SUCCESS" for name in lint_checks])
|
||||
)
|
||||
|
||||
@mock.patch('trymerge.gh_get_pr_info', return_value=mock_gh_get_info())
|
||||
@mock.patch('trymerge.parse_args', return_value=mock_parse_args(True, False))
|
||||
@mock.patch('trymerge.try_revert', side_effect=mock_revert)
|
||||
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
|
||||
@mock.patch("trymerge.parse_args", return_value=mock_parse_args(True, False))
|
||||
@mock.patch("trymerge.try_revert", side_effect=mock_revert)
|
||||
def test_main_revert(self, mock_revert: Any, *args: Any) -> None:
|
||||
trymerge_main()
|
||||
mock_revert.assert_called_once()
|
||||
|
||||
@mock.patch('trymerge.gh_get_pr_info', return_value=mock_gh_get_info())
|
||||
@mock.patch('trymerge.parse_args', return_value=mock_parse_args(False, True))
|
||||
@mock.patch('trymerge.merge', side_effect=mock_merge)
|
||||
def test_main_force(self, mock_merge: Any, mock_parse_args: Any, *args: Any) -> None:
|
||||
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
|
||||
@mock.patch("trymerge.parse_args", return_value=mock_parse_args(False, True))
|
||||
@mock.patch("trymerge.merge", side_effect=mock_merge)
|
||||
def test_main_force(
|
||||
self, mock_merge: Any, mock_parse_args: Any, *args: Any
|
||||
) -> None:
|
||||
trymerge_main()
|
||||
mock_merge.assert_called_once_with(mock.ANY,
|
||||
mock.ANY,
|
||||
dry_run=mock.ANY,
|
||||
skip_mandatory_checks=True,
|
||||
comment_id=mock.ANY,
|
||||
ignore_current=False)
|
||||
mock_merge.assert_called_once_with(
|
||||
mock.ANY,
|
||||
mock.ANY,
|
||||
dry_run=mock.ANY,
|
||||
skip_mandatory_checks=True,
|
||||
comment_id=mock.ANY,
|
||||
ignore_current=False,
|
||||
)
|
||||
|
||||
@mock.patch('trymerge.gh_get_pr_info', return_value=mock_gh_get_info())
|
||||
@mock.patch('trymerge.parse_args', return_value=mock_parse_args(False, False))
|
||||
@mock.patch('trymerge.merge', side_effect=mock_merge)
|
||||
@mock.patch("trymerge.gh_get_pr_info", return_value=mock_gh_get_info())
|
||||
@mock.patch("trymerge.parse_args", return_value=mock_parse_args(False, False))
|
||||
@mock.patch("trymerge.merge", side_effect=mock_merge)
|
||||
def test_main_merge(self, mock_merge: Any, *args: Any) -> None:
|
||||
trymerge_main()
|
||||
mock_merge.assert_called_once_with(mock.ANY,
|
||||
mock.ANY,
|
||||
dry_run=mock.ANY,
|
||||
skip_mandatory_checks=False,
|
||||
comment_id=mock.ANY,
|
||||
ignore_current=False)
|
||||
mock_merge.assert_called_once_with(
|
||||
mock.ANY,
|
||||
mock.ANY,
|
||||
dry_run=mock.ANY,
|
||||
skip_mandatory_checks=False,
|
||||
comment_id=mock.ANY,
|
||||
ignore_current=False,
|
||||
)
|
||||
|
||||
@mock.patch('trymerge.read_merge_rules', side_effect=mocked_read_merge_rules)
|
||||
@mock.patch("trymerge.read_merge_rules", side_effect=mocked_read_merge_rules)
|
||||
def test_revert_rules(self, *args: Any) -> None:
|
||||
""" Tests that reverts from collaborators are allowed """
|
||||
"""Tests that reverts from collaborators are allowed"""
|
||||
pr = GitHubPR("pytorch", "pytorch", 79694)
|
||||
repo = DummyGitRepo()
|
||||
self.assertIsNotNone(validate_revert(repo, pr, comment_id=1189459845))
|
||||
@ -403,7 +437,11 @@ class TestTryMerge(TestCase):
|
||||
return pr.get_body()
|
||||
|
||||
repo = GitRepoCoDev()
|
||||
self.assertRaisesRegex(PostCommentError, "landed via phabricator", lambda: validate_revert(repo, pr, comment_id=1372496233))
|
||||
self.assertRaisesRegex(
|
||||
PostCommentError,
|
||||
"landed via phabricator",
|
||||
lambda: validate_revert(repo, pr, comment_id=1372496233),
|
||||
)
|
||||
|
||||
def test_pr_changed_submodule_detection(self, *args: Any) -> None:
|
||||
# Updates submodule during dev-cycle but reverts it later
|
||||
@ -426,10 +464,14 @@ class TestTryMerge(TestCase):
|
||||
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
|
||||
class TestBypassFailures(TestCase):
|
||||
def test_get_classifications(self, *args: Any) -> None:
|
||||
flaky_rules = [FlakyRule("distributed", ["##[error]The operation was canceled."])]
|
||||
flaky_rules = [
|
||||
FlakyRule("distributed", ["##[error]The operation was canceled."])
|
||||
]
|
||||
pr = GitHubPR("pytorch", "pytorch", 92863)
|
||||
checks = pr.get_checkrun_conclusions()
|
||||
checks = get_classifications(checks, pr.last_commit()['oid'], pr.get_merge_base(), flaky_rules, [])
|
||||
checks = get_classifications(
|
||||
checks, pr.last_commit()["oid"], pr.get_merge_base(), flaky_rules, []
|
||||
)
|
||||
self.assertTrue(
|
||||
checks[
|
||||
"pull / linux-bionic-py3_7-clang8-xla / test (xla, 1, 1, linux.4xlarge)"
|
||||
@ -442,10 +484,14 @@ class TestBypassFailures(TestCase):
|
||||
].classification
|
||||
== "FLAKY"
|
||||
)
|
||||
pending, failed = categorize_checks(checks, list(checks.keys()), ok_failed_checks_threshold=2)
|
||||
pending, failed = categorize_checks(
|
||||
checks, list(checks.keys()), ok_failed_checks_threshold=2
|
||||
)
|
||||
self.assertTrue(len(pending) == 0)
|
||||
self.assertTrue(len(failed) == 0)
|
||||
pending, failed = categorize_checks(checks, list(checks.keys()), ok_failed_checks_threshold=1)
|
||||
pending, failed = categorize_checks(
|
||||
checks, list(checks.keys()), ok_failed_checks_threshold=1
|
||||
)
|
||||
self.assertTrue(len(pending) == 0)
|
||||
self.assertTrue(len(failed) == 2)
|
||||
|
||||
@ -454,32 +500,52 @@ class TestBypassFailures(TestCase):
|
||||
# ignore current checks takes precedence over classifications for flaky
|
||||
# or broken trunk
|
||||
|
||||
flaky_rules = [FlakyRule("distributed", ["##[error]The operation was canceled."])]
|
||||
flaky = "pull / linux-focal-py3.7-gcc7 / test (distributed, 1, 2, linux.2xlarge)"
|
||||
broken_trunk = "pull / linux-bionic-py3_7-clang8-xla / test (xla, 1, 1, linux.4xlarge)"
|
||||
flaky_rules = [
|
||||
FlakyRule("distributed", ["##[error]The operation was canceled."])
|
||||
]
|
||||
flaky = (
|
||||
"pull / linux-focal-py3.7-gcc7 / test (distributed, 1, 2, linux.2xlarge)"
|
||||
)
|
||||
broken_trunk = (
|
||||
"pull / linux-bionic-py3_7-clang8-xla / test (xla, 1, 1, linux.4xlarge)"
|
||||
)
|
||||
|
||||
pr = GitHubPR("pytorch", "pytorch", 92863)
|
||||
checks = pr.get_checkrun_conclusions()
|
||||
|
||||
# No broken trunk or flaky rules
|
||||
checks = get_classifications(checks, pr.last_commit()['oid'], None, [], [flaky])
|
||||
checks = get_classifications(checks, pr.last_commit()["oid"], None, [], [flaky])
|
||||
self.assertTrue(checks[flaky].classification == "IGNORE_CURRENT_CHECK")
|
||||
self.assertTrue(checks[broken_trunk].classification is None)
|
||||
_, failed = categorize_checks(checks, list(checks.keys()), ok_failed_checks_threshold=0)
|
||||
_, failed = categorize_checks(
|
||||
checks, list(checks.keys()), ok_failed_checks_threshold=0
|
||||
)
|
||||
self.assertTrue(len(failed) == 1)
|
||||
|
||||
# No flaky rules
|
||||
checks = get_classifications(checks, pr.last_commit()['oid'], pr.get_merge_base(), [], [flaky])
|
||||
checks = get_classifications(
|
||||
checks, pr.last_commit()["oid"], pr.get_merge_base(), [], [flaky]
|
||||
)
|
||||
self.assertTrue(checks[flaky].classification == "IGNORE_CURRENT_CHECK")
|
||||
self.assertTrue(checks[broken_trunk].classification == "BROKEN_TRUNK")
|
||||
_, failed = categorize_checks(checks, list(checks.keys()), ok_failed_checks_threshold=1)
|
||||
_, failed = categorize_checks(
|
||||
checks, list(checks.keys()), ok_failed_checks_threshold=1
|
||||
)
|
||||
self.assertTrue(len(failed) == 0)
|
||||
|
||||
# No broken_trunk
|
||||
checks = get_classifications(checks, pr.last_commit()['oid'], pr.get_merge_base(), flaky_rules, [broken_trunk])
|
||||
checks = get_classifications(
|
||||
checks,
|
||||
pr.last_commit()["oid"],
|
||||
pr.get_merge_base(),
|
||||
flaky_rules,
|
||||
[broken_trunk],
|
||||
)
|
||||
self.assertTrue(checks[flaky].classification == "FLAKY")
|
||||
self.assertTrue(checks[broken_trunk].classification == "IGNORE_CURRENT_CHECK")
|
||||
_, failed = categorize_checks(checks, list(checks.keys()), ok_failed_checks_threshold=1)
|
||||
_, failed = categorize_checks(
|
||||
checks, list(checks.keys()), ok_failed_checks_threshold=1
|
||||
)
|
||||
self.assertTrue(len(failed) == 0)
|
||||
|
||||
|
||||
|
142
.github/scripts/test_tryrebase.py
vendored
142
.github/scripts/test_tryrebase.py
vendored
@ -1,75 +1,129 @@
|
||||
from unittest import TestCase, mock, main
|
||||
from typing import Any
|
||||
from unittest import main, mock, TestCase
|
||||
|
||||
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
|
||||
from test_trymerge import mocked_gh_graphql
|
||||
from trymerge import GitHubPR
|
||||
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
|
||||
from typing import Any
|
||||
from tryrebase import rebase_onto, rebase_ghstack_onto
|
||||
from tryrebase import rebase_ghstack_onto, rebase_onto
|
||||
|
||||
|
||||
def mocked_rev_parse(branch: str) -> str:
|
||||
return branch
|
||||
|
||||
class TestRebase(TestCase):
|
||||
|
||||
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
|
||||
@mock.patch('gitutils.GitRepo._run_git')
|
||||
@mock.patch('gitutils.GitRepo.rev_parse', side_effect=mocked_rev_parse)
|
||||
@mock.patch('tryrebase.gh_post_comment')
|
||||
def test_rebase(self, mocked_post_comment: Any, mocked_rp: Any, mocked_run_git: Any, mocked_gql: Any) -> None:
|
||||
class TestRebase(TestCase):
|
||||
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
|
||||
@mock.patch("gitutils.GitRepo._run_git")
|
||||
@mock.patch("gitutils.GitRepo.rev_parse", side_effect=mocked_rev_parse)
|
||||
@mock.patch("tryrebase.gh_post_comment")
|
||||
def test_rebase(
|
||||
self,
|
||||
mocked_post_comment: Any,
|
||||
mocked_rp: Any,
|
||||
mocked_run_git: Any,
|
||||
mocked_gql: Any,
|
||||
) -> None:
|
||||
"Tests rebase successfully"
|
||||
pr = GitHubPR("pytorch", "pytorch", 31093)
|
||||
repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
|
||||
rebase_onto(pr, repo, 'master')
|
||||
calls = [mock.call('fetch', 'origin', 'pull/31093/head:pull/31093/head'),
|
||||
mock.call('rebase', 'refs/remotes/origin/master', 'pull/31093/head'),
|
||||
mock.call('push', '-f', 'https://github.com/mingxiaoh/pytorch.git', 'pull/31093/head:master')]
|
||||
rebase_onto(pr, repo, "master")
|
||||
calls = [
|
||||
mock.call("fetch", "origin", "pull/31093/head:pull/31093/head"),
|
||||
mock.call("rebase", "refs/remotes/origin/master", "pull/31093/head"),
|
||||
mock.call(
|
||||
"push",
|
||||
"-f",
|
||||
"https://github.com/mingxiaoh/pytorch.git",
|
||||
"pull/31093/head:master",
|
||||
),
|
||||
]
|
||||
mocked_run_git.assert_has_calls(calls)
|
||||
self.assertTrue(
|
||||
"Successfully rebased `master` onto `refs/remotes/origin/master`" in mocked_post_comment.call_args[0][3])
|
||||
"Successfully rebased `master` onto `refs/remotes/origin/master`"
|
||||
in mocked_post_comment.call_args[0][3]
|
||||
)
|
||||
|
||||
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
|
||||
@mock.patch('gitutils.GitRepo._run_git')
|
||||
@mock.patch('gitutils.GitRepo.rev_parse', side_effect=mocked_rev_parse)
|
||||
@mock.patch('tryrebase.gh_post_comment')
|
||||
def test_rebase_to_stable(self, mocked_post_comment: Any, mocked_rp: Any, mocked_run_git: Any, mocked_gql: Any) -> None:
|
||||
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
|
||||
@mock.patch("gitutils.GitRepo._run_git")
|
||||
@mock.patch("gitutils.GitRepo.rev_parse", side_effect=mocked_rev_parse)
|
||||
@mock.patch("tryrebase.gh_post_comment")
|
||||
def test_rebase_to_stable(
|
||||
self,
|
||||
mocked_post_comment: Any,
|
||||
mocked_rp: Any,
|
||||
mocked_run_git: Any,
|
||||
mocked_gql: Any,
|
||||
) -> None:
|
||||
"Tests rebase to viable/strict successfully"
|
||||
pr = GitHubPR("pytorch", "pytorch", 31093)
|
||||
repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
|
||||
rebase_onto(pr, repo, 'viable/strict', False)
|
||||
calls = [mock.call('fetch', 'origin', 'pull/31093/head:pull/31093/head'),
|
||||
mock.call('rebase', 'refs/remotes/origin/viable/strict', 'pull/31093/head'),
|
||||
mock.call('push', '-f', 'https://github.com/mingxiaoh/pytorch.git', 'pull/31093/head:master')]
|
||||
rebase_onto(pr, repo, "viable/strict", False)
|
||||
calls = [
|
||||
mock.call("fetch", "origin", "pull/31093/head:pull/31093/head"),
|
||||
mock.call("rebase", "refs/remotes/origin/viable/strict", "pull/31093/head"),
|
||||
mock.call(
|
||||
"push",
|
||||
"-f",
|
||||
"https://github.com/mingxiaoh/pytorch.git",
|
||||
"pull/31093/head:master",
|
||||
),
|
||||
]
|
||||
mocked_run_git.assert_has_calls(calls)
|
||||
self.assertTrue(
|
||||
"Successfully rebased `master` onto `refs/remotes/origin/viable/strict`" in mocked_post_comment.call_args[0][3])
|
||||
"Successfully rebased `master` onto `refs/remotes/origin/viable/strict`"
|
||||
in mocked_post_comment.call_args[0][3]
|
||||
)
|
||||
|
||||
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
|
||||
@mock.patch('gitutils.GitRepo._run_git', return_value="Everything up-to-date")
|
||||
@mock.patch('gitutils.GitRepo.rev_parse', side_effect=mocked_rev_parse)
|
||||
@mock.patch('tryrebase.gh_post_comment')
|
||||
def test_no_need_to_rebase(self, mocked_post_comment: Any, mocked_rp: Any, mocked_run_git: Any, mocked_gql: Any) -> None:
|
||||
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
|
||||
@mock.patch("gitutils.GitRepo._run_git", return_value="Everything up-to-date")
|
||||
@mock.patch("gitutils.GitRepo.rev_parse", side_effect=mocked_rev_parse)
|
||||
@mock.patch("tryrebase.gh_post_comment")
|
||||
def test_no_need_to_rebase(
|
||||
self,
|
||||
mocked_post_comment: Any,
|
||||
mocked_rp: Any,
|
||||
mocked_run_git: Any,
|
||||
mocked_gql: Any,
|
||||
) -> None:
|
||||
"Tests branch already up to date"
|
||||
pr = GitHubPR("pytorch", "pytorch", 31093)
|
||||
repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
|
||||
rebase_onto(pr, repo, 'master')
|
||||
calls = [mock.call('fetch', 'origin', 'pull/31093/head:pull/31093/head'),
|
||||
mock.call('rebase', 'refs/remotes/origin/master', 'pull/31093/head'),
|
||||
mock.call('push', '-f', 'https://github.com/mingxiaoh/pytorch.git', 'pull/31093/head:master')]
|
||||
rebase_onto(pr, repo, "master")
|
||||
calls = [
|
||||
mock.call("fetch", "origin", "pull/31093/head:pull/31093/head"),
|
||||
mock.call("rebase", "refs/remotes/origin/master", "pull/31093/head"),
|
||||
mock.call(
|
||||
"push",
|
||||
"-f",
|
||||
"https://github.com/mingxiaoh/pytorch.git",
|
||||
"pull/31093/head:master",
|
||||
),
|
||||
]
|
||||
mocked_run_git.assert_has_calls(calls)
|
||||
self.assertTrue(
|
||||
"Tried to rebase and push PR #31093, but it was already up to date" in mocked_post_comment.call_args[0][3])
|
||||
"Tried to rebase and push PR #31093, but it was already up to date"
|
||||
in mocked_post_comment.call_args[0][3]
|
||||
)
|
||||
|
||||
@mock.patch('trymerge.gh_graphql', side_effect=mocked_gh_graphql)
|
||||
@mock.patch('gitutils.GitRepo._run_git')
|
||||
@mock.patch('gitutils.GitRepo.rev_parse', side_effect=lambda branch: "same sha")
|
||||
@mock.patch('tryrebase.gh_post_comment')
|
||||
def test_same_sha(self, mocked_post_comment: Any, mocked_rp: Any, mocked_run_git: Any, mocked_gql: Any) -> None:
|
||||
@mock.patch("trymerge.gh_graphql", side_effect=mocked_gh_graphql)
|
||||
@mock.patch("gitutils.GitRepo._run_git")
|
||||
@mock.patch("gitutils.GitRepo.rev_parse", side_effect=lambda branch: "same sha")
|
||||
@mock.patch("tryrebase.gh_post_comment")
|
||||
def test_same_sha(
|
||||
self,
|
||||
mocked_post_comment: Any,
|
||||
mocked_rp: Any,
|
||||
mocked_run_git: Any,
|
||||
mocked_gql: Any,
|
||||
) -> None:
|
||||
"Tests rebase results in same sha"
|
||||
pr = GitHubPR("pytorch", "pytorch", 31093)
|
||||
repo = GitRepo(get_git_repo_dir(), get_git_remote_name())
|
||||
with self.assertRaisesRegex(Exception, 'same sha as the target branch'):
|
||||
rebase_onto(pr, repo, 'master')
|
||||
with self.assertRaisesRegex(Exception, 'same sha as the target branch'):
|
||||
rebase_ghstack_onto(pr, repo, 'master')
|
||||
with self.assertRaisesRegex(Exception, "same sha as the target branch"):
|
||||
rebase_onto(pr, repo, "master")
|
||||
with self.assertRaisesRegex(Exception, "same sha as the target branch"):
|
||||
rebase_ghstack_onto(pr, repo, "master")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
725
.github/scripts/trymerge.py
vendored
725
.github/scripts/trymerge.py
vendored
File diff suppressed because it is too large
Load Diff
36
.github/scripts/trymerge_explainer.py
vendored
36
.github/scripts/trymerge_explainer.py
vendored
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import re
|
||||
from typing import List, Pattern, Optional, Tuple
|
||||
from typing import List, Optional, Pattern, Tuple
|
||||
|
||||
|
||||
BOT_COMMANDS_WIKI = "https://github.com/pytorch/pytorch/wiki/Bot-commands"
|
||||
@ -10,9 +10,7 @@ CIFLOW_TRUNK_LABEL = re.compile(r"^ciflow/trunk")
|
||||
|
||||
OFFICE_HOURS_LINK = "https://github.com/pytorch/pytorch/wiki/Dev-Infra-Office-Hours"
|
||||
CONTACT_US = f"Questions? Feedback? Please reach out to the [PyTorch DevX Team]({OFFICE_HOURS_LINK})"
|
||||
ALTERNATIVES = (
|
||||
f"Learn more about merging in the [wiki]({BOT_COMMANDS_WIKI})."
|
||||
)
|
||||
ALTERNATIVES = f"Learn more about merging in the [wiki]({BOT_COMMANDS_WIKI})."
|
||||
|
||||
|
||||
def has_label(labels: List[str], pattern: Pattern[str] = CIFLOW_LABEL) -> bool:
|
||||
@ -46,31 +44,35 @@ class TryMergeExplainer(object):
|
||||
self.project = project
|
||||
self.ignore_current = ignore_current
|
||||
|
||||
def _get_flag_msg(self, ignore_current_checks: Optional[List[Tuple[str, Optional[str]]]] = None) -> str:
|
||||
def _get_flag_msg(
|
||||
self, ignore_current_checks: Optional[List[Tuple[str, Optional[str]]]] = None
|
||||
) -> str:
|
||||
if self.force:
|
||||
return "Your change will be merged immediately since you used the force (-f) flag, " + \
|
||||
"**bypassing any CI checks** (ETA: 1-5 minutes)."
|
||||
return (
|
||||
"Your change will be merged immediately since you used the force (-f) flag, "
|
||||
+ "**bypassing any CI checks** (ETA: 1-5 minutes)."
|
||||
)
|
||||
elif self.ignore_current and ignore_current_checks is not None:
|
||||
msg = f"Your change will be merged while ignoring the following {len(ignore_current_checks)} checks: "
|
||||
msg += ', '.join(f"[{x[0]}]({x[1]})" for x in ignore_current_checks)
|
||||
msg += ", ".join(f"[{x[0]}]({x[1]})" for x in ignore_current_checks)
|
||||
return msg
|
||||
else:
|
||||
return "Your change will be merged once all checks pass (ETA 0-4 Hours)."
|
||||
|
||||
|
||||
def get_merge_message(
|
||||
self,
|
||||
ignore_current_checks: Optional[List[Tuple[str, Optional[str]]]] = None
|
||||
self, ignore_current_checks: Optional[List[Tuple[str, Optional[str]]]] = None
|
||||
) -> str:
|
||||
title = "### Merge started"
|
||||
main_message = self._get_flag_msg(ignore_current_checks)
|
||||
|
||||
advanced_debugging = "\n".join((
|
||||
"<details><summary>Advanced Debugging</summary>",
|
||||
"Check the merge workflow status ",
|
||||
f"<a href=\"{os.getenv('GH_RUN_URL')}\">here</a>",
|
||||
"</details>"
|
||||
))
|
||||
advanced_debugging = "\n".join(
|
||||
(
|
||||
"<details><summary>Advanced Debugging</summary>",
|
||||
"Check the merge workflow status ",
|
||||
f"<a href=\"{os.getenv('GH_RUN_URL')}\">here</a>",
|
||||
"</details>",
|
||||
)
|
||||
)
|
||||
|
||||
msg = title + "\n"
|
||||
msg += main_message + "\n\n"
|
||||
|
108
.github/scripts/tryrebase.py
vendored
108
.github/scripts/tryrebase.py
vendored
@ -1,21 +1,24 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import re
|
||||
from typing import Any
|
||||
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
|
||||
|
||||
from github_utils import gh_post_pr_comment as gh_post_comment
|
||||
from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo
|
||||
from trymerge import GitHubPR
|
||||
|
||||
SAME_SHA_ERROR = (
|
||||
"\n```\nAborting rebase because rebasing the branch resulted in the same sha as the target branch.\n" +
|
||||
"This usually happens because the PR has already been merged. Please rebase locally and push.\n```"
|
||||
"\n```\nAborting rebase because rebasing the branch resulted in the same sha as the target branch.\n"
|
||||
+ "This usually happens because the PR has already been merged. Please rebase locally and push.\n```"
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> Any:
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser("Rebase PR into branch")
|
||||
parser.add_argument("--dry-run", action="store_true")
|
||||
parser.add_argument("--branch", type=str)
|
||||
@ -23,7 +26,9 @@ def parse_args() -> Any:
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def rebase_onto(pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run: bool = False) -> None:
|
||||
def rebase_onto(
|
||||
pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run: bool = False
|
||||
) -> None:
|
||||
branch = f"pull/{pr.pr_num}/head"
|
||||
onto_branch = f"refs/remotes/origin/{onto_branch}"
|
||||
remote_url = f"https://github.com/{pr.info['headRepository']['nameWithOwner']}.git"
|
||||
@ -40,17 +45,34 @@ def rebase_onto(pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run: bool = F
|
||||
else:
|
||||
push_result = repo._run_git("push", "-f", remote_url, refspec)
|
||||
if "Everything up-to-date" in push_result:
|
||||
gh_post_comment(pr.org, pr.project, pr.pr_num,
|
||||
f"Tried to rebase and push PR #{pr.pr_num}, but it was already up to date", dry_run=dry_run)
|
||||
gh_post_comment(
|
||||
pr.org,
|
||||
pr.project,
|
||||
pr.pr_num,
|
||||
f"Tried to rebase and push PR #{pr.pr_num}, but it was already up to date",
|
||||
dry_run=dry_run,
|
||||
)
|
||||
else:
|
||||
gh_post_comment(pr.org, pr.project, pr.pr_num,
|
||||
f"Successfully rebased `{pr.head_ref()}` onto `{onto_branch}`, please pull locally " +
|
||||
f"before adding more changes (for example, via `git checkout {pr.head_ref()} && " +
|
||||
"git pull --rebase`)", dry_run=dry_run)
|
||||
gh_post_comment(
|
||||
pr.org,
|
||||
pr.project,
|
||||
pr.pr_num,
|
||||
f"Successfully rebased `{pr.head_ref()}` onto `{onto_branch}`, please pull locally "
|
||||
+ f"before adding more changes (for example, via `git checkout {pr.head_ref()} && "
|
||||
+ "git pull --rebase`)",
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
|
||||
def rebase_ghstack_onto(pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run: bool = False) -> None:
|
||||
if subprocess.run([sys.executable, "-m", "ghstack", "--help"], capture_output=True).returncode != 0:
|
||||
def rebase_ghstack_onto(
|
||||
pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run: bool = False
|
||||
) -> None:
|
||||
if (
|
||||
subprocess.run(
|
||||
[sys.executable, "-m", "ghstack", "--help"], capture_output=True
|
||||
).returncode
|
||||
!= 0
|
||||
):
|
||||
subprocess.run([sys.executable, "-m", "pip", "install", "ghstack"])
|
||||
orig_ref = f"{re.sub(r'/head$', '/orig', pr.head_ref())}"
|
||||
onto_branch = f"refs/remotes/origin/{onto_branch}"
|
||||
@ -68,11 +90,13 @@ def rebase_ghstack_onto(pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run:
|
||||
repo._run_git("config", "--global", "user.name", name)
|
||||
|
||||
os.environ["OAUTH_TOKEN"] = os.environ["GITHUB_TOKEN"]
|
||||
with open('.ghstackrc', 'w+') as f:
|
||||
f.write('[ghstack]\n' +
|
||||
"github_url=github.com\n" +
|
||||
"github_username=pytorchmergebot\n" +
|
||||
"remote_name=origin")
|
||||
with open(".ghstackrc", "w+") as f:
|
||||
f.write(
|
||||
"[ghstack]\n"
|
||||
+ "github_url=github.com\n"
|
||||
+ "github_username=pytorchmergebot\n"
|
||||
+ "remote_name=origin"
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
print("Don't know how to dry-run ghstack")
|
||||
@ -102,19 +126,37 @@ def rebase_ghstack_onto(pr: GitHubPR, repo: GitRepo, onto_branch: str, dry_run:
|
||||
if "Updated" in line:
|
||||
pr_num = int(line.split("/")[-1])
|
||||
if pr_num != pr.pr_num:
|
||||
gh_post_comment(pr.org, pr.project, pr_num,
|
||||
f"Rebased `{orig_ref}` onto `{onto_branch}` because #{pr.pr_num} was rebased, "
|
||||
"please pull locally before adding more changes (for example, via `ghstack " +
|
||||
f"checkout https://github.com/{org}/{project}/pull/{pr_num}`)", dry_run=dry_run)
|
||||
gh_post_comment(
|
||||
pr.org,
|
||||
pr.project,
|
||||
pr_num,
|
||||
f"Rebased `{orig_ref}` onto `{onto_branch}` because #{pr.pr_num} was rebased, "
|
||||
"please pull locally before adding more changes (for example, via `ghstack "
|
||||
+ f"checkout https://github.com/{org}/{project}/pull/{pr_num}`)",
|
||||
dry_run=dry_run,
|
||||
)
|
||||
else:
|
||||
gh_post_comment(pr.org, pr.project, pr_num,
|
||||
f"Successfully rebased `{orig_ref}` onto `{onto_branch}`, please pull locally " +
|
||||
"before adding more changes (for example, via `ghstack " +
|
||||
f"checkout https://github.com/{org}/{project}/pull/{pr.pr_num}`)", dry_run=dry_run)
|
||||
gh_post_comment(
|
||||
pr.org,
|
||||
pr.project,
|
||||
pr_num,
|
||||
f"Successfully rebased `{orig_ref}` onto `{onto_branch}`, please pull locally "
|
||||
+ "before adding more changes (for example, via `ghstack "
|
||||
+ f"checkout https://github.com/{org}/{project}/pull/{pr.pr_num}`)",
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
if f"Skipped https://github.com/{org}/{project}/pull/{pr.pr_num}" in push_result:
|
||||
gh_post_comment(pr.org, pr.project, pr.pr_num,
|
||||
f"Tried to rebase and push PR #{pr.pr_num}, but it was already up to date", dry_run=dry_run)
|
||||
if (
|
||||
f"Skipped https://github.com/{org}/{project}/pull/{pr.pr_num}"
|
||||
in push_result
|
||||
):
|
||||
gh_post_comment(
|
||||
pr.org,
|
||||
pr.project,
|
||||
pr.pr_num,
|
||||
f"Tried to rebase and push PR #{pr.pr_num}, but it was already up to date",
|
||||
dry_run=dry_run,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
@ -130,7 +172,13 @@ def main() -> None:
|
||||
gh_post_comment(org, project, args.pr_num, msg, dry_run=args.dry_run)
|
||||
|
||||
if pr.is_closed():
|
||||
gh_post_comment(org, project, args.pr_num, f"PR #{args.pr_num} is closed, won't rebase", dry_run=args.dry_run)
|
||||
gh_post_comment(
|
||||
org,
|
||||
project,
|
||||
args.pr_num,
|
||||
f"PR #{args.pr_num} is closed, won't rebase",
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
|
5
.github/scripts/update_commit_hashes.py
vendored
5
.github/scripts/update_commit_hashes.py
vendored
@ -1,9 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import requests
|
||||
from typing import Any, Dict
|
||||
from argparse import ArgumentParser
|
||||
from typing import Any, Dict
|
||||
|
||||
import requests
|
||||
|
||||
MERGEBOT_TOKEN = os.environ["MERGEBOT_TOKEN"]
|
||||
PYTORCHBOT_TOKEN = os.environ["PYTORCHBOT_TOKEN"]
|
||||
|
@ -826,6 +826,8 @@ init_command = [
|
||||
[[linter]]
|
||||
code = 'UFMT'
|
||||
include_patterns = [
|
||||
'.github/**/*.py',
|
||||
'test/run_test.py',
|
||||
'test/onnx/**/*.py',
|
||||
'test/test_dynamo_cudagraphs.py',
|
||||
'tools/**/*.py',
|
||||
|
494
test/run_test.py
494
test/run_test.py
@ -2,9 +2,9 @@
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from distutils.version import LooseVersion
|
||||
import functools
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
@ -12,23 +12,23 @@ import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import json
|
||||
import glob
|
||||
from typing import Dict, Optional, List, cast, Any
|
||||
from datetime import datetime
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Any, cast, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.utils import cpp_extension
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_CI,
|
||||
FILE_SCHEMA,
|
||||
TEST_WITH_ROCM,
|
||||
shell,
|
||||
set_cwd,
|
||||
parser as common_parser,
|
||||
is_slow_gradcheck_env,
|
||||
)
|
||||
import torch.distributed as dist
|
||||
from torch.multiprocessing import current_process, get_context
|
||||
from torch.testing._internal.common_utils import (
|
||||
FILE_SCHEMA,
|
||||
IS_CI,
|
||||
is_slow_gradcheck_env,
|
||||
parser as common_parser,
|
||||
set_cwd,
|
||||
shell,
|
||||
TEST_WITH_ROCM,
|
||||
)
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||
|
||||
@ -37,11 +37,12 @@ try:
|
||||
sys.path.append(str(REPO_ROOT))
|
||||
from tools.stats.export_test_times import TEST_TIMES_FILE
|
||||
from tools.testing.test_selections import (
|
||||
calculate_shards,
|
||||
get_reordered_tests,
|
||||
get_test_case_configs,
|
||||
calculate_shards,
|
||||
NUM_PROCS
|
||||
NUM_PROCS,
|
||||
)
|
||||
|
||||
HAVE_TEST_SELECTION_TOOLS = True
|
||||
except ImportError:
|
||||
HAVE_TEST_SELECTION_TOOLS = False
|
||||
@ -65,9 +66,9 @@ def maybe_set_hip_visible_devies():
|
||||
# Special handling of ROCm GHA runners for parallel (file granularity) tests.
|
||||
if torch.version.hip:
|
||||
p = current_process()
|
||||
if p.name != 'MainProcess':
|
||||
if p.name != "MainProcess":
|
||||
# this is a Process from a parallel Pool, not the MainProcess
|
||||
os.environ['HIP_VISIBLE_DEVICES'] = str(p._identity[0] % NUM_PROCS)
|
||||
os.environ["HIP_VISIBLE_DEVICES"] = str(p._identity[0] % NUM_PROCS)
|
||||
|
||||
|
||||
def strtobool(s):
|
||||
@ -77,13 +78,15 @@ def strtobool(s):
|
||||
|
||||
|
||||
def discover_tests(
|
||||
base_dir: Optional[pathlib.Path] = None,
|
||||
blocklisted_patterns: Optional[List[str]] = None,
|
||||
blocklisted_tests: Optional[List[str]] = None,
|
||||
extra_tests: Optional[List[str]] = None) -> List[str]:
|
||||
base_dir: Optional[pathlib.Path] = None,
|
||||
blocklisted_patterns: Optional[List[str]] = None,
|
||||
blocklisted_tests: Optional[List[str]] = None,
|
||||
extra_tests: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Searches for all python files starting with test_ excluding one specified by patterns
|
||||
"""
|
||||
|
||||
def skip_test_p(name: str) -> bool:
|
||||
rc = False
|
||||
if blocklisted_patterns is not None:
|
||||
@ -91,13 +94,16 @@ def discover_tests(
|
||||
if blocklisted_tests is not None:
|
||||
rc |= name in blocklisted_tests
|
||||
return rc
|
||||
|
||||
cwd = pathlib.Path(__file__).resolve().parent if base_dir is None else base_dir
|
||||
# This supports symlinks, so we can link domain library tests to PyTorch test directory
|
||||
all_py_files = [pathlib.Path(p) for p in glob.glob(f"{cwd}/**/test_*.py", recursive=True)]
|
||||
all_py_files = [
|
||||
pathlib.Path(p) for p in glob.glob(f"{cwd}/**/test_*.py", recursive=True)
|
||||
]
|
||||
rc = [str(fname.relative_to(cwd))[:-3] for fname in all_py_files]
|
||||
# Invert slashes on Windows
|
||||
if sys.platform == "win32":
|
||||
rc = [name.replace('\\', '/') for name in rc]
|
||||
rc = [name.replace("\\", "/") for name in rc]
|
||||
rc = [test for test in rc if not skip_test_p(test)]
|
||||
if extra_tests is not None:
|
||||
rc += extra_tests
|
||||
@ -106,31 +112,31 @@ def discover_tests(
|
||||
|
||||
TESTS = discover_tests(
|
||||
blocklisted_patterns=[
|
||||
'ao',
|
||||
'bottleneck_test',
|
||||
'custom_backend',
|
||||
'custom_operator',
|
||||
'fx', # executed by test_fx.py
|
||||
'jit', # executed by test_jit.py
|
||||
'mobile',
|
||||
'onnx',
|
||||
'package', # executed by test_package.py
|
||||
'quantization', # executed by test_quantization.py
|
||||
'autograd', # executed by test_autograd.py
|
||||
"ao",
|
||||
"bottleneck_test",
|
||||
"custom_backend",
|
||||
"custom_operator",
|
||||
"fx", # executed by test_fx.py
|
||||
"jit", # executed by test_jit.py
|
||||
"mobile",
|
||||
"onnx",
|
||||
"package", # executed by test_package.py
|
||||
"quantization", # executed by test_quantization.py
|
||||
"autograd", # executed by test_autograd.py
|
||||
],
|
||||
blocklisted_tests=[
|
||||
'test_bundled_images',
|
||||
'test_cpp_extensions_aot',
|
||||
'test_determination',
|
||||
'test_jit_fuser',
|
||||
'test_jit_simple',
|
||||
'test_jit_string',
|
||||
'test_kernel_launch_checks',
|
||||
'test_nnapi',
|
||||
'test_segment_reductions',
|
||||
'test_static_runtime',
|
||||
'test_throughput_benchmark',
|
||||
'test_typing',
|
||||
"test_bundled_images",
|
||||
"test_cpp_extensions_aot",
|
||||
"test_determination",
|
||||
"test_jit_fuser",
|
||||
"test_jit_simple",
|
||||
"test_jit_string",
|
||||
"test_kernel_launch_checks",
|
||||
"test_nnapi",
|
||||
"test_segment_reductions",
|
||||
"test_static_runtime",
|
||||
"test_throughput_benchmark",
|
||||
"test_typing",
|
||||
"distributed/bin/test_script",
|
||||
"distributed/elastic/multiprocessing/bin/test_script",
|
||||
"distributed/launcher/bin/test_script",
|
||||
@ -138,8 +144,8 @@ TESTS = discover_tests(
|
||||
"distributed/launcher/bin/test_script_is_torchelastic_launched",
|
||||
"distributed/launcher/bin/test_script_local_rank",
|
||||
"distributed/test_c10d_spawn",
|
||||
'distributions/test_transforms',
|
||||
'distributions/test_utils',
|
||||
"distributions/test_transforms",
|
||||
"distributions/test_utils",
|
||||
],
|
||||
extra_tests=[
|
||||
"test_cpp_extensions_aot_ninja",
|
||||
@ -153,12 +159,12 @@ TESTS = discover_tests(
|
||||
"distributed/elastic/utils/util_test",
|
||||
"distributed/elastic/utils/distributed_test",
|
||||
"distributed/elastic/multiprocessing/api_test",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
# The doctests are a special case that don't correspond to a file that discover
|
||||
# tests can enable.
|
||||
TESTS = TESTS + ['doctests']
|
||||
TESTS = TESTS + ["doctests"]
|
||||
|
||||
FSDP_TEST = [test for test in TESTS if test.startswith("distributed/fsdp")]
|
||||
|
||||
@ -243,34 +249,34 @@ RUN_PARALLEL_BLOCKLIST = [
|
||||
] + FSDP_TEST
|
||||
|
||||
CI_SERIAL_LIST = [
|
||||
'test_nn',
|
||||
'test_fake_tensor',
|
||||
'test_cpp_api_parity',
|
||||
'test_reductions',
|
||||
'test_cuda',
|
||||
'test_jit_cuda_fuser', # OOM on test_issue_1785, also profiling?
|
||||
'test_indexing',
|
||||
'test_fx_backends',
|
||||
'test_linalg',
|
||||
'test_cpp_extensions_jit',
|
||||
'test_torch',
|
||||
'test_tensor_creation_ops',
|
||||
'test_sparse_csr',
|
||||
'test_dispatch',
|
||||
'test_spectral_ops', # Cause CUDA illegal memory access https://github.com/pytorch/pytorch/issues/88916
|
||||
'nn/test_pooling',
|
||||
'nn/test_convolution', # Doesn't respect set_per_process_memory_fraction, results in OOM for other tests in slow gradcheck
|
||||
'distributions/test_distributions',
|
||||
'test_autograd', # slow gradcheck runs a test that checks the cuda memory allocator
|
||||
'test_prims', # slow gradcheck runs a test that checks the cuda memory allocator
|
||||
'test_modules', # failed test due to mismatched elements
|
||||
'functorch/test_vmap', # OOM
|
||||
'test_fx', # gets SIGKILL
|
||||
'test_dataloader', # frequently hangs for ROCm
|
||||
'test_serialization', # test_serialization_2gb_file allocates a tensor of 2GB, and could cause OOM
|
||||
'_nvfuser/test_torchscript', # OOM on test_issue_1785
|
||||
'test_schema_check', # Cause CUDA illegal memory access https://github.com/pytorch/pytorch/issues/95749
|
||||
'functorch/test_memory_efficient_fusion', # Cause CUDA OOM on ROCm
|
||||
"test_nn",
|
||||
"test_fake_tensor",
|
||||
"test_cpp_api_parity",
|
||||
"test_reductions",
|
||||
"test_cuda",
|
||||
"test_jit_cuda_fuser", # OOM on test_issue_1785, also profiling?
|
||||
"test_indexing",
|
||||
"test_fx_backends",
|
||||
"test_linalg",
|
||||
"test_cpp_extensions_jit",
|
||||
"test_torch",
|
||||
"test_tensor_creation_ops",
|
||||
"test_sparse_csr",
|
||||
"test_dispatch",
|
||||
"test_spectral_ops", # Cause CUDA illegal memory access https://github.com/pytorch/pytorch/issues/88916
|
||||
"nn/test_pooling",
|
||||
"nn/test_convolution", # Doesn't respect set_per_process_memory_fraction, results in OOM for other tests in slow gradcheck
|
||||
"distributions/test_distributions",
|
||||
"test_autograd", # slow gradcheck runs a test that checks the cuda memory allocator
|
||||
"test_prims", # slow gradcheck runs a test that checks the cuda memory allocator
|
||||
"test_modules", # failed test due to mismatched elements
|
||||
"functorch/test_vmap", # OOM
|
||||
"test_fx", # gets SIGKILL
|
||||
"test_dataloader", # frequently hangs for ROCm
|
||||
"test_serialization", # test_serialization_2gb_file allocates a tensor of 2GB, and could cause OOM
|
||||
"_nvfuser/test_torchscript", # OOM on test_issue_1785
|
||||
"test_schema_check", # Cause CUDA illegal memory access https://github.com/pytorch/pytorch/issues/95749
|
||||
"functorch/test_memory_efficient_fusion", # Cause CUDA OOM on ROCm
|
||||
]
|
||||
|
||||
# A subset of our TEST list that validates PyTorch's ops, modules, and autograd function as expected
|
||||
@ -282,7 +288,7 @@ CORE_TEST_LIST = [
|
||||
"test_ops_gradients",
|
||||
"test_ops_fwd_gradients",
|
||||
"test_ops_jit",
|
||||
"test_torch"
|
||||
"test_torch",
|
||||
]
|
||||
|
||||
# A list of distributed tests that run on multiple backends, i.e. gloo, nccl. These backends are spread out
|
||||
@ -377,6 +383,7 @@ TESTS_NOT_USING_GRADCHECK = [
|
||||
"test_quantization",
|
||||
]
|
||||
|
||||
|
||||
def print_to_stderr(message):
|
||||
print(message, file=sys.stderr)
|
||||
|
||||
@ -421,8 +428,10 @@ def run_test(
|
||||
unittest_args.extend(ci_args)
|
||||
if test_module in PYTEST_SKIP_RETRIES:
|
||||
if not options.pytest:
|
||||
raise RuntimeError("A test running without pytest cannot skip retries using "
|
||||
"the PYTEST_SKIP_RETRIES set.")
|
||||
raise RuntimeError(
|
||||
"A test running without pytest cannot skip retries using "
|
||||
"the PYTEST_SKIP_RETRIES set."
|
||||
)
|
||||
unittest_args = [arg for arg in unittest_args if "--reruns" not in arg]
|
||||
|
||||
# Extra arguments are not supported with pytest
|
||||
@ -433,9 +442,11 @@ def run_test(
|
||||
argv = [test_module + ".py"] + unittest_args
|
||||
|
||||
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
|
||||
log_fd, log_path = tempfile.mkstemp(dir=REPO_ROOT / "test" / "test-reports",
|
||||
prefix="{}_".format(test_module.replace("\\", "-").replace("/", "-")),
|
||||
suffix=".log")
|
||||
log_fd, log_path = tempfile.mkstemp(
|
||||
dir=REPO_ROOT / "test" / "test-reports",
|
||||
prefix="{}_".format(test_module.replace("\\", "-").replace("/", "-")),
|
||||
suffix=".log",
|
||||
)
|
||||
os.close(log_fd)
|
||||
command = (launcher_cmd or []) + executable + argv
|
||||
print_to_stderr("Executing {} ... [{}]".format(command, datetime.now()))
|
||||
@ -452,11 +463,15 @@ def test_cuda_primary_ctx(test_module, test_directory, options):
|
||||
)
|
||||
|
||||
|
||||
run_test_with_subprocess = functools.partial(run_test, extra_unittest_args=["--subprocess"])
|
||||
run_test_with_subprocess = functools.partial(
|
||||
run_test, extra_unittest_args=["--subprocess"]
|
||||
)
|
||||
|
||||
|
||||
def get_run_test_with_subprocess_fn():
|
||||
return lambda test_module, test_directory, options: run_test_with_subprocess(test_module, test_directory, options)
|
||||
return lambda test_module, test_directory, options: run_test_with_subprocess(
|
||||
test_module, test_directory, options
|
||||
)
|
||||
|
||||
|
||||
def _test_cpp_extensions_aot(test_directory, options, use_ninja):
|
||||
@ -493,7 +508,7 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
|
||||
python_path = os.environ.get("PYTHONPATH", "")
|
||||
from shutil import copyfile
|
||||
|
||||
os.environ['USE_NINJA'] = shell_env['USE_NINJA']
|
||||
os.environ["USE_NINJA"] = shell_env["USE_NINJA"]
|
||||
test_module = "test_cpp_extensions_aot" + ("_ninja" if use_ninja else "_no_ninja")
|
||||
copyfile(
|
||||
test_directory + "/test_cpp_extensions_aot.py",
|
||||
@ -515,7 +530,7 @@ def _test_cpp_extensions_aot(test_directory, options, use_ninja):
|
||||
os.environ["PYTHONPATH"] = python_path
|
||||
if os.path.exists(test_directory + "/" + test_module + ".py"):
|
||||
os.remove(test_directory + "/" + test_module + ".py")
|
||||
os.environ.pop('USE_NINJA')
|
||||
os.environ.pop("USE_NINJA")
|
||||
|
||||
|
||||
def test_cpp_extensions_aot_ninja(test_module, test_directory, options):
|
||||
@ -539,9 +554,15 @@ def test_distributed(test_module, test_directory, options):
|
||||
else:
|
||||
which_shard = num_shards = 1
|
||||
# Round-robin all backends to different shards
|
||||
backend_to_shard = {backend: i % num_shards + 1
|
||||
for i, backend in enumerate(DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS[test_module])}
|
||||
print_to_stderr(f"Map different backends to different shards for {test_module}: {backend_to_shard}")
|
||||
backend_to_shard = {
|
||||
backend: i % num_shards + 1
|
||||
for i, backend in enumerate(
|
||||
DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS[test_module]
|
||||
)
|
||||
}
|
||||
print_to_stderr(
|
||||
f"Map different backends to different shards for {test_module}: {backend_to_shard}"
|
||||
)
|
||||
|
||||
config = DISTRIBUTED_TESTS_CONFIG
|
||||
for backend, env_vars in config.items():
|
||||
@ -551,7 +572,9 @@ def test_distributed(test_module, test_directory, options):
|
||||
continue
|
||||
# Default to the first shard if seeing an unrecognized backend
|
||||
if which_shard != backend_to_shard.get(backend, 1):
|
||||
print_to_stderr(f"Shard {which_shard}: {backend} should be run in {backend_to_shard.get(backend, 1)}")
|
||||
print_to_stderr(
|
||||
f"Shard {which_shard}: {backend} should be run in {backend_to_shard.get(backend, 1)}"
|
||||
)
|
||||
continue
|
||||
for with_init_file in {True, False}:
|
||||
if sys.platform == "win32" and not with_init_file:
|
||||
@ -611,7 +634,12 @@ def test_distributed(test_module, test_directory, options):
|
||||
test_module, test_directory, options, launcher_cmd=mpiexec
|
||||
)
|
||||
else:
|
||||
return_code = run_test(test_module, test_directory, options, extra_unittest_args=["--subprocess"])
|
||||
return_code = run_test(
|
||||
test_module,
|
||||
test_directory,
|
||||
options,
|
||||
extra_unittest_args=["--subprocess"],
|
||||
)
|
||||
if return_code != 0:
|
||||
return return_code
|
||||
finally:
|
||||
@ -626,8 +654,10 @@ def run_doctests(test_module, test_directory, options):
|
||||
Assumes the incoming test module is called doctest, and simply executes the
|
||||
xdoctest runner on the torch library itself.
|
||||
"""
|
||||
import xdoctest
|
||||
import pathlib
|
||||
|
||||
import xdoctest
|
||||
|
||||
pkgpath = pathlib.Path(torch.__file__).parent
|
||||
|
||||
exclude_module_list = []
|
||||
@ -638,42 +668,47 @@ def run_doctests(test_module, test_directory, options):
|
||||
# 'cuda': 'auto',
|
||||
# 'cuda1': 'auto',
|
||||
# 'qengine': 'auto',
|
||||
'lapack': 0,
|
||||
'cuda': 0,
|
||||
'cuda1': 0,
|
||||
'qengine': 0,
|
||||
'autograd_profiler': 0,
|
||||
'cpp_ext': 0,
|
||||
'monitor': 0,
|
||||
"lapack": 0,
|
||||
"cuda": 0,
|
||||
"cuda1": 0,
|
||||
"qengine": 0,
|
||||
"autograd_profiler": 0,
|
||||
"cpp_ext": 0,
|
||||
"monitor": 0,
|
||||
"onnx": "auto",
|
||||
}
|
||||
|
||||
# Resolve "auto" based on a test to determine if the feature is available.
|
||||
if enabled['cuda'] == 'auto' and torch.cuda.is_available():
|
||||
enabled['cuda'] = True
|
||||
if enabled["cuda"] == "auto" and torch.cuda.is_available():
|
||||
enabled["cuda"] = True
|
||||
|
||||
if enabled['cuda1'] == 'auto' and torch.cuda.is_available() and torch.cuda.device_count() > 1:
|
||||
enabled['cuda1'] = True
|
||||
if (
|
||||
enabled["cuda1"] == "auto"
|
||||
and torch.cuda.is_available()
|
||||
and torch.cuda.device_count() > 1
|
||||
):
|
||||
enabled["cuda1"] = True
|
||||
|
||||
if enabled['lapack'] == 'auto' and torch._C.has_lapack:
|
||||
enabled['lapack'] = True
|
||||
if enabled["lapack"] == "auto" and torch._C.has_lapack:
|
||||
enabled["lapack"] = True
|
||||
|
||||
if enabled['qengine'] == 'auto':
|
||||
if enabled["qengine"] == "auto":
|
||||
try:
|
||||
# Is there a better check if quantization is enabled?
|
||||
import torch.ao.nn.quantized as nnq # NOQA
|
||||
torch.backends.quantized.engine = 'qnnpack'
|
||||
torch.backends.quantized.engine = 'fbgemm'
|
||||
|
||||
torch.backends.quantized.engine = "qnnpack"
|
||||
torch.backends.quantized.engine = "fbgemm"
|
||||
except (ImportError, RuntimeError):
|
||||
...
|
||||
else:
|
||||
enabled['qengine'] = True
|
||||
enabled["qengine"] = True
|
||||
|
||||
if enabled["onnx"] == "auto":
|
||||
try:
|
||||
import onnx # NOQA
|
||||
import onnxscript # NOQA
|
||||
import onnxruntime # NOQA
|
||||
import onnxscript # NOQA
|
||||
except ImportError:
|
||||
exclude_module_list.append("torch.onnx._internal.fx.*")
|
||||
enabled["onnx"] = False
|
||||
@ -681,69 +716,79 @@ def run_doctests(test_module, test_directory, options):
|
||||
enabled["onnx"] = True
|
||||
|
||||
# Set doctest environment variables
|
||||
if enabled['cuda']:
|
||||
os.environ['TORCH_DOCTEST_CUDA'] = '1'
|
||||
if enabled["cuda"]:
|
||||
os.environ["TORCH_DOCTEST_CUDA"] = "1"
|
||||
|
||||
if enabled['cuda1']:
|
||||
os.environ['TORCH_DOCTEST_CUDA1'] = '1'
|
||||
if enabled["cuda1"]:
|
||||
os.environ["TORCH_DOCTEST_CUDA1"] = "1"
|
||||
|
||||
if enabled['lapack']:
|
||||
os.environ['TORCH_DOCTEST_LAPACK'] = '1'
|
||||
if enabled["lapack"]:
|
||||
os.environ["TORCH_DOCTEST_LAPACK"] = "1"
|
||||
|
||||
if enabled['qengine']:
|
||||
os.environ['TORCH_DOCTEST_QENGINE'] = '1'
|
||||
if enabled["qengine"]:
|
||||
os.environ["TORCH_DOCTEST_QENGINE"] = "1"
|
||||
|
||||
if enabled['autograd_profiler']:
|
||||
os.environ['TORCH_DOCTEST_AUTOGRAD_PROFILER'] = '1'
|
||||
if enabled["autograd_profiler"]:
|
||||
os.environ["TORCH_DOCTEST_AUTOGRAD_PROFILER"] = "1"
|
||||
|
||||
if enabled['cpp_ext']:
|
||||
os.environ['TORCH_DOCTEST_CPP_EXT'] = '1'
|
||||
if enabled["cpp_ext"]:
|
||||
os.environ["TORCH_DOCTEST_CPP_EXT"] = "1"
|
||||
|
||||
if enabled['monitor']:
|
||||
os.environ['TORCH_DOCTEST_MONITOR'] = '1'
|
||||
if enabled["monitor"]:
|
||||
os.environ["TORCH_DOCTEST_MONITOR"] = "1"
|
||||
|
||||
if enabled["onnx"]:
|
||||
os.environ['TORCH_DOCTEST_ONNX'] = '1'
|
||||
os.environ["TORCH_DOCTEST_ONNX"] = "1"
|
||||
|
||||
if 0:
|
||||
# TODO: could try to enable some of these
|
||||
os.environ['TORCH_DOCTEST_QUANTIZED_DYNAMIC'] = '1'
|
||||
os.environ['TORCH_DOCTEST_ANOMOLY'] = '1'
|
||||
os.environ['TORCH_DOCTEST_AUTOGRAD'] = '1'
|
||||
os.environ['TORCH_DOCTEST_HUB'] = '1'
|
||||
os.environ['TORCH_DOCTEST_DATALOADER'] = '1'
|
||||
os.environ['TORCH_DOCTEST_FUTURES'] = '1'
|
||||
os.environ["TORCH_DOCTEST_QUANTIZED_DYNAMIC"] = "1"
|
||||
os.environ["TORCH_DOCTEST_ANOMOLY"] = "1"
|
||||
os.environ["TORCH_DOCTEST_AUTOGRAD"] = "1"
|
||||
os.environ["TORCH_DOCTEST_HUB"] = "1"
|
||||
os.environ["TORCH_DOCTEST_DATALOADER"] = "1"
|
||||
os.environ["TORCH_DOCTEST_FUTURES"] = "1"
|
||||
|
||||
pkgpath = os.path.dirname(torch.__file__)
|
||||
|
||||
xdoctest_config = {
|
||||
'global_exec': r'\n'.join([
|
||||
'from torch import nn',
|
||||
'import torch.nn.functional as F',
|
||||
'import torch',
|
||||
]),
|
||||
'analysis': 'static', # set to "auto" to test doctests in compiled modules
|
||||
'style': 'google',
|
||||
'options': '+IGNORE_WHITESPACE',
|
||||
"global_exec": r"\n".join(
|
||||
[
|
||||
"from torch import nn",
|
||||
"import torch.nn.functional as F",
|
||||
"import torch",
|
||||
]
|
||||
),
|
||||
"analysis": "static", # set to "auto" to test doctests in compiled modules
|
||||
"style": "google",
|
||||
"options": "+IGNORE_WHITESPACE",
|
||||
}
|
||||
xdoctest_verbose = max(1, options.verbose)
|
||||
run_summary = xdoctest.runner.doctest_module(
|
||||
os.fspath(pkgpath), config=xdoctest_config, verbose=xdoctest_verbose,
|
||||
command=options.xdoctest_command, argv=[],
|
||||
exclude=exclude_module_list)
|
||||
result = 1 if run_summary.get('n_failed', 0) else 0
|
||||
os.fspath(pkgpath),
|
||||
config=xdoctest_config,
|
||||
verbose=xdoctest_verbose,
|
||||
command=options.xdoctest_command,
|
||||
argv=[],
|
||||
exclude=exclude_module_list,
|
||||
)
|
||||
result = 1 if run_summary.get("n_failed", 0) else 0
|
||||
return result
|
||||
|
||||
|
||||
def print_log_file(test: str, file_path: str, failed: bool) -> None:
|
||||
num_lines = sum(1 for _ in open(file_path, 'rb'))
|
||||
num_lines = sum(1 for _ in open(file_path, "rb"))
|
||||
n = 100
|
||||
with open(file_path, "r") as f:
|
||||
print_to_stderr("")
|
||||
if failed:
|
||||
if n < num_lines:
|
||||
print_to_stderr(f"Expand the folded group to see the beginning of the log file of {test}")
|
||||
print_to_stderr(f"##[group]PRINTING BEGINNING OF LOG FILE of {test} ({file_path})")
|
||||
print_to_stderr(
|
||||
f"Expand the folded group to see the beginning of the log file of {test}"
|
||||
)
|
||||
print_to_stderr(
|
||||
f"##[group]PRINTING BEGINNING OF LOG FILE of {test} ({file_path})"
|
||||
)
|
||||
for _ in range(num_lines - n):
|
||||
print_to_stderr(next(f).rstrip())
|
||||
print_to_stderr("##[endgroup]")
|
||||
@ -776,11 +821,13 @@ def get_pytest_args(options):
|
||||
"--use-pytest",
|
||||
"-vv",
|
||||
"-rfEX",
|
||||
"-p", "no:xdist",
|
||||
"-p",
|
||||
"no:xdist",
|
||||
]
|
||||
pytest_args.extend(rerun_options)
|
||||
return pytest_args
|
||||
|
||||
|
||||
def run_test_ops(test_module, test_directory, options):
|
||||
default_unittest_args = get_pytest_args(options)
|
||||
|
||||
@ -789,11 +836,13 @@ def run_test_ops(test_module, test_directory, options):
|
||||
pool = get_context("spawn").Pool(NUM_PROCS)
|
||||
for i in range(NUM_PROCS):
|
||||
extra_unittest_args = default_unittest_args.copy()
|
||||
extra_unittest_args.extend([
|
||||
f"--shard-id={i}",
|
||||
f"--num-shards={NUM_PROCS}",
|
||||
"-k=not _linalg_cholesky_",
|
||||
])
|
||||
extra_unittest_args.extend(
|
||||
[
|
||||
f"--shard-id={i}",
|
||||
f"--num-shards={NUM_PROCS}",
|
||||
"-k=not _linalg_cholesky_",
|
||||
]
|
||||
)
|
||||
|
||||
return_code = pool.apply_async(
|
||||
run_test,
|
||||
@ -813,9 +862,11 @@ def run_test_ops(test_module, test_directory, options):
|
||||
return return_code.get()
|
||||
|
||||
extra_unittest_args = default_unittest_args.copy()
|
||||
extra_unittest_args.extend([
|
||||
"-k=_linalg_cholesky_",
|
||||
])
|
||||
extra_unittest_args.extend(
|
||||
[
|
||||
"-k=_linalg_cholesky_",
|
||||
]
|
||||
)
|
||||
|
||||
return_code = run_test(
|
||||
test_module,
|
||||
@ -859,9 +910,7 @@ CUSTOM_HANDLERS = {
|
||||
}
|
||||
|
||||
|
||||
PYTEST_SKIP_RETRIES = {
|
||||
'test_public_bindings'
|
||||
}
|
||||
PYTEST_SKIP_RETRIES = {"test_public_bindings"}
|
||||
|
||||
|
||||
def parse_test_module(test):
|
||||
@ -881,7 +930,7 @@ def parse_args():
|
||||
description="Run the PyTorch unit test suite",
|
||||
epilog="where TESTS is any of: {}".format(", ".join(TESTS)),
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
parents=[common_parser]
|
||||
parents=[common_parser],
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
@ -905,22 +954,20 @@ def parse_args():
|
||||
"If this flag is present, we will only run functorch tests. "
|
||||
"If this flag is not present, we will run all tests "
|
||||
"(including functorch tests)."
|
||||
)
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mps",
|
||||
"--mps",
|
||||
action="store_true",
|
||||
help=(
|
||||
"If this flag is present, we will only run test_mps and test_metal"
|
||||
)
|
||||
help=("If this flag is present, we will only run test_mps and test_metal"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-core",
|
||||
"--core",
|
||||
action="store_true",
|
||||
help="Only run core tests, or tests that validate PyTorch's ops, modules,"
|
||||
"and autograd. They are defined by CORE_TEST_LIST."
|
||||
"and autograd. They are defined by CORE_TEST_LIST.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pt",
|
||||
@ -1028,12 +1075,13 @@ def parse_args():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--xdoctest-command",
|
||||
default='all',
|
||||
default="all",
|
||||
help=(
|
||||
"Control the specific doctest action. "
|
||||
"Use 'list' to simply parse doctests and check syntax. "
|
||||
"Use 'all' to execute all doctests or specify a specific "
|
||||
"doctest to run")
|
||||
"doctest to run"
|
||||
),
|
||||
)
|
||||
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
@ -1088,11 +1136,15 @@ def find_test_index(test, selected_tests, find_last_index=False):
|
||||
return found_idx
|
||||
|
||||
|
||||
def exclude_tests(exclude_list, selected_tests, exclude_message=None, exact_match=False):
|
||||
def exclude_tests(
|
||||
exclude_list, selected_tests, exclude_message=None, exact_match=False
|
||||
):
|
||||
for exclude_test in exclude_list:
|
||||
tests_copy = selected_tests[:]
|
||||
for test in tests_copy:
|
||||
if (not exact_match and test.startswith(exclude_test)) or test == exclude_test:
|
||||
if (
|
||||
not exact_match and test.startswith(exclude_test)
|
||||
) or test == exclude_test:
|
||||
if exclude_message is not None:
|
||||
print_to_stderr("Excluding {} {}".format(test, exclude_message))
|
||||
selected_tests.remove(test)
|
||||
@ -1101,19 +1153,19 @@ def exclude_tests(exclude_list, selected_tests, exclude_message=None, exact_matc
|
||||
|
||||
def must_serial(file: str) -> bool:
|
||||
return (
|
||||
os.getenv("PYTORCH_TEST_RUN_EVERYTHING_IN_SERIAL", "0") == "1" or
|
||||
"distributed" in os.getenv("TEST_CONFIG", "") or
|
||||
"dynamo" in os.getenv("TEST_CONFIG", "") or
|
||||
"distributed" in file or
|
||||
file in CUSTOM_HANDLERS or
|
||||
file in RUN_PARALLEL_BLOCKLIST or
|
||||
file in CI_SERIAL_LIST or
|
||||
file in JIT_EXECUTOR_TESTS
|
||||
os.getenv("PYTORCH_TEST_RUN_EVERYTHING_IN_SERIAL", "0") == "1"
|
||||
or "distributed" in os.getenv("TEST_CONFIG", "")
|
||||
or "dynamo" in os.getenv("TEST_CONFIG", "")
|
||||
or "distributed" in file
|
||||
or file in CUSTOM_HANDLERS
|
||||
or file in RUN_PARALLEL_BLOCKLIST
|
||||
or file in CI_SERIAL_LIST
|
||||
or file in JIT_EXECUTOR_TESTS
|
||||
)
|
||||
|
||||
|
||||
def can_run_in_pytest(test):
|
||||
return os.getenv('PYTORCH_TEST_DO_NOT_USE_PYTEST', '0') == '0'
|
||||
return os.getenv("PYTORCH_TEST_DO_NOT_USE_PYTEST", "0") == "0"
|
||||
|
||||
|
||||
def get_selected_tests(options):
|
||||
@ -1127,9 +1179,13 @@ def get_selected_tests(options):
|
||||
|
||||
if options.distributed_tests:
|
||||
selected_tests = list(
|
||||
filter(lambda test_name: (test_name in DISTRIBUTED_TESTS and
|
||||
test_name not in DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS),
|
||||
selected_tests)
|
||||
filter(
|
||||
lambda test_name: (
|
||||
test_name in DISTRIBUTED_TESTS
|
||||
and test_name not in DISTRIBUTED_TESTS_WITH_MULTIPLE_BACKENDS
|
||||
),
|
||||
selected_tests,
|
||||
)
|
||||
)
|
||||
|
||||
# Filter to only run core tests when --core option is specified
|
||||
@ -1143,10 +1199,10 @@ def get_selected_tests(options):
|
||||
selected_tests = [tname for tname in selected_tests if tname in FUNCTORCH_TESTS]
|
||||
|
||||
if options.mps:
|
||||
selected_tests = ['test_mps', 'test_metal']
|
||||
selected_tests = ["test_mps", "test_metal"]
|
||||
else:
|
||||
# Exclude all mps tests otherwise
|
||||
options.exclude.extend(['test_mps', 'test_metal'])
|
||||
options.exclude.extend(["test_mps", "test_metal"])
|
||||
|
||||
# process reordering
|
||||
if options.bring_to_front:
|
||||
@ -1221,25 +1277,39 @@ def get_selected_tests(options):
|
||||
else:
|
||||
print("Found test time stats from artifacts")
|
||||
test_file_times_config = test_file_times[test_config]
|
||||
shards = calculate_shards(num_shards, selected_tests, test_file_times_config,
|
||||
must_serial=must_serial)
|
||||
shards = calculate_shards(
|
||||
num_shards,
|
||||
selected_tests,
|
||||
test_file_times_config,
|
||||
must_serial=must_serial,
|
||||
)
|
||||
_, tests_from_shard = shards[which_shard - 1]
|
||||
selected_tests = tests_from_shard
|
||||
|
||||
# skip all distributed tests if distributed package is not available.
|
||||
if not dist.is_available():
|
||||
selected_tests = exclude_tests(DISTRIBUTED_TESTS, selected_tests,
|
||||
"PyTorch is built without distributed support.")
|
||||
selected_tests = exclude_tests(
|
||||
DISTRIBUTED_TESTS,
|
||||
selected_tests,
|
||||
"PyTorch is built without distributed support.",
|
||||
)
|
||||
|
||||
# skip tests that require LAPACK when it's not available
|
||||
if not torch._C.has_lapack:
|
||||
selected_tests = exclude_tests(TESTS_REQUIRING_LAPACK, selected_tests,
|
||||
"PyTorch is built without LAPACK support.")
|
||||
selected_tests = exclude_tests(
|
||||
TESTS_REQUIRING_LAPACK,
|
||||
selected_tests,
|
||||
"PyTorch is built without LAPACK support.",
|
||||
)
|
||||
|
||||
if is_slow_gradcheck_env():
|
||||
selected_tests = exclude_tests(TESTS_NOT_USING_GRADCHECK, selected_tests,
|
||||
"Running in slow gradcheck mode, skipping tests "
|
||||
"that don't use gradcheck.", exact_match=True)
|
||||
selected_tests = exclude_tests(
|
||||
TESTS_NOT_USING_GRADCHECK,
|
||||
selected_tests,
|
||||
"Running in slow gradcheck mode, skipping tests "
|
||||
"that don't use gradcheck.",
|
||||
exact_match=True,
|
||||
)
|
||||
|
||||
if options.distributed_tests:
|
||||
# Run distributed tests with multiple backends across all shards, one per backend
|
||||
@ -1303,12 +1373,24 @@ def main():
|
||||
# parallel = in parallel with other files
|
||||
# serial = this file on it's own. The file might still be run in parallel with itself (ex test_ops)
|
||||
selected_tests_parallel = [x for x in selected_tests if not must_serial(x)]
|
||||
selected_tests_serial = [x for x in selected_tests if x not in selected_tests_parallel]
|
||||
print_to_stderr("parallel (file granularity) tests:\n {}".format("\n ".join(selected_tests_parallel)))
|
||||
print_to_stderr("serial (file granularity) tests:\n {}".format("\n ".join(selected_tests_serial)))
|
||||
selected_tests_serial = [
|
||||
x for x in selected_tests if x not in selected_tests_parallel
|
||||
]
|
||||
print_to_stderr(
|
||||
"parallel (file granularity) tests:\n {}".format(
|
||||
"\n ".join(selected_tests_parallel)
|
||||
)
|
||||
)
|
||||
print_to_stderr(
|
||||
"serial (file granularity) tests:\n {}".format(
|
||||
"\n ".join(selected_tests_serial)
|
||||
)
|
||||
)
|
||||
|
||||
# See Note [ROCm parallel CI testing]
|
||||
pool = get_context("spawn").Pool(NUM_PROCS, maxtasksperchild=None if torch.version.hip else 1)
|
||||
pool = get_context("spawn").Pool(
|
||||
NUM_PROCS, maxtasksperchild=None if torch.version.hip else 1
|
||||
)
|
||||
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
|
||||
|
||||
def success_callback(err_message):
|
||||
@ -1321,20 +1403,24 @@ def main():
|
||||
return False
|
||||
|
||||
try:
|
||||
os.environ['PARALLEL_TESTING'] = '1'
|
||||
os.environ["PARALLEL_TESTING"] = "1"
|
||||
for test in selected_tests_parallel:
|
||||
options_clone = copy.deepcopy(options)
|
||||
if can_run_in_pytest(test):
|
||||
options_clone.pytest = True
|
||||
pool.apply_async(run_test_module, args=(test, test_directory, options_clone), callback=success_callback)
|
||||
pool.apply_async(
|
||||
run_test_module,
|
||||
args=(test, test_directory, options_clone),
|
||||
callback=success_callback,
|
||||
)
|
||||
pool.close()
|
||||
pool.join()
|
||||
del os.environ['PARALLEL_TESTING']
|
||||
del os.environ["PARALLEL_TESTING"]
|
||||
|
||||
if not options.continue_through_error and len(failure_messages) != 0:
|
||||
raise RuntimeError(
|
||||
"\n".join(failure_messages) +
|
||||
"\n\nTip: You can keep running tests even on failure by "
|
||||
"\n".join(failure_messages)
|
||||
+ "\n\nTip: You can keep running tests even on failure by "
|
||||
"passing --keep-going to run_test.py.\n"
|
||||
"If running on CI, add the 'keep-going' label to "
|
||||
"your PR and rerun your jobs."
|
||||
|
Reference in New Issue
Block a user