mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 04:04:57 +08:00 
			
		
		
		
	This PR removes unnecessary `pass` statement. This is semanticly safe because the bytecode for the Python code does not change. Note that if there is a docstring in the function, a empty function does not need a `pass` statement as placeholder. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133200 Approved by: https://github.com/malfet, https://github.com/eqy, https://github.com/kit1980
		
			
				
	
	
		
			453 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			453 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| 
 | |
| import os
 | |
| import re
 | |
| import tempfile
 | |
| from collections import defaultdict
 | |
| from datetime import datetime
 | |
| from functools import wraps
 | |
| from typing import (
 | |
|     Any,
 | |
|     Callable,
 | |
|     cast,
 | |
|     Dict,
 | |
|     Iterator,
 | |
|     List,
 | |
|     Optional,
 | |
|     Tuple,
 | |
|     TypeVar,
 | |
|     Union,
 | |
| )
 | |
| 
 | |
| 
 | |
| T = TypeVar("T")
 | |
| 
 | |
| RE_GITHUB_URL_MATCH = re.compile("^https://.*@?github.com/(.+)/(.+)$")
 | |
| 
 | |
| 
 | |
| def get_git_remote_name() -> str:
 | |
|     return os.getenv("GIT_REMOTE_NAME", "origin")
 | |
| 
 | |
| 
 | |
| def get_git_repo_dir() -> str:
 | |
|     from pathlib import Path
 | |
| 
 | |
|     return os.getenv("GIT_REPO_DIR", str(Path(__file__).resolve().parent.parent.parent))
 | |
| 
 | |
| 
 | |
| def fuzzy_list_to_dict(items: List[Tuple[str, str]]) -> Dict[str, List[str]]:
 | |
|     """
 | |
|     Converts list to dict preserving elements with duplicate keys
 | |
|     """
 | |
|     rc: Dict[str, List[str]] = defaultdict(list)
 | |
|     for key, val in items:
 | |
|         rc[key].append(val)
 | |
|     return dict(rc)
 | |
| 
 | |
| 
 | |
| def _check_output(items: List[str], encoding: str = "utf-8") -> str:
 | |
|     from subprocess import CalledProcessError, check_output, STDOUT
 | |
| 
 | |
|     try:
 | |
|         return check_output(items, stderr=STDOUT).decode(encoding)
 | |
|     except CalledProcessError as e:
 | |
|         msg = f"Command `{' '.join(e.cmd)}` returned non-zero exit code {e.returncode}"
 | |
|         stdout = e.stdout.decode(encoding) if e.stdout is not None else ""
 | |
|         stderr = e.stderr.decode(encoding) if e.stderr is not None else ""
 | |
|         # These get swallowed up, so print them here for debugging
 | |
|         print(f"stdout: \n{stdout}")
 | |
|         print(f"stderr: \n{stderr}")
 | |
|         if len(stderr) == 0:
 | |
|             msg += f"\n```\n{stdout}```"
 | |
|         else:
 | |
|             msg += f"\nstdout:\n```\n{stdout}```\nstderr:\n```\n{stderr}```"
 | |
|         raise RuntimeError(msg) from e
 | |
| 
 | |
| 
 | |
| class GitCommit:
 | |
|     commit_hash: str
 | |
|     title: str
 | |
|     body: str
 | |
|     author: str
 | |
|     author_date: datetime
 | |
|     commit_date: Optional[datetime]
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         commit_hash: str,
 | |
|         author: str,
 | |
|         author_date: datetime,
 | |
|         title: str,
 | |
|         body: str,
 | |
|         commit_date: Optional[datetime] = None,
 | |
|     ) -> None:
 | |
|         self.commit_hash = commit_hash
 | |
|         self.author = author
 | |
|         self.author_date = author_date
 | |
|         self.commit_date = commit_date
 | |
|         self.title = title
 | |
|         self.body = body
 | |
| 
 | |
|     def __repr__(self) -> str:
 | |
|         return f"{self.title} ({self.commit_hash})"
 | |
| 
 | |
|     def __contains__(self, item: Any) -> bool:
 | |
|         return item in self.body or item in self.title
 | |
| 
 | |
| 
 | |
| def parse_fuller_format(lines: Union[str, List[str]]) -> GitCommit:
 | |
