Files
pytorch/.github/scripts/generate_pytorch_version.py
Huy Do 4c0dce50fd [BE] Apply ufmt to run_test and GitHub Python util scripts (#97588)
This has been bugging me for a while as I'm working on these Python scripts and they are not tracked by ufmt linter.  So I add these script into that linter.

```
[[linter]]
code = 'UFMT'
include_patterns = [
    '.github/**/*.py',
    'test/run_test.py',
```

This change should just work and not break anything as ufmt (black + usort) linter is very safe to use for standalone util scripts.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97588
Approved by: https://github.com/kit1980
2023-03-26 04:52:55 +00:00

124 lines
3.5 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import os
import re
import subprocess
from datetime import datetime
from distutils.util import strtobool
from pathlib import Path
LEADING_V_PATTERN = re.compile("^v")
TRAILING_RC_PATTERN = re.compile("-rc[0-9]*$")
LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$")
class NoGitTagException(Exception):
pass
def get_pytorch_root() -> Path:
return Path(
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
.decode("ascii")
.strip()
)
def get_tag() -> str:
root = get_pytorch_root()
try:
dirty_tag = (
subprocess.check_output(["git", "describe", "--tags", "--exact"], cwd=root)
.decode("ascii")
.strip()
)
except subprocess.CalledProcessError:
return ""
# Strip leading v that we typically do when we tag branches
# ie: v1.7.1 -> 1.7.1
tag = re.sub(LEADING_V_PATTERN, "", dirty_tag)
# Strip trailing rc pattern
# ie: 1.7.1-rc1 -> 1.7.1
tag = re.sub(TRAILING_RC_PATTERN, "", tag)
# Ignore ciflow tags
if tag.startswith("ciflow/"):
return ""
return tag
def get_base_version() -> str:
root = get_pytorch_root()
dirty_version = open(root / "version.txt", "r").read().strip()
# Strips trailing a0 from version.txt, not too sure why it's there in the
# first place
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)
class PytorchVersion:
def __init__(
self,
gpu_arch_type: str,
gpu_arch_version: str,
no_build_suffix: bool,
) -> None:
self.gpu_arch_type = gpu_arch_type
self.gpu_arch_version = gpu_arch_version
self.no_build_suffix = no_build_suffix
def get_post_build_suffix(self) -> str:
if self.no_build_suffix:
return ""
if self.gpu_arch_type == "cuda":
return f"+cu{self.gpu_arch_version.replace('.', '')}"
return f"+{self.gpu_arch_type}{self.gpu_arch_version}"
def get_release_version(self) -> str:
if not get_tag():
raise NoGitTagException(
"Not on a git tag, are you sure you want a release version?"
)
return f"{get_tag()}{self.get_post_build_suffix()}"
def get_nightly_version(self) -> str:
date_str = datetime.today().strftime("%Y%m%d")
build_suffix = self.get_post_build_suffix()
return f"{get_base_version()}.dev{date_str}{build_suffix}"
def main() -> None:
parser = argparse.ArgumentParser(
description="Generate pytorch version for binary builds"
)
parser.add_argument(
"--no-build-suffix",
action="store_true",
help="Whether or not to add a build suffix typically (+cpu)",
default=strtobool(os.environ.get("NO_BUILD_SUFFIX", "False")),
)
parser.add_argument(
"--gpu-arch-type",
type=str,
help="GPU arch you are building for, typically (cpu, cuda, rocm)",
default=os.environ.get("GPU_ARCH_TYPE", "cpu"),
)
parser.add_argument(
"--gpu-arch-version",
type=str,
help="GPU arch version, typically (10.2, 4.0), leave blank for CPU",
default=os.environ.get("GPU_ARCH_VERSION", ""),
)
args = parser.parse_args()
version_obj = PytorchVersion(
args.gpu_arch_type, args.gpu_arch_version, args.no_build_suffix
)
try:
print(version_obj.get_release_version())
except NoGitTagException:
print(version_obj.get_nightly_version())
if __name__ == "__main__":
main()