tools: Add ability to grab release versions

Adds the ability for generate_torch_version to grab release versions
based on the current tag. Also includes a regex to check if the tagged
version matches our release pattern (vX.Y.Z) so we don't collide with
ciflow tags

Signed-off-by: Eli Uriegas <eliuriegasfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78584

Signed-off-by: Eli Uriegas <eliuriegas@fb.com>

Approved by: https://github.com/janeyx99
This commit is contained in:
Eli Uriegas
2022-05-31 15:59:41 -07:00
committed by PyTorch MergeBot
parent 44aa4ad894
commit ffaee6619c

View File

@ -1,11 +1,16 @@
import argparse
import os
import re
import subprocess
from pathlib import Path
from setuptools import distutils # type: ignore[import]
from typing import Optional, Union
UNKNOWN = "Unknown"
RELEASE_PATTERN = re.compile(r"/v[0-9]+(\.[0-9]+)*(-rc[0-9]+)?/")
def get_sha(pytorch_root: Union[str, Path]) -> str:
try:
return (
@ -14,7 +19,24 @@ def get_sha(pytorch_root: Union[str, Path]) -> str:
.strip()
)
except Exception:
return "Unknown"
return UNKNOWN
def get_tag(pytorch_root: Union[str, Path]) -> str:
try:
tag = (
subprocess.check_output(
["git", "describe", "--tags", "--exact"], cwd=pytorch_root
)
.decode("ascii")
.strip()
)
if RELEASE_PATTERN.match(tag):
return tag
else:
return UNKNOWN
except Exception:
return UNKNOWN
def get_torch_version(sha: Optional[str] = None) -> str:
@ -27,7 +49,7 @@ def get_torch_version(sha: Optional[str] = None) -> str:
version = os.getenv("PYTORCH_BUILD_VERSION", "")
if build_number > 1:
version += ".post" + str(build_number)
elif sha != "Unknown":
elif sha != UNKNOWN:
if sha is None:
sha = get_sha(pytorch_root)
version += "+git" + sha[:7]
@ -54,8 +76,13 @@ if __name__ == "__main__":
pytorch_root = Path(__file__).parent.parent
version_path = pytorch_root / "torch" / "version.py"
# Attempt to get tag first, fall back to sha if a tag was not found
tagged_version = get_tag(pytorch_root)
sha = get_sha(pytorch_root)
version = get_torch_version(sha)
if tagged_version == UNKNOWN:
version = get_torch_version(sha)
else:
version = tagged_version
with open(version_path, "w") as f:
f.write("__version__ = '{}'\n".format(version))