mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129753 Approved by: https://github.com/malfet ghstack dependencies: #129752
124 lines
3.5 KiB
Python
Executable File
124 lines
3.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import os
|
|
import re
|
|
import subprocess
|
|
from datetime import datetime
|
|
from distutils.util import strtobool
|
|
from pathlib import Path
|
|
|
|
|
|
LEADING_V_PATTERN = re.compile("^v")
|
|
TRAILING_RC_PATTERN = re.compile("-rc[0-9]*$")
|
|
LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$")
|
|
|
|
|
|
class NoGitTagException(Exception):
|
|
pass
|
|
|
|
|
|
def get_pytorch_root() -> Path:
|
|
return Path(
|
|
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
|
|
.decode("ascii")
|
|
.strip()
|
|
)
|
|
|
|
|
|
def get_tag() -> str:
|
|
root = get_pytorch_root()
|
|
try:
|
|
dirty_tag = (
|
|
subprocess.check_output(["git", "describe", "--tags", "--exact"], cwd=root)
|
|
.decode("ascii")
|
|
.strip()
|
|
)
|
|
except subprocess.CalledProcessError:
|
|
return ""
|
|
# Strip leading v that we typically do when we tag branches
|
|
# ie: v1.7.1 -> 1.7.1
|
|
tag = re.sub(LEADING_V_PATTERN, "", dirty_tag)
|
|
# Strip trailing rc pattern
|
|
# ie: 1.7.1-rc1 -> 1.7.1
|
|
tag = re.sub(TRAILING_RC_PATTERN, "", tag)
|
|
# Ignore ciflow tags
|
|
if tag.startswith("ciflow/"):
|
|
return ""
|
|
return tag
|
|
|
|
|
|
def get_base_version() -> str:
|
|
root = get_pytorch_root()
|
|
dirty_version = open(root / "version.txt").read().strip()
|
|
# Strips trailing a0 from version.txt, not too sure why it's there in the
|
|
# first place
|
|
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)
|
|
|
|
|
|
class PytorchVersion:
|
|
def __init__(
|
|
self,
|
|
gpu_arch_type: str,
|
|
gpu_arch_version: str,
|
|
no_build_suffix: bool,
|
|
) -> None:
|
|
self.gpu_arch_type = gpu_arch_type
|
|
self.gpu_arch_version = gpu_arch_version
|
|
self.no_build_suffix = no_build_suffix
|
|
|
|
def get_post_build_suffix(self) -> str:
|
|
if self.no_build_suffix:
|
|
return ""
|
|
if self.gpu_arch_type == "cuda":
|
|
return f"+cu{self.gpu_arch_version.replace('.', '')}"
|
|
return f"+{self.gpu_arch_type}{self.gpu_arch_version}"
|
|
|
|
def get_release_version(self) -> str:
|
|
if not get_tag():
|
|
raise NoGitTagException(
|
|
"Not on a git tag, are you sure you want a release version?"
|
|
)
|
|
return f"{get_tag()}{self.get_post_build_suffix()}"
|
|
|
|
def get_nightly_version(self) -> str:
|
|
date_str = datetime.today().strftime("%Y%m%d")
|
|
build_suffix = self.get_post_build_suffix()
|
|
return f"{get_base_version()}.dev{date_str}{build_suffix}"
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate pytorch version for binary builds"
|
|
)
|
|
parser.add_argument(
|
|
"--no-build-suffix",
|
|
action="store_true",
|
|
help="Whether or not to add a build suffix typically (+cpu)",
|
|
default=strtobool(os.environ.get("NO_BUILD_SUFFIX", "False")),
|
|
)
|
|
parser.add_argument(
|
|
"--gpu-arch-type",
|
|
type=str,
|
|
help="GPU arch you are building for, typically (cpu, cuda, rocm)",
|
|
default=os.environ.get("GPU_ARCH_TYPE", "cpu"),
|
|
)
|
|
parser.add_argument(
|
|
"--gpu-arch-version",
|
|
type=str,
|
|
help="GPU arch version, typically (10.2, 4.0), leave blank for CPU",
|
|
default=os.environ.get("GPU_ARCH_VERSION", ""),
|
|
)
|
|
args = parser.parse_args()
|
|
version_obj = PytorchVersion(
|
|
args.gpu_arch_type, args.gpu_arch_version, args.no_build_suffix
|
|
)
|
|
try:
|
|
print(version_obj.get_release_version())
|
|
except NoGitTagException:
|
|
print(version_obj.get_nightly_version())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|