mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Dedup hardcoded triton versions (#96580)
Define it once in `.ci/docker/trition_version.txt` and use everywhere. Also, patch version defined in `triton/__init__.py` as currently it always returns `2.0.0` even if package name is `2.1.0` Followup after https://github.com/pytorch/pytorch/pull/95896 where version needed to be updated in 4+ places Pull Request resolved: https://github.com/pytorch/pytorch/pull/96580 Approved by: https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
30b968f60d
commit
1ab883797a
1
.ci/docker/triton_version.txt
Normal file
1
.ci/docker/triton_version.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
2.1.0
|
||||||
@ -91,8 +91,9 @@ ARG TRITON
|
|||||||
COPY ./common/install_triton.sh install_triton.sh
|
COPY ./common/install_triton.sh install_triton.sh
|
||||||
COPY ./common/common_utils.sh common_utils.sh
|
COPY ./common/common_utils.sh common_utils.sh
|
||||||
COPY ci_commit_pins/triton.txt triton.txt
|
COPY ci_commit_pins/triton.txt triton.txt
|
||||||
|
COPY triton_version.txt triton_version.txt
|
||||||
RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi
|
RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi
|
||||||
RUN rm install_triton.sh common_utils.sh triton.txt
|
RUN rm install_triton.sh common_utils.sh triton.txt triton_version.txt
|
||||||
|
|
||||||
# Install ccache/sccache (do this last, so we get priority in PATH)
|
# Install ccache/sccache (do this last, so we get priority in PATH)
|
||||||
COPY ./common/install_cache.sh install_cache.sh
|
COPY ./common/install_cache.sh install_cache.sh
|
||||||
|
|||||||
37
.github/scripts/build_triton_wheel.py
vendored
37
.github/scripts/build_triton_wheel.py
vendored
@ -6,9 +6,15 @@ from typing import Optional
|
|||||||
import sys
|
import sys
|
||||||
import shutil
|
import shutil
|
||||||
SCRIPT_DIR = Path(__file__).parent
|
SCRIPT_DIR = Path(__file__).parent
|
||||||
|
REPO_DIR = SCRIPT_DIR.parent.parent
|
||||||
|
|
||||||
def read_triton_pin() -> str:
|
def read_triton_pin() -> str:
|
||||||
with open(SCRIPT_DIR.parent / "ci_commit_pins" / "triton.txt") as f:
|
with open(REPO_DIR / ".ci" / "docker" / "ci_commit_pins" / "triton.txt") as f:
|
||||||
|
return f.read().strip()
|
||||||
|
|
||||||
|
|
||||||
|
def read_triton_version() -> str:
|
||||||
|
with open(REPO_DIR / ".ci" / "docker" / "triton_version.txt") as f:
|
||||||
return f.read().strip()
|
return f.read().strip()
|
||||||
|
|
||||||
|
|
||||||
@ -19,18 +25,27 @@ def check_and_replace(inp: str, src: str, dst: str) -> str:
|
|||||||
return inp.replace(src, dst)
|
return inp.replace(src, dst)
|
||||||
|
|
||||||
|
|
||||||
def patch_setup_py(path: Path, *, version: str = "2.1.0", name: str = "triton") -> None:
|
def patch_setup_py(path: Path, *, version: str, name: str = "triton") -> None:
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
orig = f.read()
|
orig = f.read()
|
||||||
# Replace name
|
# Replace name
|
||||||
orig = check_and_replace(orig, "name=\"triton\",", f"name=\"{name}\",")
|
orig = check_and_replace(orig, "name=\"triton\",", f"name=\"{name}\",")
|
||||||
# Replace version
|
# Replace version
|
||||||
orig = check_and_replace(orig, "version=\"2.1.0\",", f"version=\"{version}\",")
|
orig = check_and_replace(orig, f"version=\"{read_triton_version()}\",", f"version=\"{version}\",")
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
f.write(orig)
|
f.write(orig)
|
||||||
|
|
||||||
|
|
||||||
def build_triton(commit_hash: str, build_conda: bool = False, py_version : Optional[str] = None) -> Path:
|
def patch_init_py(path: Path, *, version: str) -> None:
|
||||||
|
with open(path) as f:
|
||||||
|
orig = f.read()
|
||||||
|
# Replace version
|
||||||
|
orig = check_and_replace(orig, "__version__ = '2.0.0'", f"__version__ = \"{version}\"")
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write(orig)
|
||||||
|
|
||||||
|
|
||||||
|
def build_triton(*, version: str, commit_hash: str, build_conda: bool = False, py_version : Optional[str] = None) -> Path:
|
||||||
with TemporaryDirectory() as tmpdir:
|
with TemporaryDirectory() as tmpdir:
|
||||||
triton_basedir = Path(tmpdir) / "triton"
|
triton_basedir = Path(tmpdir) / "triton"
|
||||||
triton_pythondir = triton_basedir / "python"
|
triton_pythondir = triton_basedir / "python"
|
||||||
@ -38,7 +53,7 @@ def build_triton(commit_hash: str, build_conda: bool = False, py_version : Optio
|
|||||||
check_call(["git", "checkout", commit_hash], cwd=triton_basedir)
|
check_call(["git", "checkout", commit_hash], cwd=triton_basedir)
|
||||||
if build_conda:
|
if build_conda:
|
||||||
with open(triton_basedir / "meta.yaml", "w") as meta:
|
with open(triton_basedir / "meta.yaml", "w") as meta:
|
||||||
print(f"package:\n name: torchtriton\n version: 2.1.0+{commit_hash[:10]}\n", file=meta)
|
print(f"package:\n name: torchtriton\n version: {version}+{commit_hash[:10]}\n", file=meta)
|
||||||
print("source:\n path: .\n", file=meta)
|
print("source:\n path: .\n", file=meta)
|
||||||
print("build:\n string: py{{py}}\n number: 1\n script: cd python; "
|
print("build:\n string: py{{py}}\n number: 1\n script: cd python; "
|
||||||
"python setup.py install --single-version-externally-managed --record=record.txt\n", file=meta)
|
"python setup.py install --single-version-externally-managed --record=record.txt\n", file=meta)
|
||||||
@ -47,6 +62,7 @@ def build_triton(commit_hash: str, build_conda: bool = False, py_version : Optio
|
|||||||
print("about:\n home: https://github.com/openai/triton\n license: MIT\n summary:"
|
print("about:\n home: https://github.com/openai/triton\n license: MIT\n summary:"
|
||||||
" 'A language and compiler for custom Deep Learning operation'", file=meta)
|
" 'A language and compiler for custom Deep Learning operation'", file=meta)
|
||||||
|
|
||||||
|
patch_init_py(triton_pythondir / "triton" / "__init__.py", version=f"{version}+{commit_hash[:10]}")
|
||||||
if py_version is None:
|
if py_version is None:
|
||||||
py_version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
py_version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
||||||
check_call(["conda", "build", "--python", py_version,
|
check_call(["conda", "build", "--python", py_version,
|
||||||
@ -55,7 +71,8 @@ def build_triton(commit_hash: str, build_conda: bool = False, py_version : Optio
|
|||||||
shutil.copy(conda_path, Path.cwd())
|
shutil.copy(conda_path, Path.cwd())
|
||||||
return Path.cwd() / conda_path.name
|
return Path.cwd() / conda_path.name
|
||||||
|
|
||||||
patch_setup_py(triton_pythondir / "setup.py", name="pytorch-triton", version=f"2.1.0+{commit_hash[:10]}")
|
patch_setup_py(triton_pythondir / "setup.py", name="pytorch-triton", version=f"{version}+{commit_hash[:10]}")
|
||||||
|
patch_init_py(triton_pythondir / "triton" / "__init__.py", version=f"{version}+{commit_hash[:10]}")
|
||||||
check_call([sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir)
|
check_call([sys.executable, "setup.py", "bdist_wheel"], cwd=triton_pythondir)
|
||||||
whl_path = list((triton_pythondir / "dist").glob("*.whl"))[0]
|
whl_path = list((triton_pythondir / "dist").glob("*.whl"))[0]
|
||||||
shutil.copy(whl_path, Path.cwd())
|
shutil.copy(whl_path, Path.cwd())
|
||||||
@ -67,9 +84,13 @@ def main() -> None:
|
|||||||
parser = ArgumentParser("Build Triton binaries")
|
parser = ArgumentParser("Build Triton binaries")
|
||||||
parser.add_argument("--build-conda", action="store_true")
|
parser.add_argument("--build-conda", action="store_true")
|
||||||
parser.add_argument("--py-version", type=str)
|
parser.add_argument("--py-version", type=str)
|
||||||
|
parser.add_argument("--commit-hash", type=str, default=read_triton_pin())
|
||||||
|
parser.add_argument("--triton-version", type=str, default=read_triton_version())
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
pin = read_triton_pin()
|
build_triton(commit_hash=args.commit_hash,
|
||||||
build_triton(pin, build_conda=args.build_conda, py_version=args.py_version)
|
version=args.triton_version,
|
||||||
|
build_conda=args.build_conda,
|
||||||
|
py_version=args.py_version)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
2
.github/workflows/docker-release.yml
vendored
2
.github/workflows/docker-release.yml
vendored
@ -87,7 +87,7 @@ jobs:
|
|||||||
{
|
{
|
||||||
echo "DOCKER_IMAGE=pytorch-nightly";
|
echo "DOCKER_IMAGE=pytorch-nightly";
|
||||||
echo "INSTALL_CHANNEL=pytorch-nightly";
|
echo "INSTALL_CHANNEL=pytorch-nightly";
|
||||||
echo "TRITON_VERSION=2.1.0+$(cut -c -10 .github/ci_commit_pins/triton.txt)";
|
echo "TRITON_VERSION=$(cut -f 1 .ci/docker/triton_version.txt)+$(cut -c -10 .ci/docker/ci_commit_pins/triton.txt)";
|
||||||
} >> "${GITHUB_ENV}"
|
} >> "${GITHUB_ENV}"
|
||||||
- name: Run docker build / push
|
- name: Run docker build / push
|
||||||
# WITH_PUSH is used here to determine whether or not to add the --push flag
|
# WITH_PUSH is used here to determine whether or not to add the --push flag
|
||||||
|
|||||||
9
setup.py
9
setup.py
@ -1028,11 +1028,14 @@ def main():
|
|||||||
'opt-einsum': ['opt-einsum>=3.3']
|
'opt-einsum': ['opt-einsum>=3.3']
|
||||||
}
|
}
|
||||||
if platform.system() == 'Linux':
|
if platform.system() == 'Linux':
|
||||||
triton_pin_file = os.path.join(cwd, ".github", "ci_commit_pins", "triton.txt")
|
triton_pin_file = os.path.join(cwd, ".ci", "docker", "ci_commit_pins", "triton.txt")
|
||||||
if os.path.exists(triton_pin_file):
|
triton_version_file = os.path.join(cwd, ".ci", "docker", "triton_version.txt")
|
||||||
|
if os.path.exists(triton_pin_file) and os.path.exists(triton_version_file):
|
||||||
with open(triton_pin_file) as f:
|
with open(triton_pin_file) as f:
|
||||||
triton_pin = f.read().strip()
|
triton_pin = f.read().strip()
|
||||||
extras_require['dynamo'] = ['pytorch-triton==2.1.0+' + triton_pin[:10], 'jinja2']
|
with open(triton_version_file) as f:
|
||||||
|
triton_version = f.read().strip()
|
||||||
|
extras_require['dynamo'] = ['pytorch-triton==' + triton_version + '+' + triton_pin[:10], 'jinja2']
|
||||||
|
|
||||||
# Parse the command line and check the arguments before we proceed with
|
# Parse the command line and check the arguments before we proceed with
|
||||||
# building deps and setup. We need to set values so `--help` works.
|
# building deps and setup. We need to set values so `--help` works.
|
||||||
|
|||||||
Reference in New Issue
Block a user