#!/usr/bin/env python3 import argparse import os import re import subprocess from datetime import datetime from distutils.util import strtobool from pathlib import Path LEADING_V_PATTERN = re.compile("^v") TRAILING_RC_PATTERN = re.compile("-rc[0-9]*$") LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$") class NoGitTagException(Exception): pass def get_pytorch_root() -> Path: return Path( subprocess.check_output(["git", "rev-parse", "--show-toplevel"]) .decode("ascii") .strip() ) def get_tag() -> str: root = get_pytorch_root() try: dirty_tag = ( subprocess.check_output(["git", "describe", "--tags", "--exact"], cwd=root) .decode("ascii") .strip() ) except subprocess.CalledProcessError: return "" # Strip leading v that we typically do when we tag branches # ie: v1.7.1 -> 1.7.1 tag = re.sub(LEADING_V_PATTERN, "", dirty_tag) # Strip trailing rc pattern # ie: 1.7.1-rc1 -> 1.7.1 tag = re.sub(TRAILING_RC_PATTERN, "", tag) # Ignore ciflow tags if tag.startswith("ciflow/"): return "" return tag def get_base_version() -> str: root = get_pytorch_root() dirty_version = open(root / "version.txt").read().strip() # Strips trailing a0 from version.txt, not too sure why it's there in the # first place return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version) class PytorchVersion: def __init__( self, gpu_arch_type: str, gpu_arch_version: str, no_build_suffix: bool, ) -> None: self.gpu_arch_type = gpu_arch_type self.gpu_arch_version = gpu_arch_version self.no_build_suffix = no_build_suffix def get_post_build_suffix(self) -> str: if self.no_build_suffix: return "" if self.gpu_arch_type == "cuda": return f"+cu{self.gpu_arch_version.replace('.', '')}" return f"+{self.gpu_arch_type}{self.gpu_arch_version}" def get_release_version(self) -> str: if not get_tag(): raise NoGitTagException( "Not on a git tag, are you sure you want a release version?" ) return f"{get_tag()}{self.get_post_build_suffix()}" def get_nightly_version(self) -> str: date_str = datetime.today().strftime("%Y%m%d") build_suffix = self.get_post_build_suffix() return f"{get_base_version()}.dev{date_str}{build_suffix}" def main() -> None: parser = argparse.ArgumentParser( description="Generate pytorch version for binary builds" ) parser.add_argument( "--no-build-suffix", action="store_true", help="Whether or not to add a build suffix typically (+cpu)", default=strtobool(os.environ.get("NO_BUILD_SUFFIX", "False")), ) parser.add_argument( "--gpu-arch-type", type=str, help="GPU arch you are building for, typically (cpu, cuda, rocm)", default=os.environ.get("GPU_ARCH_TYPE", "cpu"), ) parser.add_argument( "--gpu-arch-version", type=str, help="GPU arch version, typically (10.2, 4.0), leave blank for CPU", default=os.environ.get("GPU_ARCH_VERSION", ""), ) args = parser.parse_args() version_obj = PytorchVersion( args.gpu_arch_type, args.gpu_arch_version, args.no_build_suffix ) try: print(version_obj.get_release_version()) except NoGitTagException: print(version_obj.get_nightly_version()) if __name__ == "__main__": main()