import argparse import hashlib import json import logging import os import platform import stat import subprocess import sys import textwrap import urllib.error import urllib.request from pathlib import Path # String representing the host platform (e.g. Linux, Darwin). HOST_PLATFORM = platform.system() # PyTorch directory root result = subprocess.run( ["git", "rev-parse", "--show-toplevel"], stdout=subprocess.PIPE, check=True, ) PYTORCH_ROOT = result.stdout.decode("utf-8").strip() DRY_RUN = False def compute_file_sha256(path: str) -> str: """Compute the SHA256 hash of a file and return it as a hex string.""" # If the file doesn't exist, return an empty string. if not os.path.exists(path): return "" hash = hashlib.sha256() # Open the file in binary mode and hash it. with open(path, "rb") as f: for b in f: hash.update(b) # Return the hash as a hexadecimal string. return hash.hexdigest() def report_download_progress( chunk_number: int, chunk_size: int, file_size: int ) -> None: """ Pretty printer for file download progress. """ if file_size != -1: percent = min(1, (chunk_number * chunk_size) / file_size) bar = "#" * int(64 * percent) sys.stdout.write("\r0% |{:<64}| {}%".format(bar, int(percent * 100))) def check(binary_path: Path, reference_hash: str) -> bool: """Check whether the binary exists and is the right one. If there is hash difference, delete the actual binary. """ if not binary_path.exists(): logging.info(f"{binary_path} does not exist.") return False existing_binary_hash = compute_file_sha256(str(binary_path)) if existing_binary_hash == reference_hash: return True logging.warning( textwrap.dedent( f"""\ Found binary hash does not match reference! Found hash: {existing_binary_hash} Reference hash: {reference_hash} Deleting {binary_path} just to be safe. """ ) ) if DRY_RUN: logging.critical( "In dry run mode, so not actually deleting the binary. But consider deleting it ASAP!" ) return False try: binary_path.unlink() except OSError as e: logging.critical(f"Failed to delete binary: {e}") logging.critical( "Delete this binary as soon as possible and do not execute it!" ) return False def download( name: str, output_dir: str, url: str, reference_bin_hash: str, ) -> bool: """ Download a platform-appropriate binary if one doesn't already exist at the expected location and verifies that it is the right binary by checking its SHA256 hash against the expected hash. """ # First check if we need to do anything binary_path = Path(output_dir, name) if check(binary_path, reference_bin_hash): logging.info(f"Correct binary already exists at {binary_path}. Exiting.") return True # Create the output folder binary_path.parent.mkdir(parents=True, exist_ok=True) # Download the binary logging.info(f"Downloading {url} to {binary_path}") if DRY_RUN: logging.info("Exiting as there is nothing left to do in dry run mode") return True urllib.request.urlretrieve( url, binary_path, reporthook=report_download_progress if sys.stdout.isatty() else None, ) logging.info(f"Downloaded {name} successfully.") # Check the downloaded binary if not check(binary_path, reference_bin_hash): logging.critical(f"Downloaded binary {name} failed its hash check") return False # Ensure that exeuctable bits are set mode = os.stat(binary_path).st_mode mode |= stat.S_IXUSR os.chmod(binary_path, mode) logging.info(f"Using {name} located at {binary_path}") return True if __name__ == "__main__": parser = argparse.ArgumentParser( description="downloads and checks binaries from s3", ) parser.add_argument( "--config-json", required=True, help="Path to config json that describes where to find binaries and hashes", ) parser.add_argument( "--linter", required=True, help="Which linter to initialize from the config json", ) parser.add_argument( "--output-dir", required=True, help="place to put the binary", ) parser.add_argument( "--output-name", required=True, help="name of binary", ) parser.add_argument( "--dry-run", default=False, help="do not download, just print what would be done", ) args = parser.parse_args() if args.dry_run == "0": DRY_RUN = False else: DRY_RUN = True logging.basicConfig( format="[DRY_RUN] %(levelname)s: %(message)s" if DRY_RUN else "%(levelname)s: %(message)s", level=logging.INFO, stream=sys.stderr, ) config = json.load(open(args.config_json)) config = config[args.linter] # If the host platform is not in platform_to_hash, it is unsupported. if HOST_PLATFORM not in config: logging.error(f"Unsupported platform: {HOST_PLATFORM}") exit(1) url = config[HOST_PLATFORM]["download_url"] hash = config[HOST_PLATFORM]["hash"] ok = download(args.output_name, args.output_dir, url, hash) if not ok: logging.critical(f"Unable to initialize {args.linter}") sys.exit(1)