Files
pytorch/tools/generate_torch_version.py
Klaus Zimmermann 42928876eb Add sdist handling to version finding (#160315)
The version finding logic triggered from `setup.py` generally tries to take the git information into account.
This is fine for most situations where we are building from a checkout, but it creates a problem in the case of sdists, as here the version is determined at the time of sdist creation, taking the git information into account, but then later recalculated when building wheels or installing from the sdist, now with the git information missing.

The solution is to take the version information directly from the sdist, which this PR adds by means of parsing the `PKG-INFO` which marks an unpacked sdist.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160315
Approved by: https://github.com/atalman
ghstack dependencies: #157814
2025-09-25 07:15:51 +00:00

155 lines
5.7 KiB
Python

from __future__ import annotations
import argparse
import email
import os
import re
import subprocess
from pathlib import Path
from packaging.version import Version
from setuptools import distutils # type: ignore[import,attr-defined]
UNKNOWN = "Unknown"
RELEASE_PATTERN = re.compile(r"/v[0-9]+(\.[0-9]+)*(-rc[0-9]+)?/")
def get_sha(pytorch_root: str | Path) -> str:
try:
rev = None
if os.path.exists(os.path.join(pytorch_root, ".git")):
rev = subprocess.check_output(
["git", "rev-parse", "HEAD"], cwd=pytorch_root
)
elif os.path.exists(os.path.join(pytorch_root, ".hg")):
rev = subprocess.check_output(
["hg", "identify", "-r", "."], cwd=pytorch_root
)
if rev:
return rev.decode("ascii").strip()
except Exception:
pass
return UNKNOWN
def get_tag(pytorch_root: str | Path) -> str:
try:
tag = subprocess.run(
["git", "describe", "--tags", "--exact"],
cwd=pytorch_root,
encoding="ascii",
capture_output=True,
).stdout.strip()
if RELEASE_PATTERN.match(tag):
return tag
else:
return UNKNOWN
except Exception:
return UNKNOWN
def get_torch_version(sha: str | None = None) -> str:
"""Determine the torch version string.
The version is determined from one of the following sources, in order of
precedence:
1. The PYTORCH_BUILD_VERSION and PYTORCH_BUILD_NUMBER environment variables.
These are set by the PyTorch build system when building official
releases. If built from an sdist, it is checked that the version matches
the sdist version.
2. The PKG-INFO file, if it exists. This file is included in source
distributions (sdist) and contains the version of the sdist.
3. The version.txt file, which contains the base version string. If the git
commit SHA is available, it is appended to the version string to
indicate that this is a development build.
"""
pytorch_root = Path(__file__).absolute().parent.parent
pkg_info_path = pytorch_root / "PKG-INFO"
if pkg_info_path.exists():
with open(pkg_info_path) as f:
pkg_info = email.message_from_file(f)
sdist_version = pkg_info["Version"]
else:
sdist_version = None
if os.getenv("PYTORCH_BUILD_VERSION"):
assert os.getenv("PYTORCH_BUILD_NUMBER") is not None
build_number = int(os.getenv("PYTORCH_BUILD_NUMBER", ""))
version = os.getenv("PYTORCH_BUILD_VERSION", "")
if build_number > 1:
version += ".post" + str(build_number)
origin = "PYTORCH_BUILD_{VERSION,NUMBER} env variables"
elif sdist_version:
version = sdist_version
origin = "PKG-INFO"
else:
version = open(pytorch_root / "version.txt").read().strip()
origin = "version.txt"
if sdist_version is None and sha != UNKNOWN:
if sha is None:
sha = get_sha(pytorch_root)
version += "+git" + sha[:7]
origin += " and git commit"
# Validate that the version is PEP 440 compliant
parsed_version = Version(version)
if sdist_version:
if (l := parsed_version.local) and l.startswith("git"):
# Assume local version is git<sha> and
# hence whole version is source version
source_version = version
else:
# local version is absent or platform tag
source_version = version.partition("+")[0]
assert sdist_version == source_version, (
f"Source part '{source_version}' of version '{version}' from "
f"{origin} does not match version '{sdist_version}' from PKG-INFO"
)
return version
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate torch/version.py from build and environment metadata."
)
parser.add_argument(
"--is-debug",
"--is_debug",
type=distutils.util.strtobool,
help="Whether this build is debug mode or not.",
)
parser.add_argument("--cuda-version", "--cuda_version", type=str)
parser.add_argument("--hip-version", "--hip_version", type=str)
parser.add_argument("--xpu-version", "--xpu_version", type=str)
args = parser.parse_args()
assert args.is_debug is not None
args.cuda_version = None if args.cuda_version == "" else args.cuda_version
args.hip_version = None if args.hip_version == "" else args.hip_version
args.xpu_version = None if args.xpu_version == "" else args.xpu_version
pytorch_root = Path(__file__).parent.parent
version_path = pytorch_root / "torch" / "version.py"
# Attempt to get tag first, fall back to sha if a tag was not found
tagged_version = get_tag(pytorch_root)
sha = get_sha(pytorch_root)
if tagged_version == UNKNOWN:
version = get_torch_version(sha)
else:
version = tagged_version
with open(version_path, "w") as f:
f.write("from typing import Optional\n\n")
f.write(
"__all__ = ['__version__', 'debug', 'cuda', 'git_version', 'hip', 'xpu']\n"
)
f.write(f"__version__ = '{version}'\n")
# NB: This is not 100% accurate, because you could have built the
# library code with DEBUG, but csrc without DEBUG (in which case
# this would claim to be a release build when it's not.)
f.write(f"debug = {repr(bool(args.is_debug))}\n")
f.write(f"cuda: Optional[str] = {repr(args.cuda_version)}\n")
f.write(f"git_version = {repr(sha)}\n")
f.write(f"hip: Optional[str] = {repr(args.hip_version)}\n")
f.write(f"xpu: Optional[str] = {repr(args.xpu_version)}\n")