diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml index 351497bee753..e9c8db722432 100644 --- a/.github/workflows/docker-release.yml +++ b/.github/workflows/docker-release.yml @@ -110,7 +110,7 @@ jobs: if [[ ${{ github.event.ref }} =~ ^refs/tags/v[0-9]+\.[0-9]+\.[0-9]+-rc[0-9]+$ ]]; then { echo "DOCKER_IMAGE=pytorch-test"; - echo "INSTALL_CHANNEL=pytorch-test"; + echo "INSTALL_CHANNEL=whl/test"; echo "TRITON_VERSION=$(cut -f 1 .ci/docker/triton_version.txt)"; } >> "${GITHUB_ENV}" fi @@ -119,7 +119,7 @@ jobs: run: | { echo "DOCKER_IMAGE=pytorch-nightly"; - echo "INSTALL_CHANNEL=pytorch-nightly"; + echo "INSTALL_CHANNEL=whl/nightly"; echo "TRITON_VERSION=$(cut -f 1 .ci/docker/triton_version.txt)+$(cut -c -10 .ci/docker/ci_commit_pins/triton.txt)"; } >> "${GITHUB_ENV}" - name: Run docker build / push diff --git a/Dockerfile b/Dockerfile index b751c64a8439..5cec2173063b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -32,7 +32,7 @@ RUN case ${TARGETPLATFORM} in \ "linux/arm64") MINICONDA_ARCH=aarch64 ;; \ *) MINICONDA_ARCH=x86_64 ;; \ esac && \ - curl -fsSL -v -o ~/miniconda.sh -O "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-${MINICONDA_ARCH}.sh" + curl -fsSL -v -o ~/miniconda.sh -O "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-${MINICONDA_ARCH}.sh" COPY requirements.txt . # Manually invoke bash on miniconda script per https://github.com/conda/conda/issues/10431 RUN chmod +x ~/miniconda.sh && \ @@ -61,19 +61,19 @@ RUN --mount=type=cache,target=/opt/ccache \ FROM conda as conda-installs ARG PYTHON_VERSION=3.11 -ARG CUDA_VERSION=12.1 +ARG CUDA_PATH=cu121 ARG CUDA_CHANNEL=nvidia -ARG INSTALL_CHANNEL=pytorch-nightly +ARG INSTALL_CHANNEL=whl/nightly # Automatically set by buildx RUN /opt/conda/bin/conda update -y -n base -c defaults conda -RUN /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -y python=${PYTHON_VERSION} +RUN /opt/conda/bin/conda install -y python=${PYTHON_VERSION} ARG TARGETPLATFORM -# On arm64 we can only install wheel packages. +# INSTALL_CHANNEL whl - release, whl/nightly - nightly, whle/test - test channels RUN case ${TARGETPLATFORM} in \ "linux/arm64") pip install --extra-index-url https://download.pytorch.org/whl/cpu/ torch torchvision torchaudio ;; \ - *) /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" pytorch torchvision torchaudio "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \ + *) pip install --index-url https://download.pytorch.org/${INSTALL_CHANNEL}/${CUDA_PATH#.}/ torch torchvision torchaudio ;; \ esac && \ /opt/conda/bin/conda clean -ya RUN /opt/conda/bin/pip install torchelastic diff --git a/docker.Makefile b/docker.Makefile index 7f131707e7ab..f510d1618a21 100644 --- a/docker.Makefile +++ b/docker.Makefile @@ -18,7 +18,12 @@ CMAKE_VARS ?= # The conda channel to use to install cudatoolkit CUDA_CHANNEL = nvidia # The conda channel to use to install pytorch / torchvision -INSTALL_CHANNEL ?= pytorch +INSTALL_CHANNEL ?= whl + +CUDA_PATH ?= cpu +ifneq ("$(CUDA_VERSION_SHORT)","cpu") +CUDA_PATH = cu$(subst .,,$(CUDA_VERSION_SHORT)) +endif PYTHON_VERSION ?= 3.11 # Match versions that start with v followed by a number, to avoid matching with tags like ciflow @@ -31,7 +36,7 @@ TRITON_VERSION ?= BUILD_ARGS = --build-arg BASE_IMAGE=$(BASE_IMAGE) \ --build-arg PYTHON_VERSION=$(PYTHON_VERSION) \ --build-arg CUDA_VERSION=$(CUDA_VERSION) \ - --build-arg CUDA_CHANNEL=$(CUDA_CHANNEL) \ + --build-arg CUDA_PATH=$(CUDA_PATH) \ --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION) \ --build-arg INSTALL_CHANNEL=$(INSTALL_CHANNEL) \ --build-arg TRITON_VERSION=$(TRITON_VERSION) \