mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-30 19:54:53 +08:00 
			
		
		
		
	Fixes https://github.com/pytorch/pytorch/issues/104713
### Testing
Manual testing locally using #104121 and confirm that the correct merge base commit is returned [803c14490b189f9b755ecb9f2a969876088ea243](1cb87771c1) instead of the wrong value provided by `baseRefOid` (de7b6e55eb0f87f8d822f69bad6b4189a857b460).  Here is the JSON output of the GraphQL query for PR info https://paste.sh/TJ-QQWz4#fvE3Y6qoJ8vDkRBZ3vowkZ3m
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105098
Approved by: https://github.com/malfet
		
	
		
			
				
	
	
		
			170 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			170 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """GitHub Utilities"""
 | |
| 
 | |
| import json
 | |
| import os
 | |
| import warnings
 | |
| 
 | |
| from dataclasses import dataclass
 | |
| from typing import Any, Callable, cast, Dict, List, Optional, Tuple
 | |
| from urllib.error import HTTPError
 | |
| from urllib.parse import quote
 | |
| from urllib.request import Request, urlopen
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class GitHubComment:
 | |
|     body_text: str
 | |
|     created_at: str
 | |
|     author_login: str
 | |
|     author_association: str
 | |
|     editor_login: Optional[str]
 | |
|     database_id: int
 | |
|     url: str
 | |
| 
 | |
| 
 | |
| def gh_fetch_url_and_headers(
 | |
|     url: str,
 | |
|     *,
 | |
|     headers: Optional[Dict[str, str]] = None,
 | |
|     data: Optional[Dict[str, Any]] = None,
 | |
|     method: Optional[str] = None,
 | |
|     reader: Callable[[Any], Any] = lambda x: x.read(),
 | |
| ) -> Tuple[Any, Any]:
 | |
|     if headers is None:
 | |
|         headers = {}
 | |
|     token = os.environ.get("GITHUB_TOKEN")
 | |
|     if token is not None and url.startswith("https://api.github.com/"):
 | |
|         headers["Authorization"] = f"token {token}"
 | |
|     data_ = json.dumps(data).encode() if data is not None else None
 | |
|     try:
 | |
|         with urlopen(Request(url, headers=headers, data=data_, method=method)) as conn:
 | |
|             return conn.headers, reader(conn)
 | |
|     except HTTPError as err:
 | |
|         if err.code == 403 and all(
 | |
|             key in err.headers for key in ["X-RateLimit-Limit", "X-RateLimit-Used"]
 | |
|         ):
 | |
|             print(
 | |
|                 f"""Rate limit exceeded:
 | |
|                 Used: {err.headers['X-RateLimit-Used']}
 | |
|                 Limit: {err.headers['X-RateLimit-Limit']}
 | |
|                 Remaining: {err.headers['X-RateLimit-Remaining']}
 | |
|                 Resets at: {err.headers['x-RateLimit-Reset']}"""
 | |
|             )
 | |
|         raise
 | |
| 
 | |
| 
 | |
| def gh_fetch_url(
 | |
|     url: str,
 | |
|     *,
 | |
|     headers: Optional[Dict[str, str]] = None,
 | |
|     data: Optional[Dict[str, Any]] = None,
 | |
|     method: Optional[str] = None,
 | |
|     reader: Callable[[Any], Any] = lambda x: x.read(),
 | |
| ) -> Any:
 | |
|     return gh_fetch_url_and_headers(
 | |
|         url, headers=headers, data=data, reader=json.load, method=method
 | |
|     )[1]
 | |
| 
 | |
| 
 | |
| def gh_fetch_json(
 | |
|     url: str,
 | |
|     params: Optional[Dict[str, Any]] = None,
 | |
|     data: Optional[Dict[str, Any]] = None,
 | |
|     method: Optional[str] = None,
 | |
| ) -> List[Dict[str, Any]]:
 | |
|     headers = {"Accept": "application/vnd.github.v3+json"}
 | |
|     if params is not None and len(params) > 0:
 | |
|         url += "?" + "&".join(
 | |
|             f"{name}={quote(str(val))}" for name, val in params.items()
 | |
|         )
 | |
|     return cast(
 | |
|         List[Dict[str, Any]],
 | |
|         gh_fetch_url(url, headers=headers, data=data, reader=json.load, method=method),
 | |
|     )
 | |
