Files
pytorch/tools/onnx/update_default_opset_version.py
Xuehai Pan b6bdb67f82 [BE][Easy] use pathlib.Path instead of dirname / ".." / pardir (#129374)
Changes by apply order:

1. Replace all `".."` and `os.pardir` usage with `os.path.dirname(...)`.
2. Replace nested `os.path.dirname(os.path.dirname(...))` call with `str(Path(...).parent.parent)`.
3. Reorder `.absolute()` ~/ `.resolve()`~ and `.parent`: always resolve the path first.

    `.parent{...}.absolute()` -> `.absolute().parent{...}`

4. Replace chained `.parent x N` with `.parents[${N - 1}]`: the code is easier to read (see 5.)

    `.parent.parent.parent.parent` -> `.parents[3]`

5. ~Replace `.parents[${N - 1}]` with `.parents[${N} - 1]`: the code is easier to read and does not introduce any runtime overhead.~

    ~`.parents[3]` -> `.parents[4 - 1]`~

6. ~Replace `.parents[2 - 1]` with `.parent.parent`: because the code is shorter and easier to read.~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129374
Approved by: https://github.com/justinchuby, https://github.com/malfet
2024-12-29 17:23:13 +00:00

116 lines
3.3 KiB
Python
Executable File

#!/usr/bin/env python3
"""Updates the default value of opset_version.
The current policy is that the default should be set to the
latest released version as of 18 months ago.
Usage:
Run with no arguments.
"""
import argparse
import datetime
import os
import re
import subprocess
import sys
from pathlib import Path
from subprocess import DEVNULL
from typing import Any
def read_sub_write(path: str, prefix_pat: str, new_default: int) -> None:
with open(path, encoding="utf-8") as f:
content_str = f.read()
content_str = re.sub(prefix_pat, rf"\g<1>{new_default}", content_str)
with open(path, "w", encoding="utf-8") as f:
f.write(content_str)
print("modified", path)
def main(args: Any) -> None:
pytorch_dir = Path(__file__).parents[2].resolve()
onnx_dir = pytorch_dir / "third_party" / "onnx"
os.chdir(onnx_dir)
date = datetime.datetime.now() - datetime.timedelta(days=18 * 30)
onnx_commit = subprocess.check_output(
("git", "log", f"--until={date}", "--max-count=1", "--format=%H"),
encoding="utf-8",
).strip()
onnx_tags = subprocess.check_output(
("git", "tag", "--list", f"--contains={onnx_commit}"), encoding="utf-8"
)
tag_tups = []
semver_pat = re.compile(r"v(\d+)\.(\d+)\.(\d+)")
for tag in onnx_tags.splitlines():
match = semver_pat.match(tag)
if match:
tag_tups.append(tuple(int(x) for x in match.groups()))
# Take the release 18 months ago
version_str = "{}.{}.{}".format(*min(tag_tups))
print("Using ONNX release", version_str)
head_commit = subprocess.check_output(
("git", "log", "--max-count=1", "--format=%H", "HEAD"), encoding="utf-8"
).strip()
new_default = None
subprocess.check_call(
("git", "checkout", f"v{version_str}"), stdout=DEVNULL, stderr=DEVNULL
)
try:
from onnx import helper # type: ignore[import]
for version in helper.VERSION_TABLE:
if version[0] == version_str:
new_default = version[2]
print("found new default opset_version", new_default)
break
if not new_default:
sys.exit(
f"failed to find version {version_str} in onnx.helper.VERSION_TABLE at commit {onnx_commit}"
)
finally:
subprocess.check_call(
("git", "checkout", head_commit), stdout=DEVNULL, stderr=DEVNULL
)
os.chdir(pytorch_dir)
read_sub_write(
os.path.join("torch", "onnx", "_constants.py"),
r"(ONNX_DEFAULT_OPSET = )\d+",
new_default,
)
read_sub_write(
os.path.join("torch", "onnx", "utils.py"),
r"(opset_version \(int, default )\d+",
new_default,
)
if not args.skip_build:
print("Building PyTorch...")
subprocess.check_call(
("python", "setup.py", "develop"),
)
print("Updating operator .expect files")
subprocess.check_call(
("python", os.path.join("test", "onnx", "test_operators.py"), "--accept"),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--skip-build",
"--skip_build",
action="store_true",
help="Skip building pytorch",
)
main(parser.parse_args())