mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53133 In light of some issues where users were having trouble installing CUDA specific versions of pytorch we should no longer have special privileges for CUDA 10.2. Recently I added scripts/release/promote/prep_binary_for_pypi.sh (https://github.com/pytorch/pytorch/pull/53056) to make it so that we could theoretically promote any wheel we publish to download.pytorch.org to pypi Signed-off-by: Eli Uriegas <eliuriegas@fb.com> Test Plan: Imported from OSS Reviewed By: walterddr Differential Revision: D26759823 Pulled By: seemethere fbshipit-source-id: 2d2b29e7fef0f48c23f3c853bdca6144b7c61f22 (cherry picked from commit b8546bde09c7c00581fe4ceb061e5942c7b78b20) Signed-off-by: Eli Uriegas <eliuriegas@fb.com>
114 lines
3.4 KiB
Python
Executable File
114 lines
3.4 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import os
|
|
import subprocess
|
|
import re
|
|
|
|
from datetime import datetime
|
|
from distutils.util import strtobool
|
|
from pathlib import Path
|
|
|
|
LEADING_V_PATTERN = re.compile("^v")
|
|
TRAILING_RC_PATTERN = re.compile("-rc[0-9]*$")
|
|
LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$")
|
|
|
|
class NoGitTagException(Exception):
|
|
pass
|
|
|
|
def get_pytorch_root():
|
|
return Path(subprocess.check_output(
|
|
['git', 'rev-parse', '--show-toplevel']
|
|
).decode('ascii').strip())
|
|
|
|
def get_tag():
|
|
root = get_pytorch_root()
|
|
# We're on a tag
|
|
am_on_tag = (
|
|
subprocess.run(
|
|
['git', 'describe', '--tags', '--exact'],
|
|
cwd=root,
|
|
stdout=subprocess.DEVNULL,
|
|
stderr=subprocess.DEVNULL
|
|
).returncode == 0
|
|
)
|
|
tag = ""
|
|
if am_on_tag:
|
|
dirty_tag = subprocess.check_output(
|
|
['git', 'describe'],
|
|
cwd=root
|
|
).decode('ascii').strip()
|
|
# Strip leading v that we typically do when we tag branches
|
|
# ie: v1.7.1 -> 1.7.1
|
|
tag = re.sub(LEADING_V_PATTERN, "", dirty_tag)
|
|
# Strip trailing rc pattern
|
|
# ie: 1.7.1-rc1 -> 1.7.1
|
|
tag = re.sub(TRAILING_RC_PATTERN, "", tag)
|
|
return tag
|
|
|
|
def get_base_version():
|
|
root = get_pytorch_root()
|
|
dirty_version = open(root / 'version.txt', 'r').read().strip()
|
|
# Strips trailing a0 from version.txt, not too sure why it's there in the
|
|
# first place
|
|
return re.sub(LEGACY_BASE_VERSION_SUFFIX_PATTERN, "", dirty_version)
|
|
|
|
class PytorchVersion:
|
|
def __init__(self, gpu_arch_type, gpu_arch_version, no_build_suffix):
|
|
self.gpu_arch_type = gpu_arch_type
|
|
self.gpu_arch_version = gpu_arch_version
|
|
self.no_build_suffix = no_build_suffix
|
|
|
|
def get_post_build_suffix(self):
|
|
if self.gpu_arch_type == "cuda":
|
|
return f"+cu{self.gpu_arch_version.replace('.', '')}"
|
|
return f"+{self.gpu_arch_type}{self.gpu_arch_version}"
|
|
|
|
def get_release_version(self):
|
|
if not get_tag():
|
|
raise NoGitTagException(
|
|
"Not on a git tag, are you sure you want a release version?"
|
|
)
|
|
return f"{get_tag()}{self.get_post_build_suffix()}"
|
|
|
|
def get_nightly_version(self):
|
|
date_str = datetime.today().strftime('%Y%m%d')
|
|
build_suffix = self.get_post_build_suffix()
|
|
return f"{get_base_version()}.dev{date_str}{build_suffix}"
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate pytorch version for binary builds"
|
|
)
|
|
parser.add_argument(
|
|
"--no-build-suffix",
|
|
type=strtobool,
|
|
help="Whether or not to add a build suffix typically (+cpu)",
|
|
default=os.environ.get("NO_BUILD_SUFFIX", False)
|
|
)
|
|
parser.add_argument(
|
|
"--gpu-arch-type",
|
|
type=str,
|
|
help="GPU arch you are building for, typically (cpu, cuda, rocm)",
|
|
default=os.environ.get("GPU_ARCH_TYPE", "cpu")
|
|
)
|
|
parser.add_argument(
|
|
"--gpu-arch-version",
|
|
type=str,
|
|
help="GPU arch version, typically (10.2, 4.0), leave blank for CPU",
|
|
default=os.environ.get("GPU_ARCH_VERSION", "")
|
|
)
|
|
args = parser.parse_args()
|
|
version_obj = PytorchVersion(
|
|
args.gpu_arch_type,
|
|
args.gpu_arch_version,
|
|
args.no_build_suffix
|
|
)
|
|
try:
|
|
print(version_obj.get_release_version())
|
|
except NoGitTagException:
|
|
print(version_obj.get_nightly_version())
|
|
|
|
if __name__ == "__main__":
|
|
main()
|