From 1ab883797a2b3b54677574ce98e897b19fbbecec Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Sun, 12 Mar 2023 20:00:48 +0000 Subject: [PATCH] [BE] Dedup hardcoded triton versions (#96580) Define it once in `.ci/docker/trition_version.txt` and use everywhere. Also, patch version defined in `triton/__init__.py` as currently it always returns `2.0.0` even if package name is `2.1.0` Followup after https://github.com/pytorch/pytorch/pull/95896 where version needed to be updated in 4+ places Pull Request resolved: https://github.com/pytorch/pytorch/pull/96580 Approved by: https://github.com/huydhn --- .ci/docker/triton_version.txt | 1 + .ci/docker/ubuntu-cuda/Dockerfile | 3 ++- .github/scripts/build_triton_wheel.py | 37 +++++++++++++++++++++------ .github/workflows/docker-release.yml | 2 +- setup.py | 9 ++++--- 5 files changed, 39 insertions(+), 13 deletions(-) create mode 100644 .ci/docker/triton_version.txt diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt new file mode 100644 index 000000000000..7ec1d6db4087 --- /dev/null +++ b/.ci/docker/triton_version.txt @@ -0,0 +1 @@ +2.1.0 diff --git a/.ci/docker/ubuntu-cuda/Dockerfile b/.ci/docker/ubuntu-cuda/Dockerfile index 0e294838f90f..57ed3a9c439f 100644 --- a/.ci/docker/ubuntu-cuda/Dockerfile +++ b/.ci/docker/ubuntu-cuda/Dockerfile @@ -91,8 +91,9 @@ ARG TRITON COPY ./common/install_triton.sh install_triton.sh COPY ./common/common_utils.sh common_utils.sh COPY ci_commit_pins/triton.txt triton.txt +COPY triton_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi -RUN rm install_triton.sh common_utils.sh triton.txt +RUN rm install_triton.sh common_utils.sh triton.txt triton_version.txt # Install ccache/sccache (do this last, so we get priority in PATH) COPY ./common/install_cache.sh install_cache.sh diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py index 36c8914f213e..3015296be53f 100644 --- a/.github/scripts/build_triton_wheel.py +++ b/.github/scripts/build_triton_wheel.py @@ -6,9 +6,15 @@ from typing import Optional import sys import shutil SCRIPT_DIR = Path(__file__).parent +REPO_DIR = SCRIPT_DIR.parent.parent def read_triton_pin() -> str: - with open(SCRIPT_DIR.parent / "ci_commit_pins" / "triton.txt") as f: + with open(REPO_DIR / ".ci" / "docker" / "ci_commit_pins" / "triton.txt") as f: + return f.read().strip() + + +def read_triton_version() -> str: + with open(REPO_DIR / ".ci" / "docker" / "triton_version.txt") as f: return f.read().strip() @@ -19,18 +25,27 @@ def check_and_replace(inp: str, src: str, dst: str) -> str: return inp.replace(src, dst) -def patch_setup_py(path: Path, *, version: str = "2.1.0", name: str = "triton") -> None: +def patch_setup_py(path: Path, *, version: str, name: str = "triton") -> None: with open(path) as f: orig = f.read() # Replace name orig = check_and_replace(orig, "name=\"triton\",", f"name=\"{name}\",") # Replace version - orig = check_and_replace(orig, "version=\"2.1.0\",", f"version=\"{version}\",") + orig = check_and_replace(orig, f"version=\"{read_triton_version()}\",", f"version=\"{version}\",") with open(path, "w") as f: f.write(orig) -def build_triton(commit_hash: str, build_conda: bool = False, py_version : Optional[str] = None) -> Path: +def patch_init_py(path: Path, *, version: str) -> None: + with open(path) as f: + orig = f.read() + # Replace version + orig = check_and_replace(orig, "__version__ = '2.0.0'", f"__version__ = \"{version}\"") + with open(path, "w") as f: + f.write(orig) + + +def build_triton(*, version: str, commit_hash: str, build_conda: bool = False, py_version : Optional[str] = None) -> Path: with TemporaryDirectory() as tmpdir: triton_basedir = Path(tmpdir) / "triton" triton_pythondir = triton_basedir / "python" @@ -38,7 +53,7 @@ def build_triton(commit_hash: str, build_conda: bool = False, py_version : Optio check_call(["git", "checkout", commit_hash], cwd=triton_basedir) if build_conda: with open(triton_basedir / "meta.yaml", "w") as meta: - print(f"package:\n name: torchtriton\n version: 2.1.0+{commit_hash[:10]}\n", file=meta) + print(f"package:\n name: torchtriton\n version: {version}+{commit_hash[:10]}\n", file=meta) print("source:\n path: .\n", file=meta) print("build:\n string: py{{py}}\n number: 1\n script: cd python; " "python setup.py install --single-version-externally-managed --record=record.txt\n", file=meta) @@ -47,6 +62,7 @@ def build_triton(commit_hash: str, build_conda: bool = False, py_version : Optio print("about:\n home: https://github.com/openai/triton\n license: MIT\n summary:" " 'A language and compiler for custom Deep Learning operation'", file=meta) + patch_init_py(triton_pythondir / "triton" / "__init__.py", version=f"{version}+{commit_hash[:10]}") if py_version is None: py_version = f"{sys.version_info.major}.{sys.version_info.minor}" check_call(["conda", "build", "--python", py_version, @@ -55,7 +71,8 @@ def build_triton(commit_hash: str, build_conda: bool = False, py_version : Optio shutil.copy(conda_path, Path.cwd()) return Path.cwd() / conda_path.name - patch_setup_py(triton_pythondir / "setup.py", name="pytorch-triton", version=f"2.1.0+{commit_hash[:10]}") + patch_setup_py(triton_pythondir / "setup.py", name="pytorch-triton", version=f"{version}+{commit_hash[:10]}") + patch_init_py(triton_pythondir / "triton" / "__init__.py", version=f"{version}+{commit_hash[:10]}") check_call([sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir) whl_path = list((triton_pythondir / "dist").glob("*.whl"))[0] shutil.copy(whl_path, Path.cwd()) @@ -67,9 +84,13 @@ def main() -> None: parser = ArgumentParser("Build Triton binaries") parser.add_argument("--build-conda", action="store_true") parser.add_argument("--py-version", type=str) + parser.add_argument("--commit-hash", type=str, default=read_triton_pin()) + parser.add_argument("--triton-version", type=str, default=read_triton_version()) args = parser.parse_args() - pin = read_triton_pin() - build_triton(pin, build_conda=args.build_conda, py_version=args.py_version) + build_triton(commit_hash=args.commit_hash, + version=args.triton_version, + build_conda=args.build_conda, + py_version=args.py_version) if __name__ == "__main__": diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml index b054f5482c1c..a3936246f3be 100644 --- a/.github/workflows/docker-release.yml +++ b/.github/workflows/docker-release.yml @@ -87,7 +87,7 @@ jobs: { echo "DOCKER_IMAGE=pytorch-nightly"; echo "INSTALL_CHANNEL=pytorch-nightly"; - echo "TRITON_VERSION=2.1.0+$(cut -c -10 .github/ci_commit_pins/triton.txt)"; + echo "TRITON_VERSION=$(cut -f 1 .ci/docker/triton_version.txt)+$(cut -c -10 .ci/docker/ci_commit_pins/triton.txt)"; } >> "${GITHUB_ENV}" - name: Run docker build / push # WITH_PUSH is used here to determine whether or not to add the --push flag diff --git a/setup.py b/setup.py index c48b83917954..9a258caba1ee 100644 --- a/setup.py +++ b/setup.py @@ -1028,11 +1028,14 @@ def main(): 'opt-einsum': ['opt-einsum>=3.3'] } if platform.system() == 'Linux': - triton_pin_file = os.path.join(cwd, ".github", "ci_commit_pins", "triton.txt") - if os.path.exists(triton_pin_file): + triton_pin_file = os.path.join(cwd, ".ci", "docker", "ci_commit_pins", "triton.txt") + triton_version_file = os.path.join(cwd, ".ci", "docker", "triton_version.txt") + if os.path.exists(triton_pin_file) and os.path.exists(triton_version_file): with open(triton_pin_file) as f: triton_pin = f.read().strip() - extras_require['dynamo'] = ['pytorch-triton==2.1.0+' + triton_pin[:10], 'jinja2'] + with open(triton_version_file) as f: + triton_version = f.read().strip() + extras_require['dynamo'] = ['pytorch-triton==' + triton_version + '+' + triton_pin[:10], 'jinja2'] # Parse the command line and check the arguments before we proceed with # building deps and setup. We need to set values so `--help` works.