|     """
 | |
|     Expect commit message generated using `--format=fuller --date=unix` format, i.e.:
 | |
|         commit <sha1>
 | |
|         Author:     <author>
 | |
|         AuthorDate: <author date>
 | |
|         Commit:     <committer>
 | |
|         CommitDate: <committer date>
 | |
| 
 | |
|         <title line>
 | |
| 
 | |
|         <full commit message>
 | |
| 
 | |
|     """
 | |
|     if isinstance(lines, str):
 | |
|         lines = lines.split("\n")
 | |
|     # TODO: Handle merge commits correctly
 | |
|     if len(lines) > 1 and lines[1].startswith("Merge:"):
 | |
|         del lines[1]
 | |
|     assert len(lines) > 7
 | |
|     assert lines[0].startswith("commit")
 | |
|     assert lines[1].startswith("Author: ")
 | |
|     assert lines[2].startswith("AuthorDate: ")
 | |
|     assert lines[3].startswith("Commit: ")
 | |
|     assert lines[4].startswith("CommitDate: ")
 | |
|     assert len(lines[5]) == 0
 | |
|     return GitCommit(
 | |
|         commit_hash=lines[0].split()[1].strip(),
 | |
|         author=lines[1].split(":", 1)[1].strip(),
 | |
|         author_date=datetime.fromtimestamp(int(lines[2].split(":", 1)[1].strip())),
 | |
|         commit_date=datetime.fromtimestamp(int(lines[4].split(":", 1)[1].strip())),
 | |
|         title=lines[6].strip(),
 | |
|         body="\n".join(lines[7:]),
 | |
|     )
 | |
| 
 | |
| 
 | |
| class GitRepo:
 | |
|     def __init__(self, path: str, remote: str = "origin", debug: bool = False) -> None:
 | |
|         self.repo_dir = path
 | |
|         self.remote = remote
 | |
|         self.debug = debug
 | |
| 
 | |
|     def _run_git(self, *args: Any) -> str:
 | |
|         if self.debug:
 | |
|             print(f"+ git -C {self.repo_dir} {' '.join(args)}")
 | |
|         return _check_output(["git", "-C", self.repo_dir] + list(args))
 | |
| 
 | |
|     def revlist(self, revision_range: str) -> List[str]:
 | |
|         rc = self._run_git("rev-list", revision_range, "--", ".").strip()
 | |
|         return rc.split("\n") if len(rc) > 0 else []
 | |
| 
 | |
|     def branches_containing_ref(
 | |
|         self, ref: str, *, include_remote: bool = True
 | |
|     ) -> List[str]:
 | |
|         rc = (
 | |
|             self._run_git("branch", "--remote", "--contains", ref)
 | |
|             if include_remote
 | |
|             else self._run_git("branch", "--contains", ref)
 | |
|         )
 | |
|         return [x.strip() for x in rc.split("\n") if x.strip()] if len(rc) > 0 else []
 | |
| 
 | |
|     def current_branch(self) -> Optional[str]:
 | |
|         try:
 | |
|             return self._run_git("symbolic-ref", "--short", "HEAD").strip()
 | |
|         except RuntimeError:
 | |
|             # we are in detached HEAD state
 | |
|             return None
 | |
| 
 | |
|     def checkout(self, branch: str) -> None:
 | |
|         self._run_git("checkout", branch)
 | |
| 
 | |
|     def create_branch_and_checkout(self, branch: str) -> None:
 | |
|         self._run_git("checkout", "-b", branch)
 | |
| 
 | |
|     def fetch(self, ref: Optional[str] = None, branch: Optional[str] = None) -> None:
 | |
|         if branch is None and ref is None:
 | |
|             self._run_git("fetch", self.remote)
 | |
|         elif branch is None:
 | |
|             self._run_git("fetch", self.remote, ref)
 | |
|         else:
 | |
|             self._run_git("fetch", self.remote, f"{ref}:{branch}")
 | |
| 
 | |
|     def show_ref(self, name: str) -> str:
 | |
|         refs = self._run_git("show-ref", "-s", name).strip().split("\n")
 | |
|         if not all(refs[i] == refs[0] for i in range(1, len(refs))):
 | |
|             raise RuntimeError(f"reference {name} is ambiguous")
 | |
|         return refs[0]
 | |
| 
 | |
|     def rev_parse(self, name: str) -> str:
 | |
|         return self._run_git("rev-parse", "--verify", name).strip()
 | |
| 
 | |
|     def get_merge_base(self, from_ref: str, to_ref: str) -> str:
 | |