| 
 | |
| 
 | |
| def _gh_fetch_json_any(
 | |
|     url: str,
 | |
|     params: Optional[Dict[str, Any]] = None,
 | |
|     data: Optional[Dict[str, Any]] = None,
 | |
| ) -> Any:
 | |
|     headers = {"Accept": "application/vnd.github.v3+json"}
 | |
|     if params is not None and len(params) > 0:
 | |
|         url += "?" + "&".join(
 | |
|             f"{name}={quote(str(val))}" for name, val in params.items()
 | |
|         )
 | |
|     return gh_fetch_url(url, headers=headers, data=data, reader=json.load)
 | |
| 
 | |
| 
 | |
| def gh_fetch_json_list(
 | |
|     url: str,
 | |
|     params: Optional[Dict[str, Any]] = None,
 | |
|     data: Optional[Dict[str, Any]] = None,
 | |
| ) -> List[Dict[str, Any]]:
 | |
|     return cast(List[Dict[str, Any]], _gh_fetch_json_any(url, params, data))
 | |
| 
 | |
| 
 | |
| def gh_fetch_json_dict(
 | |
|     url: str,
 | |
|     params: Optional[Dict[str, Any]] = None,
 | |
|     data: Optional[Dict[str, Any]] = None,
 | |
| ) -> Dict[str, Any]:
 | |
|     return cast(Dict[str, Any], _gh_fetch_json_any(url, params, data))
 | |
| 
 | |
| 
 | |
| def _gh_post_comment(
 | |
|     url: str, comment: str, dry_run: bool = False
 | |
| ) -> List[Dict[str, Any]]:
 | |
|     if dry_run:
 | |
|         print(comment)
 | |
|         return []
 | |
|     return gh_fetch_json_list(url, data={"body": comment})
 | |
| 
 | |
| 
 | |
| def gh_post_pr_comment(
 | |
|     org: str, repo: str, pr_num: int, comment: str, dry_run: bool = False
 | |
| ) -> List[Dict[str, Any]]:
 | |
|     return _gh_post_comment(
 | |
|         f"https://api.github.com/repos/{org}/{repo}/issues/{pr_num}/comments",
 | |
|         comment,
 | |
|         dry_run,
 | |
|     )
 | |
| 
 | |
| 
 | |
| def gh_post_commit_comment(
 | |
|     org: str, repo: str, sha: str, comment: str, dry_run: bool = False
 | |
| ) -> List[Dict[str, Any]]:
 | |
|     return _gh_post_comment(
 | |
|         f"https://api.github.com/repos/{org}/{repo}/commits/{sha}/comments",
 | |
|         comment,
 | |
|         dry_run,
 | |
|     )
 | |
| 
 | |
| 
 | |
| def gh_delete_comment(org: str, repo: str, comment_id: int) -> None:
 | |
|     url = f"https://api.github.com/repos/{org}/{repo}/issues/comments/{comment_id}"
 | |
|     gh_fetch_url(url, method="DELETE")
 | |
| 
 | |
| 
 | |
| def gh_fetch_merge_base(org: str, repo: str, base: str, head: str) -> str:
 | |
|     merge_base = ""
 | |
|     # Get the merge base using the GitHub REST API. This is the same as using
 | |
|     # git merge-base without the need to have git. The API doc can be found at
 | |
|     # https://docs.github.com/en/rest/commits/commits?apiVersion=2022-11-28#compare-two-commits
 | |
|     try:
 | |
|         json_data = gh_fetch_url(
 | |
|             f"https://api.github.com/repos/{org}/{repo}/compare/{base}...{head}",
 | |
|             headers={"Accept": "application/vnd.github.v3+json"},
 | |
|             reader=json.load,
 | |
|         )
 | |
|         if json_data:
 | |
|             merge_base = json_data.get("merge_base_commit", {}).get("sha", "")
 | |
|         else:
 | |
|             warnings.warn(
 | |
|                 f"Failed to get merge base for {base}...{head}: Empty response"
 | |
|             )
 | |
|     except Exception as error:
 | |
|         warnings.warn(f"Failed to get merge base for {base}...{head}: {error}")
 | |
| 
 | |
|     return merge_base
 |