mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68292 - noqa was typo-d to be the same as type: ignore - generalize clang-tidy initialization and use it for clang_format as well - Add a script that lets you update the binaries in s3 relatively easily Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D32403934 Pulled By: suo fbshipit-source-id: 4e21b22605216f013d87d636a205707ca8e0af36
207 lines
5.5 KiB
Python
207 lines
5.5 KiB
Python
import argparse
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import platform
|
|
import stat
|
|
import subprocess
|
|
import sys
|
|
import textwrap
|
|
import urllib.error
|
|
import urllib.request
|
|
from pathlib import Path
|
|
|
|
# String representing the host platform (e.g. Linux, Darwin).
|
|
HOST_PLATFORM = platform.system()
|
|
|
|
# PyTorch directory root
|
|
result = subprocess.run(
|
|
["git", "rev-parse", "--show-toplevel"],
|
|
stdout=subprocess.PIPE,
|
|
check=True,
|
|
)
|
|
PYTORCH_ROOT = result.stdout.decode("utf-8").strip()
|
|
|
|
DRY_RUN = False
|
|
|
|
|
|
def compute_file_sha256(path: str) -> str:
|
|
"""Compute the SHA256 hash of a file and return it as a hex string."""
|
|
# If the file doesn't exist, return an empty string.
|
|
if not os.path.exists(path):
|
|
return ""
|
|
|
|
hash = hashlib.sha256()
|
|
|
|
# Open the file in binary mode and hash it.
|
|
with open(path, "rb") as f:
|
|
for b in f:
|
|
hash.update(b)
|
|
|
|
# Return the hash as a hexadecimal string.
|
|
return hash.hexdigest()
|
|
|
|
|
|
def report_download_progress(
|
|
chunk_number: int, chunk_size: int, file_size: int
|
|
) -> None:
|
|
"""
|
|
Pretty printer for file download progress.
|
|
"""
|
|
if file_size != -1:
|
|
percent = min(1, (chunk_number * chunk_size) / file_size)
|
|
bar = "#" * int(64 * percent)
|
|
sys.stdout.write("\r0% |{:<64}| {}%".format(bar, int(percent * 100)))
|
|
|
|
|
|
def check(binary_path: Path, reference_hash: str) -> bool:
|
|
"""Check whether the binary exists and is the right one.
|
|
|
|
If there is hash difference, delete the actual binary.
|
|
"""
|
|
if not binary_path.exists():
|
|
logging.info(f"{binary_path} does not exist.")
|
|
return False
|
|
|
|
existing_binary_hash = compute_file_sha256(str(binary_path))
|
|
if existing_binary_hash == reference_hash:
|
|
return True
|
|
|
|
logging.warning(
|
|
textwrap.dedent(
|
|
f"""\
|
|
Found binary hash does not match reference!
|
|
|
|
Found hash: {existing_binary_hash}
|
|
Reference hash: {reference_hash}
|
|
|
|
Deleting {binary_path} just to be safe.
|
|
"""
|
|
)
|
|
)
|
|
if DRY_RUN:
|
|
logging.critical(
|
|
"In dry run mode, so not actually deleting the binary. But consider deleting it ASAP!"
|
|
)
|
|
return False
|
|
|
|
try:
|
|
binary_path.unlink()
|
|
except OSError as e:
|
|
logging.critical(f"Failed to delete binary: {e}")
|
|
logging.critical(
|
|
"Delete this binary as soon as possible and do not execute it!"
|
|
)
|
|
|
|
return False
|
|
|
|
|
|
def download(
|
|
name: str,
|
|
output_dir: str,
|
|
url: str,
|
|
reference_bin_hash: str,
|
|
) -> bool:
|
|
"""
|
|
Download a platform-appropriate binary if one doesn't already exist at the expected location and verifies
|
|
that it is the right binary by checking its SHA256 hash against the expected hash.
|
|
"""
|
|
# First check if we need to do anything
|
|
binary_path = Path(output_dir, name)
|
|
if check(binary_path, reference_bin_hash):
|
|
logging.info(f"Correct binary already exists at {binary_path}. Exiting.")
|
|
return True
|
|
|
|
# Create the output folder
|
|
binary_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Download the binary
|
|
logging.info(f"Downloading {url} to {binary_path}")
|
|
|
|
if DRY_RUN:
|
|
logging.info("Exiting as there is nothing left to do in dry run mode")
|
|
return True
|
|
|
|
urllib.request.urlretrieve(
|
|
url,
|
|
binary_path,
|
|
reporthook=report_download_progress if sys.stdout.isatty() else None,
|
|
)
|
|
|
|
logging.info(f"Downloaded {name} successfully.")
|
|
|
|
# Check the downloaded binary
|
|
if not check(binary_path, reference_bin_hash):
|
|
logging.critical(f"Downloaded binary {name} failed its hash check")
|
|
return False
|
|
|
|
# Ensure that exeuctable bits are set
|
|
mode = os.stat(binary_path).st_mode
|
|
mode |= stat.S_IXUSR
|
|
os.chmod(binary_path, mode)
|
|
|
|
logging.info(f"Using {name} located at {binary_path}")
|
|
return True
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(
|
|
description="downloads and checks binaries from s3",
|
|
)
|
|
parser.add_argument(
|
|
"--config-json",
|
|
required=True,
|
|
help="Path to config json that describes where to find binaries and hashes",
|
|
)
|
|
parser.add_argument(
|
|
"--linter",
|
|
required=True,
|
|
help="Which linter to initialize from the config json",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
required=True,
|
|
help="place to put the binary",
|
|
)
|
|
parser.add_argument(
|
|
"--output-name",
|
|
required=True,
|
|
help="name of binary",
|
|
)
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
default=False,
|
|
help="do not download, just print what would be done",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
if args.dry_run == "0":
|
|
DRY_RUN = False
|
|
else:
|
|
DRY_RUN = True
|
|
|
|
logging.basicConfig(
|
|
format="[DRY_RUN] %(levelname)s: %(message)s"
|
|
if DRY_RUN
|
|
else "%(levelname)s: %(message)s",
|
|
level=logging.INFO,
|
|
stream=sys.stderr,
|
|
)
|
|
|
|
config = json.load(open(args.config_json))
|
|
config = config[args.linter]
|
|
|
|
# If the host platform is not in platform_to_hash, it is unsupported.
|
|
if HOST_PLATFORM not in config:
|
|
logging.error(f"Unsupported platform: {HOST_PLATFORM}")
|
|
exit(1)
|
|
|
|
url = config[HOST_PLATFORM]["download_url"]
|
|
hash = config[HOST_PLATFORM]["hash"]
|
|
|
|
ok = download(args.output_name, args.output_dir, url, hash)
|
|
if not ok:
|
|
logging.critical(f"Unable to initialize {args.linter}")
|
|
sys.exit(1)
|