|         return self._run_git("merge-base", from_ref, to_ref).strip()
 | |
| 
 | |
|     def patch_id(self, ref: Union[str, List[str]]) -> List[Tuple[str, str]]:
 | |
|         is_list = isinstance(ref, list)
 | |
|         if is_list:
 | |
|             if len(ref) == 0:
 | |
|                 return []
 | |
|             ref = " ".join(ref)
 | |
|         rc = _check_output(
 | |
|             ["sh", "-c", f"git -C {self.repo_dir} show {ref}|git patch-id --stable"]
 | |
|         ).strip()
 | |
|         return [cast(Tuple[str, str], x.split(" ", 1)) for x in rc.split("\n")]
 | |
| 
 | |
|     def commits_resolving_gh_pr(self, pr_num: int) -> List[str]:
 | |
|         owner, name = self.gh_owner_and_name()
 | |
|         msg = f"Pull Request resolved: https://github.com/{owner}/{name}/pull/{pr_num}"
 | |
|         rc = self._run_git("log", "--format=%H", "--grep", msg).strip()
 | |
|         return rc.split("\n") if len(rc) > 0 else []
 | |
| 
 | |
|     def get_commit(self, ref: str) -> GitCommit:
 | |
|         return parse_fuller_format(
 | |
|             self._run_git("show", "--format=fuller", "--date=unix", "--shortstat", ref)
 | |
|         )
 | |
| 
 | |
|     def cherry_pick(self, ref: str) -> None:
 | |
|         self._run_git("cherry-pick", "-x", ref)
 | |
| 
 | |
|     def revert(self, ref: str) -> None:
 | |
|         self._run_git("revert", "--no-edit", ref)
 | |
| 
 | |
|     def compute_branch_diffs(
 | |
|         self, from_branch: str, to_branch: str
 | |
|     ) -> Tuple[List[str], List[str]]:
 | |
|         """
 | |
|         Returns list of commmits that are missing in each other branch since their merge base
 | |
|         Might be slow if merge base is between two branches is pretty far off
 | |
|         """
 | |
|         from_ref = self.rev_parse(from_branch)
 | |
|         to_ref = self.rev_parse(to_branch)
 | |
|         merge_base = self.get_merge_base(from_ref, to_ref)
 | |
|         from_commits = self.revlist(f"{merge_base}..{from_ref}")
 | |
|         to_commits = self.revlist(f"{merge_base}..{to_ref}")
 | |
|         from_ids = fuzzy_list_to_dict(self.patch_id(from_commits))
 | |
|         to_ids = fuzzy_list_to_dict(self.patch_id(to_commits))
 | |
|         for patch_id in set(from_ids).intersection(set(to_ids)):
 | |
|             from_values = from_ids[patch_id]
 | |
|             to_values = to_ids[patch_id]
 | |
|             if len(from_values) != len(to_values):
 | |
|                 # Eliminate duplicate commits+reverts from the list
 | |
|                 while len(from_values) > 0 and len(to_values) > 0:
 | |
|                     frc = self.get_commit(from_values.pop())
 | |
|                     toc = self.get_commit(to_values.pop())
 | |
|                     # FRC branch might have PR number added to the title
 | |
|                     if frc.title != toc.title or frc.author_date != toc.author_date:
 | |
|                         # HACK: Same commit were merged, reverted and landed again
 | |
|                         # which creates a tracking problem
 | |
|                         if (
 | |
|                             "pytorch/pytorch" not in self.remote_url()
 | |
|                             or frc.commit_hash
 | |
|                             not in {
 | |
|                                 "0a6a1b27a464ba5be5f587cce2ee12ab8c504dbf",
 | |
|                                 "6d0f4a1d545a8f161df459e8d4ccafd4b9017dbe",
 | |
|                                 "edf909e58f06150f7be41da2f98a3b9de3167bca",
 | |
|                                 "a58c6aea5a0c9f8759a4154e46f544c8b03b8db1",
 | |
|                                 "7106d216c29ca16a3504aa2bedad948ebcf4abc2",
 | |
|                             }
 | |
|                         ):
 | |
|                             raise RuntimeError(
 | |
|                                 f"Unexpected differences between {frc} and {toc}"
 | |
|                             )
 | |
|                     from_commits.remove(frc.commit_hash)
 | |
|                     to_commits.remove(toc.commit_hash)
 | |
|                 continue
 | |
