[ROCm][Build][Bugfix] Fix ROCm base docker whls installation order (#25415)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Gregory Shtrasberg
2025-09-24 11:25:10 -04:00
committed by yewentao256
parent 3331ced61b
commit b1f9a1f46a

View File

@ -65,8 +65,6 @@ ARG PYTORCH_BRANCH
ARG PYTORCH_VISION_BRANCH
ARG PYTORCH_REPO
ARG PYTORCH_VISION_REPO
ARG FA_BRANCH
ARG FA_REPO
RUN git clone ${PYTORCH_REPO} pytorch
RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \
pip install -r requirements.txt && git submodule update --init --recursive \
@ -77,14 +75,20 @@ RUN git clone ${PYTORCH_VISION_REPO} vision
RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \
&& python3 setup.py bdist_wheel --dist-dir=dist \
&& pip install dist/*.whl
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
&& cp /app/vision/dist/*.whl /app/install
FROM base AS build_fa
ARG FA_BRANCH
ARG FA_REPO
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
pip install /install/*.whl
RUN git clone ${FA_REPO}
RUN cd flash-attention \
&& git checkout ${FA_BRANCH} \
&& git submodule update --init \
&& GPU_ARCHS=$(echo ${PYTORCH_ROCM_ARCH} | sed -e 's/;gfx1[0-9]\{3\}//g') python3 setup.py bdist_wheel --dist-dir=dist
RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \
&& cp /app/vision/dist/*.whl /app/install \
&& cp /app/flash-attention/dist/*.whl /app/install
RUN mkdir -p /app/install && cp /app/flash-attention/dist/*.whl /app/install
FROM base AS build_aiter
ARG AITER_BRANCH
@ -103,6 +107,8 @@ FROM base AS debs
RUN mkdir /app/debs
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
cp /install/*.whl /app/debs
RUN --mount=type=bind,from=build_fa,src=/app/install/,target=/install \
cp /install/*.whl /app/debs
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
cp /install/*.whl /app/debs
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
@ -111,13 +117,7 @@ RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
cp /install/*.whl /app/debs
FROM base AS final
RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \
pip install /install/*.whl
RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
pip install /install/*.whl
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
pip install /install/*.whl
RUN --mount=type=bind,from=build_aiter,src=/app/install/,target=/install \
RUN --mount=type=bind,from=debs,src=/app/debs,target=/install \
pip install /install/*.whl
ARG BASE_IMAGE