mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable triton build in CI docker image for ROCm (#98096)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98096 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
7117c87489
commit
ce4df4cc59
@ -184,6 +184,7 @@ case "$image" in
|
||||
ROCM_VERSION=5.3
|
||||
NINJA_VERSION=1.9.0
|
||||
CONDA_CMAKE=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-focal-rocm-n-py3)
|
||||
ANACONDA_PYTHON_VERSION=3.8
|
||||
@ -194,6 +195,7 @@ case "$image" in
|
||||
ROCM_VERSION=5.4.2
|
||||
NINJA_VERSION=1.9.0
|
||||
CONDA_CMAKE=yes
|
||||
TRITON=yes
|
||||
;;
|
||||
pytorch-linux-focal-py3.8-gcc7)
|
||||
ANACONDA_PYTHON_VERSION=3.8
|
||||
@ -248,6 +250,7 @@ case "$image" in
|
||||
if [[ "$image" == *rocm* ]]; then
|
||||
extract_version_from_image_name rocm ROCM_VERSION
|
||||
NINJA_VERSION=1.9.0
|
||||
TRITON=yes
|
||||
fi
|
||||
if [[ "$image" == *centos7* ]]; then
|
||||
NINJA_VERSION=1.10.2
|
||||
|
||||
1
.ci/docker/ci_commit_pins/triton-rocm.txt
Normal file
1
.ci/docker/ci_commit_pins/triton-rocm.txt
Normal file
@ -0,0 +1 @@
|
||||
de3f5436247e391b062a7dd7fd42d2a55c2cd524
|
||||
@ -12,8 +12,16 @@ conda_reinstall() {
|
||||
as_jenkins conda install -q -n py_$ANACONDA_PYTHON_VERSION -y --force-reinstall $*
|
||||
}
|
||||
|
||||
if [ -n "${ROCM_VERSION}" ]; then
|
||||
TRITON_REPO="https://github.com/ROCmSoftwarePlatform/triton"
|
||||
TRITON_TEXT_FILE="triton-rocm"
|
||||
else
|
||||
TRITON_REPO="https://github.com/openai/triton"
|
||||
TRITON_TEXT_FILE="triton"
|
||||
fi
|
||||
|
||||
# The logic here is copied from .ci/pytorch/common_utils.sh
|
||||
TRITON_PINNED_COMMIT=$(get_pinned_commit triton)
|
||||
TRITON_PINNED_COMMIT=$(get_pinned_commit ${TRITON_TEXT_FILE})
|
||||
|
||||
apt update
|
||||
apt-get install -y gpg-agent
|
||||
@ -28,15 +36,15 @@ if [ -n "${GCC_VERSION}" ] && [[ "${GCC_VERSION}" == "7" ]]; then
|
||||
# Triton needs at least gcc-9 to build
|
||||
apt-get install -y g++-9
|
||||
|
||||
CXX=g++-9 pip_install "git+https://github.com/openai/triton@${TRITON_PINNED_COMMIT}#subdirectory=python"
|
||||
CXX=g++-9 pip_install "git+${TRITON_REPO}@${TRITON_PINNED_COMMIT}#subdirectory=python"
|
||||
elif [ -n "${CLANG_VERSION}" ]; then
|
||||
# Triton needs <filesystem> which surprisingly is not available with clang-9 toolchain
|
||||
add-apt-repository -y ppa:ubuntu-toolchain-r/test
|
||||
apt-get install -y g++-9
|
||||
|
||||
CXX=g++-9 pip_install "git+https://github.com/openai/triton@${TRITON_PINNED_COMMIT}#subdirectory=python"
|
||||
CXX=g++-9 pip_install "git+${TRITON_REPO}@${TRITON_PINNED_COMMIT}#subdirectory=python"
|
||||
else
|
||||
pip_install "git+https://github.com/openai/triton@${TRITON_PINNED_COMMIT}#subdirectory=python"
|
||||
pip_install "git+${TRITON_REPO}@${TRITON_PINNED_COMMIT}#subdirectory=python"
|
||||
fi
|
||||
|
||||
if [ -n "${CONDA_CMAKE}" ]; then
|
||||
|
||||
@ -68,6 +68,7 @@ RUN rm install_rocm.sh
|
||||
COPY ./common/install_rocm_magma.sh install_rocm_magma.sh
|
||||
RUN bash ./install_rocm_magma.sh
|
||||
RUN rm install_rocm_magma.sh
|
||||
ENV ROCM_PATH /opt/rocm
|
||||
ENV PATH /opt/rocm/bin:$PATH
|
||||
ENV PATH /opt/rocm/hcc/bin:$PATH
|
||||
ENV PATH /opt/rocm/hip/bin:$PATH
|
||||
@ -89,6 +90,16 @@ COPY ./common/install_ninja.sh install_ninja.sh
|
||||
RUN if [ -n "${NINJA_VERSION}" ]; then bash ./install_ninja.sh; fi
|
||||
RUN rm install_ninja.sh
|
||||
|
||||
ARG TRITON
|
||||
# Install triton, this needs to be done before sccache because the latter will
|
||||
# try to reach out to S3, which docker build runners don't have access
|
||||
COPY ./common/install_triton.sh install_triton.sh
|
||||
COPY ./common/common_utils.sh common_utils.sh
|
||||
COPY ci_commit_pins/triton-rocm.txt triton-rocm.txt
|
||||
COPY triton_version.txt triton_version.txt
|
||||
RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi
|
||||
RUN rm install_triton.sh common_utils.sh triton-rocm.txt triton_version.txt
|
||||
|
||||
# Install ccache/sccache (do this last, so we get priority in PATH)
|
||||
COPY ./common/install_cache.sh install_cache.sh
|
||||
ENV PATH /opt/cache/bin:$PATH
|
||||
|
||||
@ -1 +0,0 @@
|
||||
pytorch-triton-rocm>=2.0.0,<2.1
|
||||
18
setup.py
18
setup.py
@ -927,13 +927,6 @@ def configure_extension_build():
|
||||
|
||||
# These extensions are built by cmake and copied manually in build_extensions()
|
||||
# inside the build_ext implementation
|
||||
if cmake_cache_vars['USE_ROCM']:
|
||||
triton_req_file = os.path.join(cwd, ".github", "requirements", "triton-requirements-rocm.txt")
|
||||
if os.path.exists(triton_req_file):
|
||||
with open(triton_req_file) as f:
|
||||
triton_req = f.read().strip()
|
||||
extra_install_requires.append(triton_req)
|
||||
|
||||
if cmake_cache_vars['BUILD_CAFFE2']:
|
||||
extensions.append(
|
||||
Extension(
|
||||
@ -1019,14 +1012,21 @@ def main():
|
||||
'opt-einsum': ['opt-einsum>=3.3']
|
||||
}
|
||||
if platform.system() == 'Linux':
|
||||
triton_pin_file = os.path.join(cwd, ".ci", "docker", "ci_commit_pins", "triton.txt")
|
||||
cmake_cache_vars = get_cmake_cache_vars()
|
||||
if cmake_cache_vars['USE_ROCM']:
|
||||
triton_text_file = "triton-rocm.txt"
|
||||
triton_package_name = "pytorch-triton-rocm"
|
||||
else:
|
||||
triton_text_file = "triton.txt"
|
||||
triton_package_name = "pytorch-triton"
|
||||
triton_pin_file = os.path.join(cwd, ".ci", "docker", "ci_commit_pins", triton_text_file)
|
||||
triton_version_file = os.path.join(cwd, ".ci", "docker", "triton_version.txt")
|
||||
if os.path.exists(triton_pin_file) and os.path.exists(triton_version_file):
|
||||
with open(triton_pin_file) as f:
|
||||
triton_pin = f.read().strip()
|
||||
with open(triton_version_file) as f:
|
||||
triton_version = f.read().strip()
|
||||
extras_require['dynamo'] = ['pytorch-triton==' + triton_version + '+' + triton_pin[:10], 'jinja2']
|
||||
extras_require['dynamo'] = [triton_package_name + '==' + triton_version + '+' + triton_pin[:10], 'jinja2']
|
||||
|
||||
# Parse the command line and check the arguments before we proceed with
|
||||
# building deps and setup. We need to set values so `--help` works.
|
||||
|
||||
Reference in New Issue
Block a user