|             for commit in from_values:
 | |
|                 from_commits.remove(commit)
 | |
|             for commit in to_values:
 | |
|                 to_commits.remove(commit)
 | |
|         # Another HACK: Patch-id is not stable for commits with binary files or for big changes across commits
 | |
|         # I.e. cherry-picking those from one branch into another will change patchid
 | |
|         if "pytorch/pytorch" in self.remote_url():
 | |
|             for excluded_commit in {
 | |
|                 "8e09e20c1dafcdbdb45c2d1574da68a32e54a3a5",
 | |
|                 "5f37e5c2a39c3acb776756a17730b865f0953432",
 | |
|                 "b5222584e6d6990c6585981a936defd1af14c0ba",
 | |
|                 "84d9a2e42d5ed30ec3b8b4140c38dd83abbce88d",
 | |
|                 "f211ec90a6cdc8a2a5795478b5b5c8d7d7896f7e",
 | |
|             }:
 | |
|                 if excluded_commit in from_commits:
 | |
|                     from_commits.remove(excluded_commit)
 | |
| 
 | |
|         return (from_commits, to_commits)
 | |
| 
 | |
|     def cherry_pick_commits(self, from_branch: str, to_branch: str) -> None:
 | |
|         orig_branch = self.current_branch()
 | |
|         assert orig_branch is not None, "Must be on a branch"
 | |
|         self.checkout(to_branch)
 | |
|         from_commits, to_commits = self.compute_branch_diffs(from_branch, to_branch)
 | |
|         if len(from_commits) == 0:
 | |
|             print("Nothing to do")
 | |
|             self.checkout(orig_branch)
 | |
|             return
 | |
|         for commit in reversed(from_commits):
 | |
|             print(f"Cherry picking commit {commit}")
 | |
|             self.cherry_pick(commit)
 | |
|         self.checkout(orig_branch)
 | |
| 
 | |
|     def push(self, branch: str, dry_run: bool, retry: int = 3) -> None:
 | |
|         for cnt in range(retry):
 | |
|             try:
 | |
|                 if dry_run:
 | |
|                     self._run_git("push", "--dry-run", self.remote, branch)
 | |
|                 else:
 | |
|                     self._run_git("push", self.remote, branch)
 | |
|             except RuntimeError as e:
 | |
|                 print(f"{cnt} push attempt failed with {e}")
 | |
|                 self.fetch()
 | |
|                 self._run_git("rebase", f"{self.remote}/{branch}")
 | |
| 
 | |
|     def head_hash(self) -> str:
 | |
|         return self._run_git("show-ref", "--hash", "HEAD").strip()
 | |
| 
 | |
|     def remote_url(self) -> str:
 | |
|         return self._run_git("remote", "get-url", self.remote)
 | |
| 
 | |
|     def gh_owner_and_name(self) -> Tuple[str, str]:
 | |
|         url = os.getenv("GIT_REMOTE_URL", None)
 | |
|         if url is None:
 | |
|             url = self.remote_url()
 | |
|         rc = RE_GITHUB_URL_MATCH.match(url)
 | |
|         if rc is None:
 | |
|             raise RuntimeError(f"Unexpected url format {url}")
 | |
|         return cast(Tuple[str, str], rc.groups())
 | |
| 
 | |
|     def commit_message(self, ref: str) -> str:
 | |
|         return self._run_git("log", "-1", "--format=%B", ref)
 | |
| 
 | |
|     def amend_commit_message(self, msg: str) -> None:
 | |
|         self._run_git("commit", "--amend", "-m", msg)
 | |
| 
 | |
|     def diff(self, from_ref: str, to_ref: Optional[str] = None) -> str:
 | |
|         if to_ref is None:
 | |
|             return self._run_git("diff", f"{from_ref}^!")
 | |
|         return self._run_git("diff", f"{from_ref}..{to_ref}")
 | |
| 
 | |
| 
 | |
| def clone_repo(username: str, password: str, org: str, project: str) -> GitRepo:
 | |
|     path = tempfile.mkdtemp()
 | |
|     _check_output(
 | |
|         [
 | |
|             "git",
 | |
|             "clone",
 | |
|             f"https://{username}:{password}@github.com/{org}/{project}",
 | |
|             path,
 | |
|         ]
 | |
|     ).strip()
 | |
|     return GitRepo(path=path)
 | |
| 
 | |
| 
 | |
| class PeekableIterator(Iterator[str]):
 | |
