Enable triton build in CI docker image for ROCm (#98096)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98096
Approved by: https://github.com/malfet
This commit is contained in:
Jithun Nair
2023-04-11 09:02:19 +00:00
committed by PyTorch MergeBot
parent 7117c87489
commit ce4df4cc59
6 changed files with 36 additions and 14 deletions

View File

@ -184,6 +184,7 @@ case "$image" in
ROCM_VERSION=5.3
NINJA_VERSION=1.9.0
CONDA_CMAKE=yes
TRITON=yes
;;
pytorch-linux-focal-rocm-n-py3)
ANACONDA_PYTHON_VERSION=3.8
@ -194,6 +195,7 @@ case "$image" in
ROCM_VERSION=5.4.2
NINJA_VERSION=1.9.0
CONDA_CMAKE=yes
TRITON=yes
;;
pytorch-linux-focal-py3.8-gcc7)
ANACONDA_PYTHON_VERSION=3.8
@ -248,6 +250,7 @@ case "$image" in
if [[ "$image" == *rocm* ]]; then
extract_version_from_image_name rocm ROCM_VERSION
NINJA_VERSION=1.9.0
TRITON=yes
fi
if [[ "$image" == *centos7* ]]; then
NINJA_VERSION=1.10.2

View File

@ -0,0 +1 @@
de3f5436247e391b062a7dd7fd42d2a55c2cd524

View File

@ -12,8 +12,16 @@ conda_reinstall() {
as_jenkins conda install -q -n py_$ANACONDA_PYTHON_VERSION -y --force-reinstall $*
}
if [ -n "${ROCM_VERSION}" ]; then
TRITON_REPO="https://github.com/ROCmSoftwarePlatform/triton"
TRITON_TEXT_FILE="triton-rocm"
else
TRITON_REPO="https://github.com/openai/triton"
TRITON_TEXT_FILE="triton"
fi
# The logic here is copied from .ci/pytorch/common_utils.sh
TRITON_PINNED_COMMIT=$(get_pinned_commit triton)
TRITON_PINNED_COMMIT=$(get_pinned_commit ${TRITON_TEXT_FILE})
apt update
apt-get install -y gpg-agent
@ -28,15 +36,15 @@ if [ -n "${GCC_VERSION}" ] && [[ "${GCC_VERSION}" == "7" ]]; then
# Triton needs at least gcc-9 to build
apt-get install -y g++-9
CXX=g++-9 pip_install "git+https://github.com/openai/triton@${TRITON_PINNED_COMMIT}#subdirectory=python"
CXX=g++-9 pip_install "git+${TRITON_REPO}@${TRITON_PINNED_COMMIT}#subdirectory=python"
elif [ -n "${CLANG_VERSION}" ]; then
# Triton needs <filesystem> which surprisingly is not available with clang-9 toolchain
add-apt-repository -y ppa:ubuntu-toolchain-r/test
apt-get install -y g++-9
CXX=g++-9 pip_install "git+https://github.com/openai/triton@${TRITON_PINNED_COMMIT}#subdirectory=python"
CXX=g++-9 pip_install "git+${TRITON_REPO}@${TRITON_PINNED_COMMIT}#subdirectory=python"
else
pip_install "git+https://github.com/openai/triton@${TRITON_PINNED_COMMIT}#subdirectory=python"
pip_install "git+${TRITON_REPO}@${TRITON_PINNED_COMMIT}#subdirectory=python"
fi
if [ -n "${CONDA_CMAKE}" ]; then

View File

@ -68,6 +68,7 @@ RUN rm install_rocm.sh
COPY ./common/install_rocm_magma.sh install_rocm_magma.sh
RUN bash ./install_rocm_magma.sh
RUN rm install_rocm_magma.sh
ENV ROCM_PATH /opt/rocm
ENV PATH /opt/rocm/bin:$PATH
ENV PATH /opt/rocm/hcc/bin:$PATH
ENV PATH /opt/rocm/hip/bin:$PATH
@ -89,6 +90,16 @@ COPY ./common/install_ninja.sh install_ninja.sh
RUN if [ -n "${NINJA_VERSION}" ]; then bash ./install_ninja.sh; fi
RUN rm install_ninja.sh
ARG TRITON
# Install triton, this needs to be done before sccache because the latter will
# try to reach out to S3, which docker build runners don't have access
COPY ./common/install_triton.sh install_triton.sh
COPY ./common/common_utils.sh common_utils.sh
COPY ci_commit_pins/triton-rocm.txt triton-rocm.txt
COPY triton_version.txt triton_version.txt
RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi
RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt
# Install ccache/sccache (do this last, so we get priority in PATH)
COPY ./common/install_cache.sh install_cache.sh
ENV PATH /opt/cache/bin:$PATH

View File

@ -1 +0,0 @@
pytorch-triton-rocm>=2.0.0,<2.1

View File

@ -927,13 +927,6 @@ def configure_extension_build():
# These extensions are built by cmake and copied manually in build_extensions()
# inside the build_ext implementation
if cmake_cache_vars['USE_ROCM']:
triton_req_file = os.path.join(cwd, ".github", "requirements", "triton-requirements-rocm.txt")
if os.path.exists(triton_req_file):
with open(triton_req_file) as f:
triton_req = f.read().strip()
extra_install_requires.append(triton_req)
if cmake_cache_vars['BUILD_CAFFE2']:
extensions.append(
Extension(
@ -1019,14 +1012,21 @@ def main():
'opt-einsum': ['opt-einsum>=3.3']
}
if platform.system() == 'Linux':
triton_pin_file = os.path.join(cwd, ".ci", "docker", "ci_commit_pins", "triton.txt")
cmake_cache_vars = get_cmake_cache_vars()
if cmake_cache_vars['USE_ROCM']:
triton_text_file = "triton-rocm.txt"
triton_package_name = "pytorch-triton-rocm"
else:
triton_text_file = "triton.txt"
triton_package_name = "pytorch-triton"
triton_pin_file = os.path.join(cwd, ".ci", "docker", "ci_commit_pins", triton_text_file)
triton_version_file = os.path.join(cwd, ".ci", "docker", "triton_version.txt")
if os.path.exists(triton_pin_file) and os.path.exists(triton_version_file):
with open(triton_pin_file) as f:
triton_pin = f.read().strip()
with open(triton_version_file) as f:
triton_version = f.read().strip()
extras_require['dynamo'] = ['pytorch-triton==' + triton_version + '+' + triton_pin[:10], 'jinja2']
extras_require['dynamo'] = [triton_package_name + '==' + triton_version + '+' + triton_pin[:10], 'jinja2']
# Parse the command line and check the arguments before we proceed with
# building deps and setup. We need to set values so `--help` works.