|     def __init__(self, val: str) -> None:
 | |
|         self._val = val
 | |
|         self._idx = -1
 | |
| 
 | |
|     def peek(self) -> Optional[str]:
 | |
|         if self._idx + 1 >= len(self._val):
 | |
|             return None
 | |
|         return self._val[self._idx + 1]
 | |
| 
 | |
|     def __iter__(self) -> "PeekableIterator":
 | |
|         return self
 | |
| 
 | |
|     def __next__(self) -> str:
 | |
|         rc = self.peek()
 | |
|         if rc is None:
 | |
|             raise StopIteration
 | |
|         self._idx += 1
 | |
|         return rc
 | |
| 
 | |
| 
 | |
| def patterns_to_regex(allowed_patterns: List[str]) -> Any:
 | |
|     """
 | |
|     pattern is glob-like, i.e. the only special sequences it has are:
 | |
|       - ? - matches single character
 | |
|       - * - matches any non-folder separator characters or no character
 | |
|       - ** - matches any characters or no character
 | |
|       Assuming that patterns are free of braces and backslashes
 | |
|       the only character that needs to be escaped are dot and plus
 | |
|     """
 | |
|     rc = "("
 | |
|     for idx, pattern in enumerate(allowed_patterns):
 | |
|         if idx > 0:
 | |
|             rc += "|"
 | |
|         pattern_ = PeekableIterator(pattern)
 | |
|         assert not any(c in pattern for c in "{}()[]\\")
 | |
|         for c in pattern_:
 | |
|             if c == ".":
 | |
|                 rc += "\\."
 | |
|             elif c == "+":
 | |
|                 rc += "\\+"
 | |
|             elif c == "*":
 | |
|                 if pattern_.peek() == "*":
 | |
|                     next(pattern_)
 | |
|                     rc += ".*"
 | |
|                 else:
 | |
|                     rc += "[^/]*"
 | |
|             else:
 | |
|                 rc += c
 | |
|     rc += ")"
 | |
|     return re.compile(rc)
 | |
| 
 | |
| 
 | |
| def _shasum(value: str) -> str:
 | |
|     import hashlib
 | |
| 
 | |
|     m = hashlib.sha256()
 | |
|     m.update(value.encode("utf-8"))
 | |
|     return m.hexdigest()
 | |
| 
 | |
| 
 | |
| def is_commit_hash(ref: str) -> bool:
 | |
|     "True if ref is hexadecimal number, else false"
 | |
|     try:
 | |
|         int(ref, 16)
 | |
|     except ValueError:
 | |
|         return False
 | |
|     return True
 | |
| 
 | |
| 
 | |
| def are_ghstack_branches_in_sync(
 | |
|     repo: GitRepo, head_ref: str, base_ref: Optional[str] = None
 | |
| ) -> bool:
 | |
|     """Checks that diff between base and head is the same as diff between orig and its parent"""
 | |
|     orig_ref = re.sub(r"/head$", "/orig", head_ref)
 | |
|     if base_ref is None:
 | |
|         base_ref = re.sub(r"/head$", "/base", head_ref)
 | |
|     orig_diff_sha = _shasum(repo.diff(f"{repo.remote}/{orig_ref}"))
 | |
|     head_diff_sha = _shasum(
 | |
|         repo.diff(
 | |
|             base_ref if is_commit_hash(base_ref) else f"{repo.remote}/{base_ref}",
 | |
|             f"{repo.remote}/{head_ref}",
 | |
|         )
 | |
|     )
 | |
|     return orig_diff_sha == head_diff_sha
 | |
| 
 | |
| 
 | |
| def retries_decorator(
 | |
|     rc: Any = None, num_retries: int = 3
 | |
| ) -> Callable[[Callable[..., T]], Callable[..., T]]:
 | |
|     def decorator(f: Callable[..., T]) -> Callable[..., T]:
 | |
|         @wraps(f)
 | |
|         def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> T:
 | |
|             for idx in range(num_retries):
 | |
|                 try:
 | |
|                     return f(*args, **kwargs)
 | |
|                 except Exception as e:
 | |
|                     print(
 | |
|                         f'Attempt {idx} of {num_retries} to call {f.__name__} failed with "{e}"'
 | |
|                     )
 | |
|             return cast(T, rc)
 | |
| 
 | |
|         return wrapper
 | |
| 
 | |
|     return decorator
 |