mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 08:00:58 +08:00 
			
		
		
		
	Compare commits
	
		
			190 Commits
		
	
	
		
			replace-py
			...
			update_sub
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 638dcd6039 | |||
| b78968b4d1 | |||
| e5621b4d8b | |||
| 2542e71f3f | |||
| 0242d40fa5 | |||
| 17de899709 | |||
| 25d0d8b0a3 | |||
| c6d697ff52 | |||
| 30d2f98daa | |||
| 8780d28c65 | |||
| da8f48d88f | |||
| cb9e2092a8 | |||
| f6bf1573fc | |||
| 82a18423be | |||
| 3fe3c23d4e | |||
| 052c441cf4 | |||
| b26d2a9464 | |||
| 6382302990 | |||
| 80dd05e31e | |||
| 9df07ecfbe | |||
| 846963fa9b | |||
| 663da17b62 | |||
| e299926f72 | |||
| bbd11c4f23 | |||
| eaa5d9d3d3 | |||
| a7c75ae976 | |||
| f7ad69f59c | |||
| 4cae9cf2df | |||
| 7710800865 | |||
| aa99e0958f | |||
| 3fc7a95176 | |||
| 858fb80b9b | |||
| 55061c9602 | |||
| 214d04833a | |||
| 9c5601ecc3 | |||
| 5b9ad951f8 | |||
| 4d5f92aa39 | |||
| 39ca0ce0c8 | |||
| d52bb67ac3 | |||
| 05b9b63fb6 | |||
| 453cfa5153 | |||
| 9faca5f260 | |||
| 6fe6dd9fdc | |||
| f82c7eed84 | |||
| 25ccc4716e | |||
| d387a48c38 | |||
| 831e85104a | |||
| 211c98859a | |||
| dae7710bf2 | |||
| dc194a3096 | |||
| 4051b42c29 | |||
| eb0eaa67e1 | |||
| 98373e5ad2 | |||
| 371eacb2ae | |||
| 3650989e6e | |||
| 3be70dc30e | |||
| 47a1db823d | |||
| eac2d9d695 | |||
| 3fe19a7a0a | |||
| 4a90dc0c1f | |||
| 1fc683cf17 | |||
| b9d7de3a09 | |||
| 1028c5e2d5 | |||
| 19b4283884 | |||
| 8d6d324631 | |||
| fdfd69bb05 | |||
| 0d3461bac0 | |||
| 65053c03a3 | |||
| 96f9fbe21a | |||
| 1c25871191 | |||
| 6c05ea6475 | |||
| 5665dc9ab7 | |||
| 2ff7c1c774 | |||
| 3028fa6ce9 | |||
| 077cb38974 | |||
| cd8d8c18f5 | |||
| 63654ba4c5 | |||
| 7e27347fd3 | |||
| 1d80d697a2 | |||
| b6b74aed60 | |||
| 4a773e1e86 | |||
| 198b5fd2d4 | |||
| 20bdabbb3c | |||
| d556586448 | |||
| 781e9a7724 | |||
| e4de93f6a3 | |||
| a5652407e4 | |||
| 6f0f4e0c3e | |||
| db763b1717 | |||
| 089c4a1ba0 | |||
| 97c8c98f8d | |||
| 39aa3d1471 | |||
| 639778b3ee | |||
| 00d7d6f123 | |||
| c6d78d4dbd | |||
| 2898d3f965 | |||
| 34358f335d | |||
| fe3f5fe4ea | |||
| 45ba7ecda8 | |||
| 194fcfcfbd | |||
| 29d20d49f0 | |||
| 3faee0a631 | |||
| 8cfaf51d4e | |||
| af3cabc55d | |||
| 74bbe7b4a3 | |||
| 7bfc424a61 | |||
| 5ace061254 | |||
| 15e49f6164 | |||
| dd21c8a578 | |||
| a06ec54d40 | |||
| 50a8c11875 | |||
| e4e4dbd2f8 | |||
| d670304001 | |||
| c5efc5c8a6 | |||
| 6da11d9aaf | |||
| 182efe31db | |||
| 1ea688f9a2 | |||
| 53e3949495 | |||
| 33d9401866 | |||
| d1950d4bb5 | |||
| 1196bb1c2e | |||
| e9eb2096a5 | |||
| 3ef2e1ef76 | |||
| 4cde0acc0e | |||
| 70ccdec44b | |||
| db0b7f1cc9 | |||
| 1c26c53851 | |||
| adcca7d9a1 | |||
| 01584d2a7d | |||
| 87e6c4079d | |||
| 7d87e358ac | |||
| 6ee175195a | |||
| db32b60662 | |||
| 56c828bef9 | |||
| a2fd106d67 | |||
| c656334120 | |||
| 31c9ac4319 | |||
| deea71a90e | |||
| 114a6c4043 | |||
| ee1b0412b9 | |||
| 42e51cd4b3 | |||
| 96bd33b2de | |||
| ecde76c764 | |||
| 34ec5ed275 | |||
| 641ee74781 | |||
| 6e8865fbc1 | |||
| 9a06e6d031 | |||
| 6ea8376f84 | |||
| 8eee08d227 | |||
| e497620260 | |||
| 199e9abb6a | |||
| 685f15dbea | |||
| 85db508af5 | |||
| 27156ec804 | |||
| 6746bc59df | |||
| 3008d985a8 | |||
| 652a6f5954 | |||
| 118bc97b14 | |||
| 305fa22393 | |||
| 1151b40cbf | |||
| d0f9785af3 | |||
| ba47821f52 | |||
| 2c5e10a5fc | |||
| 355462e127 | |||
| 41673110cd | |||
| 6be6d06295 | |||
| f15ada5c6f | |||
| 69a0a9aa7f | |||
| 32099961d5 | |||
| 8d1cf52922 | |||
| b1f43548ca | |||
| 0d71ca2c46 | |||
| 5737372862 | |||
| 2e4e5ab4be | |||
| 16d15445f8 | |||
| 101276f81b | |||
| cbffde7745 | |||
| 78a2fe1d42 | |||
| f8f0414a59 | |||
| 4d419a7461 | |||
| 655137b678 | |||
| f27232a213 | |||
| c24ca7f4bf | |||
| b4596895b9 | |||
| 5a9c4cfce4 | |||
| a354fa91e2 | |||
| f95b58c284 | |||
| 8e6a313858 | |||
| 7e91394955 | |||
| 89654db1ab | 
@ -208,7 +208,9 @@ if __name__ == "__main__":
 | 
			
		||||
    build_vars = "CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000 "
 | 
			
		||||
    # MAX_JOB=5 is not required for CPU backend (see commit 465d98b)
 | 
			
		||||
    if enable_cuda:
 | 
			
		||||
        build_vars = "MAX_JOBS=5 " + build_vars
 | 
			
		||||
        build_vars += "MAX_JOBS=5 "
 | 
			
		||||
        # nvshmem is broken for aarch64 see https://github.com/pytorch/pytorch/issues/160425
 | 
			
		||||
        build_vars += "USE_NVSHMEM=OFF "
 | 
			
		||||
 | 
			
		||||
    override_package_version = os.getenv("OVERRIDE_PACKAGE_VERSION")
 | 
			
		||||
    desired_cuda = os.getenv("DESIRED_CUDA")
 | 
			
		||||
 | 
			
		||||
@ -76,7 +76,6 @@ ADD ./common/install_mnist.sh install_mnist.sh
 | 
			
		||||
RUN bash ./install_mnist.sh
 | 
			
		||||
 | 
			
		||||
FROM base as all_cuda
 | 
			
		||||
COPY --from=cuda11.8  /usr/local/cuda-11.8 /usr/local/cuda-11.8
 | 
			
		||||
COPY --from=cuda12.6  /usr/local/cuda-12.6 /usr/local/cuda-12.6
 | 
			
		||||
COPY --from=cuda12.8  /usr/local/cuda-12.8 /usr/local/cuda-12.8
 | 
			
		||||
COPY --from=cuda12.9  /usr/local/cuda-12.9 /usr/local/cuda-12.9
 | 
			
		||||
 | 
			
		||||
@ -76,6 +76,9 @@ elif [[ "$image" == *cuda*linter* ]]; then
 | 
			
		||||
elif [[ "$image" == *linter* ]]; then
 | 
			
		||||
  # Use a separate Dockerfile for linter to keep a small image size
 | 
			
		||||
  DOCKERFILE="linter/Dockerfile"
 | 
			
		||||
elif [[ "$image" == *riscv* ]]; then
 | 
			
		||||
  # Use RISC-V specific Dockerfile
 | 
			
		||||
  DOCKERFILE="ubuntu-cross-riscv/Dockerfile"
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
_UCX_COMMIT=7bb2722ff2187a0cad557ae4a6afa090569f83fb
 | 
			
		||||
@ -303,6 +306,9 @@ case "$tag" in
 | 
			
		||||
    SKIP_LLVM_SRC_BUILD_INSTALL=yes
 | 
			
		||||
    INDUCTOR_BENCHMARKS=yes
 | 
			
		||||
    ;;
 | 
			
		||||
  pytorch-linux-noble-riscv64-py3.12-gcc14)
 | 
			
		||||
    GCC_VERSION=14
 | 
			
		||||
    ;;
 | 
			
		||||
  *)
 | 
			
		||||
    # Catch-all for builds that are not hardcoded.
 | 
			
		||||
    VISION=yes
 | 
			
		||||
@ -423,7 +429,14 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
if [ -n "$GCC_VERSION" ]; then
 | 
			
		||||
  if !(drun gcc --version 2>&1 | grep -q " $GCC_VERSION\\W"); then
 | 
			
		||||
  if [[ "$image" == *riscv* ]]; then
 | 
			
		||||
    # Check RISC-V cross-compilation toolchain version
 | 
			
		||||
    if !(drun riscv64-linux-gnu-gcc-${GCC_VERSION} --version 2>&1 | grep -q " $GCC_VERSION\\W"); then
 | 
			
		||||
      echo "RISC-V GCC_VERSION=$GCC_VERSION, but:"
 | 
			
		||||
      drun riscv64-linux-gnu-gcc-${GCC_VERSION} --version
 | 
			
		||||
      exit 1
 | 
			
		||||
    fi
 | 
			
		||||
  elif !(drun gcc --version 2>&1 | grep -q " $GCC_VERSION\\W"); then
 | 
			
		||||
    echo "GCC_VERSION=$GCC_VERSION, but:"
 | 
			
		||||
    drun gcc --version
 | 
			
		||||
    exit 1
 | 
			
		||||
 | 
			
		||||
@ -1 +1 @@
 | 
			
		||||
ae324eeac8e102a2b40370e341460f3791353398
 | 
			
		||||
0958dc9b2bb815e428f721f9da599dab0dc1c5d7
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										155
									
								
								.ci/docker/ubuntu-cross-riscv/Dockerfile
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										155
									
								
								.ci/docker/ubuntu-cross-riscv/Dockerfile
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,155 @@
 | 
			
		||||
# Cross-compilation Docker container for RISC-V architecture
 | 
			
		||||
ARG UBUNTU_VERSION
 | 
			
		||||
FROM --platform=linux/amd64 ubuntu:${UBUNTU_VERSION} as base
 | 
			
		||||
 | 
			
		||||
ARG UBUNTU_VERSION
 | 
			
		||||
 | 
			
		||||
ENV GCC_VERSION=14
 | 
			
		||||
ENV PYTHON_VERSION=3.12.3
 | 
			
		||||
ENV DEBIAN_FRONTEND=noninteractive
 | 
			
		||||
ENV CC=riscv64-linux-gnu-gcc-${GCC_VERSION}
 | 
			
		||||
ENV CXX=riscv64-linux-gnu-g++-${GCC_VERSION}
 | 
			
		||||
ENV QEMU_LD_PREFIX=/usr/riscv64-linux-gnu/
 | 
			
		||||
ENV SYSROOT=/opt/sysroot
 | 
			
		||||
 | 
			
		||||
# Install basic dependencies
 | 
			
		||||
RUN apt-get update && apt-get install -y \
 | 
			
		||||
    ninja-build \
 | 
			
		||||
    autoconf \
 | 
			
		||||
    automake \
 | 
			
		||||
    libtool \
 | 
			
		||||
    patchelf \
 | 
			
		||||
    ccache \
 | 
			
		||||
    git \
 | 
			
		||||
    wget \
 | 
			
		||||
    python3-pip \
 | 
			
		||||
    python3-venv \
 | 
			
		||||
    python-is-python3 \
 | 
			
		||||
    cmake \
 | 
			
		||||
    sudo \
 | 
			
		||||
    lsb-release \
 | 
			
		||||
    gcc-${GCC_VERSION}-riscv64-linux-gnu \
 | 
			
		||||
    g++-${GCC_VERSION}-riscv64-linux-gnu \
 | 
			
		||||
    pkg-config \
 | 
			
		||||
    && rm -rf /var/lib/apt/lists/*
 | 
			
		||||
 | 
			
		||||
# Install user
 | 
			
		||||
COPY ./common/install_user.sh install_user.sh
 | 
			
		||||
RUN bash ./install_user.sh && rm install_user.sh
 | 
			
		||||
 | 
			
		||||
FROM base as python
 | 
			
		||||
ARG ZLIB_VERSION=1.3.1
 | 
			
		||||
ARG FFI_VERSION=3.4.6
 | 
			
		||||
ARG BZ2_VERSION=1.0.8
 | 
			
		||||
ARG XZ_VERSION=5.4.6
 | 
			
		||||
ARG OPENSSL_VERSION=3.2.1
 | 
			
		||||
 | 
			
		||||
# Set up sysroot directory for dependencies
 | 
			
		||||
ENV PKG_CONFIG_PATH=${SYSROOT}/lib/pkgconfig
 | 
			
		||||
ENV PKG_CONFIG_SYSROOT_DIR=${SYSROOT}
 | 
			
		||||
 | 
			
		||||
WORKDIR /opt
 | 
			
		||||
 | 
			
		||||
# Build zlib (for compression)
 | 
			
		||||
RUN echo "--- Building zlib ---" \
 | 
			
		||||
    && wget -c https://www.zlib.net/zlib-${ZLIB_VERSION}.tar.gz \
 | 
			
		||||
    && tar -xf zlib-${ZLIB_VERSION}.tar.gz --no-same-permissions --no-same-owner \
 | 
			
		||||
    && cd zlib-${ZLIB_VERSION}/ \
 | 
			
		||||
    && mkdir build && cd build \
 | 
			
		||||
    && ../configure --prefix=${SYSROOT} \
 | 
			
		||||
    && make -j$(nproc) && make install \
 | 
			
		||||
    && cd ../..
 | 
			
		||||
 | 
			
		||||
# Build libffi (for ctypes module)
 | 
			
		||||
RUN echo "--- Building libffi ---" \
 | 
			
		||||
    && wget -c https://github.com/libffi/libffi/releases/download/v${FFI_VERSION}/libffi-${FFI_VERSION}.tar.gz \
 | 
			
		||||
    && tar -xf libffi-${FFI_VERSION}.tar.gz --no-same-permissions --no-same-owner \
 | 
			
		||||
    && cd libffi-${FFI_VERSION}/ \
 | 
			
		||||
    && mkdir build && cd build \
 | 
			
		||||
    && ../configure --prefix=${SYSROOT} --host=riscv64-linux-gnu --build=x86_64-linux-gnu \
 | 
			
		||||
    && make -j$(nproc) && make install \
 | 
			
		||||
    && cd ../..
 | 
			
		||||
 | 
			
		||||
# Build bzip2 (for bz2 module)
 | 
			
		||||
RUN echo "--- Building bzip2 ---" \
 | 
			
		||||
    && wget -c https://sourceware.org/pub/bzip2/bzip2-${BZ2_VERSION}.tar.gz \
 | 
			
		||||
    && tar -xf bzip2-${BZ2_VERSION}.tar.gz --no-same-permissions --no-same-owner \
 | 
			
		||||
    && cd bzip2-${BZ2_VERSION}/ \
 | 
			
		||||
    && make CC=riscv64-linux-gnu-gcc-${GCC_VERSION} bzip2 bzip2recover libbz2.a \
 | 
			
		||||
    && make CC=riscv64-linux-gnu-gcc-${GCC_VERSION} -f Makefile-libbz2_so \
 | 
			
		||||
    && make install PREFIX=${SYSROOT} \
 | 
			
		||||
    && cp libbz2.so.${BZ2_VERSION} ${SYSROOT}/lib/ \
 | 
			
		||||
    && cd ${SYSROOT}/lib/ \
 | 
			
		||||
    && ln -sf libbz2.so.${BZ2_VERSION} libbz2.so.1.0 \
 | 
			
		||||
    && ln -sf libbz2.so.1.0 libbz2.so \
 | 
			
		||||
    && cd /opt/
 | 
			
		||||
 | 
			
		||||
# Build xz (for lzma module)
 | 
			
		||||
RUN echo "--- Building xz ---" \
 | 
			
		||||
    && wget -c https://github.com/tukaani-project/xz/releases/download/v${XZ_VERSION}/xz-${XZ_VERSION}.tar.gz \
 | 
			
		||||
    && tar -xf xz-${XZ_VERSION}.tar.gz --no-same-permissions --no-same-owner \
 | 
			
		||||
    && cd xz-${XZ_VERSION} \
 | 
			
		||||
    && mkdir build && cd build \
 | 
			
		||||
    && ../configure --prefix=${SYSROOT} --host=riscv64-linux-gnu --build=x86_64-linux-gnu \
 | 
			
		||||
    && make -j$(nproc) && make install \
 | 
			
		||||
    && cd ../..
 | 
			
		||||
 | 
			
		||||
# Build OpenSSL (for ssl module)
 | 
			
		||||
RUN echo "--- Building OpenSSL ---" \
 | 
			
		||||
    && wget -c https://www.openssl.org/source/openssl-${OPENSSL_VERSION}.tar.gz \
 | 
			
		||||
    && tar -xf openssl-${OPENSSL_VERSION}.tar.gz --no-same-permissions --no-same-owner \
 | 
			
		||||
    && cd openssl-${OPENSSL_VERSION}/ \
 | 
			
		||||
    && mkdir build && cd build \
 | 
			
		||||
    && ../Configure linux64-riscv64 --prefix=${SYSROOT} \
 | 
			
		||||
    && make -j$(nproc) && make install_sw \
 | 
			
		||||
    && cd ../..
 | 
			
		||||
 | 
			
		||||
# Build SQLite3 (for sqlite3 module)
 | 
			
		||||
RUN echo "--- Building SQLite3 ---" \
 | 
			
		||||
    && wget -c https://www.sqlite.org/2024/sqlite-autoconf-3450200.tar.gz \
 | 
			
		||||
    && tar -xf sqlite-autoconf-3450200.tar.gz --no-same-permissions --no-same-owner \
 | 
			
		||||
    && cd sqlite-autoconf-3450200 \
 | 
			
		||||
    && mkdir build && cd build \
 | 
			
		||||
    && ../configure --prefix=${SYSROOT} --host=riscv64-linux-gnu --build=x86_64-linux-gnu \
 | 
			
		||||
    && make -j$(nproc) && make install \
 | 
			
		||||
    && cd ../..
 | 
			
		||||
 | 
			
		||||
# Build and install RISC-V Python with all modules
 | 
			
		||||
RUN wget -c https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tgz \
 | 
			
		||||
    && tar -xf Python-${PYTHON_VERSION}.tgz --no-same-permissions --no-same-owner \
 | 
			
		||||
    && cd Python-${PYTHON_VERSION} \
 | 
			
		||||
    && mkdir build && cd build \
 | 
			
		||||
    && ../configure \
 | 
			
		||||
        --host=riscv64-linux-gnu \
 | 
			
		||||
        --build=x86_64-linux-gnu \
 | 
			
		||||
        --prefix=${SYSROOT} \
 | 
			
		||||
        --enable-shared \
 | 
			
		||||
        --disable-ipv6 \
 | 
			
		||||
        --with-build-python=/usr/bin/python3 \
 | 
			
		||||
        --with-ensurepip=no \
 | 
			
		||||
        ac_cv_file__dev_ptmx=yes \
 | 
			
		||||
        ac_cv_file__dev_ptc=no \
 | 
			
		||||
    && make -j$(nproc) \
 | 
			
		||||
    && make install
 | 
			
		||||
 | 
			
		||||
FROM base as final
 | 
			
		||||
COPY --from=python             /opt/sysroot                       /opt/sysroot
 | 
			
		||||
 | 
			
		||||
# Install crossenv and cmake
 | 
			
		||||
RUN pip install crossenv cmake==4.0.0 --break-system-packages \
 | 
			
		||||
    && /usr/bin/python3 -m crossenv ${SYSROOT}/bin/python3 /opt/riscv-cross-env
 | 
			
		||||
 | 
			
		||||
# Add pip-installed cmake binaries to PATH
 | 
			
		||||
ENV PATH="/usr/local/bin:${PATH}"
 | 
			
		||||
 | 
			
		||||
# Set up cross Python environment
 | 
			
		||||
SHELL ["/bin/bash", "-c"]
 | 
			
		||||
RUN source /opt/riscv-cross-env/bin/activate \
 | 
			
		||||
    && pip install setuptools pyyaml typing_extensions wheel
 | 
			
		||||
 | 
			
		||||
# Set default environment variables for PyTorch build
 | 
			
		||||
ENV Python_ROOT_DIR=${SYSROOT}
 | 
			
		||||
ENV OPENSSL_ROOT_DIR=${SYSROOT}
 | 
			
		||||
 | 
			
		||||
USER jenkins
 | 
			
		||||
CMD ["bash"]
 | 
			
		||||
							
								
								
									
										31
									
								
								.ci/lumen_cli/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								.ci/lumen_cli/README.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,31 @@
 | 
			
		||||
# 🔧 Lumen_cli
 | 
			
		||||
A Python CLI tool for building and testing PyTorch-based components, using a YAML configuration file for structured, repeatable workflows.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## Features
 | 
			
		||||
- **Build**
 | 
			
		||||
    - external projects (e.g. vLLM)
 | 
			
		||||
 | 
			
		||||
## 📦 Installation
 | 
			
		||||
at the root of the pytorch repo
 | 
			
		||||
```bash
 | 
			
		||||
pip install -e .ci/lumen_cli
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Run the cli tool
 | 
			
		||||
The cli tool must be used at root of pytorch repo, as example to run build external vllm:
 | 
			
		||||
```bash
 | 
			
		||||
python -m cli.run build external vllm
 | 
			
		||||
```
 | 
			
		||||
this will run the build steps with default behaviour for vllm project.
 | 
			
		||||
 | 
			
		||||
to see help messages, run
 | 
			
		||||
```bash
 | 
			
		||||
python3 -m cli.run --help
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Add customized external build logics
 | 
			
		||||
To add a new external build, for instance, add a new external build logics:
 | 
			
		||||
1. create the build function in cli/lib folder
 | 
			
		||||
2. register your target and the main build function at  EXTERNAL_BUILD_TARGET_DISPATCH in `cli/build_cli/register_build.py`
 | 
			
		||||
3. [optional] create your ci config file in .github/ci_configs/${EXTERNAL_PACKAGE_NAME}.yaml
 | 
			
		||||
							
								
								
									
										37
									
								
								.ci/lumen_cli/cli/build_cli/register_build.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								.ci/lumen_cli/cli/build_cli/register_build.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,37 @@
 | 
			
		||||
import argparse
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from cli.lib.common.cli_helper import register_targets, RichHelp, TargetSpec
 | 
			
		||||
from cli.lib.core.vllm import VllmBuildRunner
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# Maps targets to their argparse configuration and runner
 | 
			
		||||
# it adds new target to path python -m cli.run build external {target} with buildrunner
 | 
			
		||||
_TARGETS: dict[str, TargetSpec] = {
 | 
			
		||||
    "vllm": {
 | 
			
		||||
        "runner": VllmBuildRunner,
 | 
			
		||||
        "help": "Build vLLM using docker buildx.",
 | 
			
		||||
    }
 | 
			
		||||
    # add yours ...
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def register_build_commands(subparsers: argparse._SubParsersAction) -> None:
 | 
			
		||||
    build_parser = subparsers.add_parser(
 | 
			
		||||
        "build",
 | 
			
		||||
        help="Build related commands",
 | 
			
		||||
        formatter_class=RichHelp,
 | 
			
		||||
    )
 | 
			
		||||
    build_subparsers = build_parser.add_subparsers(dest="build_command", required=True)
 | 
			
		||||
    overview = "\n".join(
 | 
			
		||||
        f"  {name:12} {spec.get('help', '')}" for name, spec in _TARGETS.items()
 | 
			
		||||
    )
 | 
			
		||||
    external_parser = build_subparsers.add_parser(
 | 
			
		||||
        "external",
 | 
			
		||||
        help="Build external targets",
 | 
			
		||||
        description="Build third-party targets.\n\nAvailable targets:\n" + overview,
 | 
			
		||||
        formatter_class=RichHelp,
 | 
			
		||||
    )
 | 
			
		||||
    register_targets(external_parser, _TARGETS)
 | 
			
		||||
							
								
								
									
										71
									
								
								.ci/lumen_cli/cli/lib/common/cli_helper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								.ci/lumen_cli/cli/lib/common/cli_helper.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,71 @@
 | 
			
		||||
"""
 | 
			
		||||
Cli Argparser Utility helpers for CLI tasks.
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from typing import Any, Callable, Required, TypedDict  # Python 3.11+
 | 
			
		||||
except ImportError:
 | 
			
		||||
    from typing import Any, Callable, TypedDict
 | 
			
		||||
 | 
			
		||||
    from typing_extensions import Required  # Fallback for Python <3.11
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseRunner(ABC):
 | 
			
		||||
    def __init__(self, args: Any) -> None:
 | 
			
		||||
        self.args = args
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def run(self) -> None:
 | 
			
		||||
        """runs main logics, required"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Pretty help: keep newlines + show defaults
 | 
			
		||||
class RichHelp(
 | 
			
		||||
    argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter
 | 
			
		||||
):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TargetSpec(TypedDict, total=False):
 | 
			
		||||
    """CLI subcommand specification with bA."""
 | 
			
		||||
 | 
			
		||||
    runner: Required[type[BaseRunner]]
 | 
			
		||||
    help: str
 | 
			
		||||
    description: str
 | 
			
		||||
    add_arguments: Callable[[argparse.ArgumentParser], None]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def register_targets(
 | 
			
		||||
    parser: argparse.ArgumentParser,
 | 
			
		||||
    target_specs: dict[str, TargetSpec],
 | 
			
		||||
    common_args: Callable[[argparse.ArgumentParser], None] = lambda _: None,
 | 
			
		||||
) -> None:
 | 
			
		||||
    """Register target subcommands."""
 | 
			
		||||
    targets = parser.add_subparsers(
 | 
			
		||||
        dest="target",
 | 
			
		||||
        required=True,
 | 
			
		||||
        metavar="{" + ",".join(target_specs.keys()) + "}",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    for name, spec in target_specs.items():
 | 
			
		||||
        desc = spec.get("description") or spec["runner"].__doc__ or ""
 | 
			
		||||
 | 
			
		||||
        p = targets.add_parser(
 | 
			
		||||
            name,
 | 
			
		||||
            help=spec.get("help", ""),
 | 
			
		||||
            description=desc.strip(),
 | 
			
		||||
            formatter_class=RichHelp,
 | 
			
		||||
        )
 | 
			
		||||
        p.set_defaults(
 | 
			
		||||
            func=lambda args, cls=spec["runner"]: cls(args).run(),
 | 
			
		||||
            _runner_class=spec["runner"],
 | 
			
		||||
        )
 | 
			
		||||
        if "add_arguments" in spec and callable(spec["add_arguments"]):
 | 
			
		||||
            spec["add_arguments"](p)
 | 
			
		||||
        if common_args:
 | 
			
		||||
            common_args(p)
 | 
			
		||||
							
								
								
									
										42
									
								
								.ci/lumen_cli/cli/lib/common/docker_helper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								.ci/lumen_cli/cli/lib/common/docker_helper.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,42 @@
 | 
			
		||||
"""
 | 
			
		||||
Docker Utility helpers for CLI tasks.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import docker
 | 
			
		||||
from docker.errors import APIError, NotFound
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# lazy singleton so we don't reconnect every call
 | 
			
		||||
_docker_client: Optional[docker.DockerClient] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_client() -> docker.DockerClient:
 | 
			
		||||
    global _docker_client
 | 
			
		||||
    if _docker_client is None:
 | 
			
		||||
        _docker_client = docker.from_env()
 | 
			
		||||
    return _docker_client
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def local_image_exists(
 | 
			
		||||
    image_name: str, client: Optional[docker.DockerClient] = None
 | 
			
		||||
) -> bool:
 | 
			
		||||
    """Return True if a local Docker image exists."""
 | 
			
		||||
    if not image_name:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    client = client or _get_client()
 | 
			
		||||
    try:
 | 
			
		||||
        client.images.get(image_name)
 | 
			
		||||
        return True
 | 
			
		||||
    except (NotFound, APIError) as e:
 | 
			
		||||
        logger.error(
 | 
			
		||||
            "Error when checking Docker image '%s': %s",
 | 
			
		||||
            image_name,
 | 
			
		||||
            e.explanation if hasattr(e, "explanation") else str(e),
 | 
			
		||||
        )
 | 
			
		||||
        return False
 | 
			
		||||
							
								
								
									
										110
									
								
								.ci/lumen_cli/cli/lib/common/envs_helper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								.ci/lumen_cli/cli/lib/common/envs_helper.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,110 @@
 | 
			
		||||
"""
 | 
			
		||||
Environment Variables and Dataclasses Utility helpers for CLI tasks.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from dataclasses import field, fields, is_dataclass, MISSING
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from textwrap import indent
 | 
			
		||||
from typing import Optional, Union
 | 
			
		||||
 | 
			
		||||
from cli.lib.common.utils import str2bool
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_env(name: str, default: str = "") -> str:
 | 
			
		||||
    """Get environment variable with default fallback."""
 | 
			
		||||
    return os.environ.get(name) or default
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def env_path_optional(
 | 
			
		||||
    name: str,
 | 
			
		||||
    default: Optional[Union[str, Path]] = None,
 | 
			
		||||
    resolve: bool = True,
 | 
			
		||||
) -> Optional[Path]:
 | 
			
		||||
    """Get environment variable as optional Path."""
 | 
			
		||||
    val = get_env(name) or default
 | 
			
		||||
    if not val:
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    path = Path(val)
 | 
			
		||||
    return path.resolve() if resolve else path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def env_path(
 | 
			
		||||
    name: str,
 | 
			
		||||
    default: Optional[Union[str, Path]] = None,
 | 
			
		||||
    resolve: bool = True,
 | 
			
		||||
) -> Path:
 | 
			
		||||
    """Get environment variable as Path, raise if missing."""
 | 
			
		||||
    path = env_path_optional(name, default, resolve)
 | 
			
		||||
    if not path:
 | 
			
		||||
        raise ValueError(f"Missing path value for {name}")
 | 
			
		||||
    return path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def env_bool(
 | 
			
		||||
    name: str,
 | 
			
		||||
    default: bool = False,
 | 
			
		||||
) -> bool:
 | 
			
		||||
    val = get_env(name)
 | 
			
		||||
    if not val:
 | 
			
		||||
        return default
 | 
			
		||||
    return str2bool(val)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def env_bool_field(
 | 
			
		||||
    name: str,
 | 
			
		||||
    default: bool = False,
 | 
			
		||||
):
 | 
			
		||||
    return field(default_factory=lambda: env_bool(name, default))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def env_path_field(
 | 
			
		||||
    name: str,
 | 
			
		||||
    default: Union[str, Path] = "",
 | 
			
		||||
    *,
 | 
			
		||||
    resolve: bool = True,
 | 
			
		||||
) -> Path:
 | 
			
		||||
    return field(default_factory=lambda: env_path(name, default, resolve=resolve))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def env_str_field(
 | 
			
		||||
    name: str,
 | 
			
		||||
    default: str = "",
 | 
			
		||||
) -> str:
 | 
			
		||||
    return field(default_factory=lambda: get_env(name, default))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_dataclass_help(cls) -> str:
 | 
			
		||||
    """Auto-generate help text for dataclass fields."""
 | 
			
		||||
    if not is_dataclass(cls):
 | 
			
		||||
        raise TypeError(f"{cls} is not a dataclass")
 | 
			
		||||
 | 
			
		||||
    def get_value(f):
 | 
			
		||||
        if f.default is not MISSING:
 | 
			
		||||
            return f.default
 | 
			
		||||
        if f.default_factory is not MISSING:
 | 
			
		||||
            try:
 | 
			
		||||
                return f.default_factory()
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                return f"<error: {e}>"
 | 
			
		||||
        return "<required>"
 | 
			
		||||
 | 
			
		||||
    lines = [f"{f.name:<22} = {repr(get_value(f))}" for f in fields(cls)]
 | 
			
		||||
    return indent("\n".join(lines), "    ")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def with_params_help(params_cls: type, title: str = "Parameter defaults"):
 | 
			
		||||
    """
 | 
			
		||||
    Class decorator that appends a help table generated from another dataclass
 | 
			
		||||
    (e.g., VllmParameters) to the decorated class's docstring.
 | 
			
		||||
    """
 | 
			
		||||
    if not is_dataclass(params_cls):
 | 
			
		||||
        raise TypeError(f"{params_cls} must be a dataclass")
 | 
			
		||||
 | 
			
		||||
    def _decorator(cls: type) -> type:
 | 
			
		||||
        block = generate_dataclass_help(params_cls)
 | 
			
		||||
        cls.__doc__ = (cls.__doc__ or "") + f"\n\n{title}:\n{block}"
 | 
			
		||||
        return cls
 | 
			
		||||
 | 
			
		||||
    return _decorator
 | 
			
		||||
							
								
								
									
										69
									
								
								.ci/lumen_cli/cli/lib/common/git_helper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								.ci/lumen_cli/cli/lib/common/git_helper.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,69 @@
 | 
			
		||||
"""
 | 
			
		||||
Git Utility helpers for CLI tasks.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
from cli.lib.common.path_helper import remove_dir
 | 
			
		||||
from git import GitCommandError, RemoteProgress, Repo
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PrintProgress(RemoteProgress):
 | 
			
		||||
    """Simple progress logger for git operations."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, interval: int = 5):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self._last_percent = -1
 | 
			
		||||
        self._interval = interval
 | 
			
		||||
 | 
			
		||||
    def update(self, op_code, cur, max=None, message=""):
 | 
			
		||||
        msg = self._cur_line or message
 | 
			
		||||
        if max and cur:
 | 
			
		||||
            percent = int(cur / max * 100)
 | 
			
		||||
            if percent != self._last_percent and percent % self._interval == 0:
 | 
			
		||||
                self._last_percent = percent
 | 
			
		||||
                logger.info("Progress: %d%% - %s", percent, msg)
 | 
			
		||||
        elif msg:
 | 
			
		||||
            logger.info(msg)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def clone_external_repo(target: str, repo: str, dst: str = "", update_submodules=False):
 | 
			
		||||
    """Clone repository with pinned commit and optional submodules."""
 | 
			
		||||
    dst = dst or target
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        logger.info("Cloning %s to %s", target, dst)
 | 
			
		||||
 | 
			
		||||
        # Clone and fetch
 | 
			
		||||
        remove_dir(dst)
 | 
			
		||||
        r = Repo.clone_from(repo, dst, progress=PrintProgress())
 | 
			
		||||
        r.git.fetch("--all", "--tags")
 | 
			
		||||
 | 
			
		||||
        # Checkout pinned commit
 | 
			
		||||
        commit = get_post_build_pinned_commit(target)
 | 
			
		||||
        logger.info("Checking out pinned commit %s", commit)
 | 
			
		||||
        r.git.checkout(commit)
 | 
			
		||||
 | 
			
		||||
        # Update submodules if requested
 | 
			
		||||
        if update_submodules and r.submodules:
 | 
			
		||||
            logger.info("Updating %d submodule(s)", len(r.submodules))
 | 
			
		||||
            for sm in r.submodules:
 | 
			
		||||
                sm.update(init=True, recursive=True, progress=PrintProgress())
 | 
			
		||||
 | 
			
		||||
        logger.info("Successfully cloned %s", target)
 | 
			
		||||
        return r
 | 
			
		||||
 | 
			
		||||
    except GitCommandError as e:
 | 
			
		||||
        logger.error("Git operation failed: %s", e)
 | 
			
		||||
        raise
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_post_build_pinned_commit(name: str, prefix=".github/ci_commit_pins") -> str:
 | 
			
		||||
    path = Path(prefix) / f"{name}.txt"
 | 
			
		||||
    if not path.exists():
 | 
			
		||||
        raise FileNotFoundError(f"Pin file not found: {path}")
 | 
			
		||||
    return path.read_text(encoding="utf-8").strip()
 | 
			
		||||
							
								
								
									
										14
									
								
								.ci/lumen_cli/cli/lib/common/logger.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								.ci/lumen_cli/cli/lib/common/logger.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,14 @@
 | 
			
		||||
"""
 | 
			
		||||
Logger Utility helpers for CLI tasks.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def setup_logging(level: int = logging.INFO):
 | 
			
		||||
    logging.basicConfig(
 | 
			
		||||
        level=level,
 | 
			
		||||
        format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
 | 
			
		||||
        stream=sys.stdout,
 | 
			
		||||
    )
 | 
			
		||||
							
								
								
									
										62
									
								
								.ci/lumen_cli/cli/lib/common/path_helper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										62
									
								
								.ci/lumen_cli/cli/lib/common/path_helper.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,62 @@
 | 
			
		||||
"""Path utility helpers for CLI tasks."""
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import shutil
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Union
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_path(path: Union[str, Path], resolve: bool = False) -> Path:
 | 
			
		||||
    """Convert to Path object, optionally resolving to absolute path."""
 | 
			
		||||
    if not path:
 | 
			
		||||
        raise ValueError("Path cannot be None or empty")
 | 
			
		||||
    result = Path(path)
 | 
			
		||||
    return result.resolve() if resolve else result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def ensure_dir_exists(path: Union[str, Path]) -> Path:
 | 
			
		||||
    """Create directory if it doesn't exist."""
 | 
			
		||||
    path_obj = get_path(path)
 | 
			
		||||
    path_obj.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
    return path_obj
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def remove_dir(path: Union[str, Path, None]) -> None:
 | 
			
		||||
    """Remove directory if it exists."""
 | 
			
		||||
    if not path:
 | 
			
		||||
        return
 | 
			
		||||
    path_obj = get_path(path)
 | 
			
		||||
    if path_obj.exists():
 | 
			
		||||
        shutil.rmtree(path_obj)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def force_create_dir(path: Union[str, Path]) -> Path:
 | 
			
		||||
    """Remove directory if exists, then create fresh empty directory."""
 | 
			
		||||
    remove_dir(path)
 | 
			
		||||
    return ensure_dir_exists(path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def copy(src: Union[str, Path], dst: Union[str, Path]) -> None:
 | 
			
		||||
    """Copy file or directory from src to dst."""
 | 
			
		||||
    src_path = get_path(src, resolve=True)
 | 
			
		||||
    dst_path = get_path(dst, resolve=True)
 | 
			
		||||
 | 
			
		||||
    if not src_path.exists():
 | 
			
		||||
        raise FileNotFoundError(f"Source does not exist: {src_path}")
 | 
			
		||||
 | 
			
		||||
    dst_path.parent.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
    if src_path.is_file():
 | 
			
		||||
        shutil.copy2(src_path, dst_path)
 | 
			
		||||
    elif src_path.is_dir():
 | 
			
		||||
        shutil.copytree(src_path, dst_path, dirs_exist_ok=True)
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Unsupported path type: {src_path}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_path_exist(path: Union[str, Path, None]) -> bool:
 | 
			
		||||
    """Check if path exists."""
 | 
			
		||||
    return bool(path and get_path(path).exists())
 | 
			
		||||
							
								
								
									
										79
									
								
								.ci/lumen_cli/cli/lib/common/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								.ci/lumen_cli/cli/lib/common/utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,79 @@
 | 
			
		||||
"""
 | 
			
		||||
General Utility helpers for CLI tasks.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
import shlex
 | 
			
		||||
import subprocess
 | 
			
		||||
import sys
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_command(
 | 
			
		||||
    cmd: str,
 | 
			
		||||
    use_shell: bool = False,
 | 
			
		||||
    log_cmd: bool = True,
 | 
			
		||||
    cwd: Optional[str] = None,
 | 
			
		||||
    env: Optional[dict] = None,
 | 
			
		||||
    check: bool = True,
 | 
			
		||||
) -> int:
 | 
			
		||||
    """Run a command with optional shell execution."""
 | 
			
		||||
    if use_shell:
 | 
			
		||||
        args = cmd
 | 
			
		||||
        log_prefix = "[shell]"
 | 
			
		||||
        executable = "/bin/bash"
 | 
			
		||||
    else:
 | 
			
		||||
        args = shlex.split(cmd)
 | 
			
		||||
        log_prefix = "[cmd]"
 | 
			
		||||
        executable = None
 | 
			
		||||
 | 
			
		||||
    if log_cmd:
 | 
			
		||||
        display_cmd = cmd if use_shell else " ".join(args)
 | 
			
		||||
        logger.info("%s %s", log_prefix, display_cmd)
 | 
			
		||||
 | 
			
		||||
    run_env = {**os.environ, **(env or {})}
 | 
			
		||||
 | 
			
		||||
    proc = subprocess.run(
 | 
			
		||||
        args,
 | 
			
		||||
        shell=use_shell,
 | 
			
		||||
        executable=executable,
 | 
			
		||||
        stdout=sys.stdout,
 | 
			
		||||
        stderr=sys.stderr,
 | 
			
		||||
        cwd=cwd,
 | 
			
		||||
        env=run_env,
 | 
			
		||||
        check=False,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if check and proc.returncode != 0:
 | 
			
		||||
        logger.error(
 | 
			
		||||
            "%s Command failed (exit %s): %s", log_prefix, proc.returncode, cmd
 | 
			
		||||
        )
 | 
			
		||||
        raise subprocess.CalledProcessError(
 | 
			
		||||
            proc.returncode, args if not use_shell else cmd
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return proc.returncode
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def str2bool(value: Optional[str]) -> bool:
 | 
			
		||||
    """Convert environment variables to boolean values."""
 | 
			
		||||
    if not value:
 | 
			
		||||
        return False
 | 
			
		||||
    if not isinstance(value, str):
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"Expected a string value for boolean conversion, got {type(value)}"
 | 
			
		||||
        )
 | 
			
		||||
    value = value.strip().lower()
 | 
			
		||||
 | 
			
		||||
    true_value_set = {"1", "true", "t", "yes", "y", "on", "enable", "enabled", "found"}
 | 
			
		||||
    false_value_set = {"0", "false", "f", "no", "n", "off", "disable"}
 | 
			
		||||
 | 
			
		||||
    if value in true_value_set:
 | 
			
		||||
        return True
 | 
			
		||||
    if value in false_value_set:
 | 
			
		||||
        return False
 | 
			
		||||
    raise ValueError(f"Invalid string value for boolean conversion: {value}")
 | 
			
		||||
							
								
								
									
										263
									
								
								.ci/lumen_cli/cli/lib/core/vllm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										263
									
								
								.ci/lumen_cli/cli/lib/core/vllm.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,263 @@
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
import textwrap
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from cli.lib.common.cli_helper import BaseRunner
 | 
			
		||||
from cli.lib.common.docker_helper import local_image_exists
 | 
			
		||||
from cli.lib.common.envs_helper import (
 | 
			
		||||
    env_bool_field,
 | 
			
		||||
    env_path_field,
 | 
			
		||||
    env_str_field,
 | 
			
		||||
    with_params_help,
 | 
			
		||||
)
 | 
			
		||||
from cli.lib.common.git_helper import clone_external_repo
 | 
			
		||||
from cli.lib.common.path_helper import (
 | 
			
		||||
    copy,
 | 
			
		||||
    ensure_dir_exists,
 | 
			
		||||
    force_create_dir,
 | 
			
		||||
    get_path,
 | 
			
		||||
    is_path_exist,
 | 
			
		||||
)
 | 
			
		||||
from cli.lib.common.utils import run_command
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Default path for docker build artifacts
 | 
			
		||||
_DEFAULT_RESULT_PATH = "./shared"
 | 
			
		||||
 | 
			
		||||
# Temp folder in vllm work place to cp torch whls in vllm work directory for docker build
 | 
			
		||||
_VLLM_TEMP_FOLDER = "tmp"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class VllmBuildParameters:
 | 
			
		||||
    """
 | 
			
		||||
    Parameters defining the vllm external input configurations.
 | 
			
		||||
    Combine with VllmDockerBuildArgs to define the vllm build environment
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # USE_TORCH_WHEEL: when true, use local Torch wheels; requires TORCH_WHEELS_PATH.
 | 
			
		||||
    #  Otherwise docker build pull torch nightly during build
 | 
			
		||||
    # TORCH_WHEELS_PATH: directory containing local torch wheels when use_torch_whl is True
 | 
			
		||||
    use_torch_whl: bool = env_bool_field("USE_TORCH_WHEEL", True)
 | 
			
		||||
    torch_whls_path: Path = env_path_field("TORCH_WHEELS_PATH", "./dist")
 | 
			
		||||
 | 
			
		||||
    # USE_LOCAL_BASE_IMAGE: when true, use an existing local Docker base image; requires BASE_IMAGE
 | 
			
		||||
    # Otherwise, pull dockerfile's default image remotely
 | 
			
		||||
    # BASE_IMAGE: name:tag (only needed when use_local_base_image is True)
 | 
			
		||||
    use_local_base_image: bool = env_bool_field("USE_LOCAL_BASE_IMAGE", True)
 | 
			
		||||
    base_image: str = env_str_field("BASE_IMAGE")
 | 
			
		||||
 | 
			
		||||
    # USE_LOCAL_DOCKERFILE: when true("1"), use a local Dockerfile; requires DOCKERFILE_PATH.
 | 
			
		||||
    # otherwise, use vllm's default dockerfile.torch_nightly for build
 | 
			
		||||
    # DOCKERFILE_PATH: path to Dockerfile used when use_local_dockerfile is True"
 | 
			
		||||
    use_local_dockerfile: bool = env_bool_field("USE_LOCAL_DOCKERFILE", True)
 | 
			
		||||
    dockerfile_path: Path = env_path_field(
 | 
			
		||||
        "DOCKERFILE_PATH", ".github/ci_configs/vllm/Dockerfile.tmp_vllm"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # OUTPUT_DIR: where docker buildx (local exporter) will write artifacts
 | 
			
		||||
    output_dir: Path = env_path_field("OUTPUT_DIR", "external/vllm")
 | 
			
		||||
 | 
			
		||||
    # --- Build args ----------------------------------------------------------
 | 
			
		||||
    target_stage: str = env_str_field("TARGET_STAGE", "export-wheels")
 | 
			
		||||
 | 
			
		||||
    tag_name: str = env_str_field("TAG", "vllm-wheels")
 | 
			
		||||
 | 
			
		||||
    cuda_version: str = env_str_field("CUDA_VERSION", "12.8.1")
 | 
			
		||||
 | 
			
		||||
    python_version: str = env_str_field("PYTHON_VERSION", "3.12")
 | 
			
		||||
 | 
			
		||||
    max_jobs: str = env_str_field("MAX_JOBS", "64")
 | 
			
		||||
 | 
			
		||||
    sccache_bucket: str = env_str_field("SCCACHE_BUCKET")
 | 
			
		||||
 | 
			
		||||
    sccache_region: str = env_str_field("SCCACHE_REGION")
 | 
			
		||||
 | 
			
		||||
    torch_cuda_arch_list: str = env_str_field("TORCH_CUDA_ARCH_LIST", "8.9")
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        checks = [
 | 
			
		||||
            (
 | 
			
		||||
                self.use_torch_whl,  # flag
 | 
			
		||||
                True,  # trigger_value
 | 
			
		||||
                "torch_whls_path",  # resource
 | 
			
		||||
                is_path_exist,  # check_func
 | 
			
		||||
                "TORCH_WHEELS_PATH is not provided, but USE_TORCH_WHEEL is set to 1",
 | 
			
		||||
            ),
 | 
			
		||||
            (
 | 
			
		||||
                self.use_local_base_image,
 | 
			
		||||
                True,
 | 
			
		||||
                "base_image",
 | 
			
		||||
                local_image_exists,
 | 
			
		||||
                f"BASE_IMAGE {self.base_image} does not found, but USE_LOCAL_BASE_IMAGE is set to 1",
 | 
			
		||||
            ),
 | 
			
		||||
            (
 | 
			
		||||
                self.use_local_dockerfile,
 | 
			
		||||
                True,
 | 
			
		||||
                "dockerfile_path",
 | 
			
		||||
                is_path_exist,
 | 
			
		||||
                " DOCKERFILE_PATH path does not found, but USE_LOCAL_DOCKERFILE is set to 1",
 | 
			
		||||
            ),
 | 
			
		||||
        ]
 | 
			
		||||
        for flag, trigger_value, attr_name, check_func, error_msg in checks:
 | 
			
		||||
            value = getattr(self, attr_name)
 | 
			
		||||
            if flag == trigger_value:
 | 
			
		||||
                if not value or not check_func(value):
 | 
			
		||||
                    raise ValueError(error_msg)
 | 
			
		||||
            else:
 | 
			
		||||
                logger.info("flag  %s is not set", flag)
 | 
			
		||||
        if not self.output_dir:
 | 
			
		||||
            raise ValueError("missing required output_dir")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@with_params_help(VllmBuildParameters)
 | 
			
		||||
class VllmBuildRunner(BaseRunner):
 | 
			
		||||
    """
 | 
			
		||||
    Build vLLM using docker buildx.
 | 
			
		||||
 | 
			
		||||
    Environment variable options:
 | 
			
		||||
        "USE_TORCH_WHEEL":      "1: use local wheels; 0: pull nightly from pypi",
 | 
			
		||||
        "TORCH_WHEELS_PATH":    "Path to local wheels (when USE_TORCH_WHEEL=1)",
 | 
			
		||||
 | 
			
		||||
        "USE_LOCAL_BASE_IMAGE": "1: use local base image; 0: default image",
 | 
			
		||||
         "BASE_IMAGE":           "name:tag to indicate base image the dockerfile depends on (when USE_LOCAL_BASE_IMAGE=1)",
 | 
			
		||||
 | 
			
		||||
        "USE_LOCAL_DOCKERFILE": "1: use local Dockerfile; 0: vllm repo default dockerfile.torch_nightly",
 | 
			
		||||
        "DOCKERFILE_PATH":      "Path to Dockerfile (when USE_LOCAL_DOCKERFILE=1)",
 | 
			
		||||
 | 
			
		||||
        "OUTPUT_DIR":           "e.g. './shared'",
 | 
			
		||||
 | 
			
		||||
        "TORCH_CUDA_ARCH_LIST": "e.g. '8.0' or '8.0;9.0'",
 | 
			
		||||
        "CUDA_VERSION":         "e.g. '12.8.1'",
 | 
			
		||||
        "PYTHON_VERSION":       "e.g. '3.12'",
 | 
			
		||||
        "MAX_JOBS":             "e.g. '64'",
 | 
			
		||||
        "SCCACHE_BUCKET":       "e.g. 'my-bucket'",
 | 
			
		||||
        "SCCACHE_REGION":       "e.g. 'us-west-2'",
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, args=None):
 | 
			
		||||
        self.work_directory = "vllm"
 | 
			
		||||
 | 
			
		||||
    def run(self):
 | 
			
		||||
        """
 | 
			
		||||
        main function to run vllm build
 | 
			
		||||
        1. prepare vllm build environment
 | 
			
		||||
        2. prepare the docker build command args
 | 
			
		||||
        3. run docker build
 | 
			
		||||
        """
 | 
			
		||||
        inputs = VllmBuildParameters()
 | 
			
		||||
        clone_vllm()
 | 
			
		||||
 | 
			
		||||
        self.cp_dockerfile_if_exist(inputs)
 | 
			
		||||
 | 
			
		||||
        # cp torch wheels from root direct to vllm workspace if exist
 | 
			
		||||
        self.cp_torch_whls_if_exist(inputs)
 | 
			
		||||
 | 
			
		||||
        ensure_dir_exists(inputs.output_dir)
 | 
			
		||||
 | 
			
		||||
        cmd = self._generate_docker_build_cmd(inputs)
 | 
			
		||||
        logger.info("Running docker build: \n %s", cmd)
 | 
			
		||||
        run_command(cmd, cwd="vllm", env=os.environ.copy())
 | 
			
		||||
 | 
			
		||||
    def cp_torch_whls_if_exist(self, inputs: VllmBuildParameters) -> str:
 | 
			
		||||
        if not inputs.use_torch_whl:
 | 
			
		||||
            return ""
 | 
			
		||||
        tmp_dir = f"./{self.work_directory}/{_VLLM_TEMP_FOLDER}"
 | 
			
		||||
        tmp_path = Path(tmp_dir)
 | 
			
		||||
        force_create_dir(tmp_path)
 | 
			
		||||
        copy(inputs.torch_whls_path, tmp_dir)
 | 
			
		||||
        return tmp_dir
 | 
			
		||||
 | 
			
		||||
    def cp_dockerfile_if_exist(self, inputs: VllmBuildParameters):
 | 
			
		||||
        if not inputs.use_local_dockerfile:
 | 
			
		||||
            logger.info("using vllm default dockerfile.torch_nightly for build")
 | 
			
		||||
            return
 | 
			
		||||
        dockerfile_path = get_path(inputs.dockerfile_path, resolve=True)
 | 
			
		||||
        vllm_torch_dockerfile = Path(
 | 
			
		||||
            f"./{self.work_directory}/docker/Dockerfile.nightly_torch"
 | 
			
		||||
        )
 | 
			
		||||
        copy(dockerfile_path, vllm_torch_dockerfile)
 | 
			
		||||
 | 
			
		||||
    def get_result_path(self, path):
 | 
			
		||||
        """
 | 
			
		||||
        Get the absolute path of the result path
 | 
			
		||||
        """
 | 
			
		||||
        if not path:
 | 
			
		||||
            path = _DEFAULT_RESULT_PATH
 | 
			
		||||
        abs_path = get_path(path, resolve=True)
 | 
			
		||||
        return abs_path
 | 
			
		||||
 | 
			
		||||
    def _get_torch_wheel_path_arg(self, torch_whl_dir: Optional[Path]) -> str:
 | 
			
		||||
        if not torch_whl_dir:
 | 
			
		||||
            return ""
 | 
			
		||||
        return f"--build-arg TORCH_WHEELS_PATH={_VLLM_TEMP_FOLDER}"
 | 
			
		||||
 | 
			
		||||
    def _get_base_image_args(self, inputs: VllmBuildParameters) -> tuple[str, str, str]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns:
 | 
			
		||||
            - base_image_arg: docker buildx arg string for base image
 | 
			
		||||
            - final_base_image_arg:  docker buildx arg string for vllm-base stage
 | 
			
		||||
            - pull_flag: --pull=true or --pull=false depending on whether the image exists locally
 | 
			
		||||
        """
 | 
			
		||||
        if not inputs.use_local_base_image:
 | 
			
		||||
            return "", "", ""
 | 
			
		||||
 | 
			
		||||
        base_image = inputs.base_image
 | 
			
		||||
 | 
			
		||||
        # set both base image and final base image to the same local image
 | 
			
		||||
        base_image_arg = f"--build-arg BUILD_BASE_IMAGE={base_image}"
 | 
			
		||||
        final_base_image_arg = f"--build-arg FINAL_BASE_IMAGE={base_image}"
 | 
			
		||||
 | 
			
		||||
        if local_image_exists(base_image):
 | 
			
		||||
            pull_flag = "--pull=false"
 | 
			
		||||
            return base_image_arg, final_base_image_arg, pull_flag
 | 
			
		||||
        logger.info(
 | 
			
		||||
            "[INFO] Local image not found:%s will try to pull from remote", {base_image}
 | 
			
		||||
        )
 | 
			
		||||
        return base_image_arg, final_base_image_arg, ""
 | 
			
		||||
 | 
			
		||||
    def _generate_docker_build_cmd(
 | 
			
		||||
        self,
 | 
			
		||||
        inputs: VllmBuildParameters,
 | 
			
		||||
    ) -> str:
 | 
			
		||||
        base_image_arg, final_base_image_arg, pull_flag = self._get_base_image_args(
 | 
			
		||||
            inputs
 | 
			
		||||
        )
 | 
			
		||||
        torch_arg = self._get_torch_wheel_path_arg(inputs.torch_whls_path)
 | 
			
		||||
 | 
			
		||||
        return textwrap.dedent(
 | 
			
		||||
            f"""
 | 
			
		||||
            docker buildx build \
 | 
			
		||||
                --output type=local,dest={inputs.output_dir} \
 | 
			
		||||
                -f docker/Dockerfile.nightly_torch \
 | 
			
		||||
                {pull_flag} \
 | 
			
		||||
                {torch_arg} \
 | 
			
		||||
                {base_image_arg} \
 | 
			
		||||
                {final_base_image_arg} \
 | 
			
		||||
                --build-arg max_jobs={inputs.max_jobs} \
 | 
			
		||||
                --build-arg CUDA_VERSION={inputs.cuda_version} \
 | 
			
		||||
                --build-arg PYTHON_VERSION={inputs.python_version} \
 | 
			
		||||
                --build-arg USE_SCCACHE={int(bool(inputs.sccache_bucket and inputs.sccache_region))} \
 | 
			
		||||
                --build-arg SCCACHE_BUCKET_NAME={inputs.sccache_bucket} \
 | 
			
		||||
                --build-arg SCCACHE_REGION_NAME={inputs.sccache_region} \
 | 
			
		||||
                --build-arg torch_cuda_arch_list='{inputs.torch_cuda_arch_list}' \
 | 
			
		||||
                --target {inputs.target_stage} \
 | 
			
		||||
                -t {inputs.tag_name} \
 | 
			
		||||
                --progress=plain .
 | 
			
		||||
        """
 | 
			
		||||
        ).strip()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def clone_vllm():
 | 
			
		||||
    clone_external_repo(
 | 
			
		||||
        target="vllm",
 | 
			
		||||
        repo="https://github.com/vllm-project/vllm.git",
 | 
			
		||||
        dst="vllm",
 | 
			
		||||
        update_submodules=True,
 | 
			
		||||
    )
 | 
			
		||||
							
								
								
									
										38
									
								
								.ci/lumen_cli/cli/run.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								.ci/lumen_cli/cli/run.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,38 @@
 | 
			
		||||
# main.py
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from cli.build_cli.register_build import register_build_commands
 | 
			
		||||
from cli.lib.common.logger import setup_logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    # Define top-level parser
 | 
			
		||||
    parser = argparse.ArgumentParser(description="Lumos CLI")
 | 
			
		||||
    subparsers = parser.add_subparsers(dest="command", required=True)
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--log-level", default="INFO", help="Log level (DEBUG, INFO, WARNING, ERROR)"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # registers second-level subcommands
 | 
			
		||||
    register_build_commands(subparsers)
 | 
			
		||||
 | 
			
		||||
    # parse args after all options are registered
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    # setup global logging
 | 
			
		||||
    setup_logging(getattr(logging, args.log_level.upper(), logging.INFO))
 | 
			
		||||
    logger.debug("Parsed args: %s", args)
 | 
			
		||||
 | 
			
		||||
    if hasattr(args, "func"):
 | 
			
		||||
        args.func(args)
 | 
			
		||||
    else:
 | 
			
		||||
        parser.print_help()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										22
									
								
								.ci/lumen_cli/pyproject.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								.ci/lumen_cli/pyproject.toml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,22 @@
 | 
			
		||||
[project]
 | 
			
		||||
name = "lumen-ci"
 | 
			
		||||
version = "0.1.0"
 | 
			
		||||
dependencies = [
 | 
			
		||||
    "pyyaml==6.0.2",
 | 
			
		||||
    "GitPython==3.1.45",
 | 
			
		||||
    "docker==7.1.0",
 | 
			
		||||
    "pytest==7.3.2",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
[tool.setuptools]
 | 
			
		||||
packages = ["cli"]
 | 
			
		||||
 | 
			
		||||
[tool.setuptools.package-dir]
 | 
			
		||||
cli = "cli"
 | 
			
		||||
 | 
			
		||||
[tool.ruff.lint]
 | 
			
		||||
# Enable preview mode for linting
 | 
			
		||||
preview = true
 | 
			
		||||
 | 
			
		||||
# Now you can select your preview rules, like RUF048
 | 
			
		||||
extend-select = ["RUF048"]
 | 
			
		||||
							
								
								
									
										47
									
								
								.ci/lumen_cli/tests/test_app.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								.ci/lumen_cli/tests/test_app.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,47 @@
 | 
			
		||||
# tests/test_cli.py
 | 
			
		||||
import io
 | 
			
		||||
import sys
 | 
			
		||||
import unittest
 | 
			
		||||
from contextlib import redirect_stderr, redirect_stdout
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
 | 
			
		||||
from cli.run import main
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestArgparseCLI(unittest.TestCase):
 | 
			
		||||
    @patch("cli.build_cli.register_build.VllmBuildRunner.run", return_value=None)
 | 
			
		||||
    @patch("cli.build_cli.register_build.VllmBuildRunner.__init__", return_value=None)
 | 
			
		||||
    def test_cli_run_build_external(self, mock_init, mock_run):
 | 
			
		||||
        from cli.run import main  # import after patches if needed
 | 
			
		||||
 | 
			
		||||
        test_args = ["cli.run", "build", "external", "vllm"]
 | 
			
		||||
        with patch.object(sys, "argv", test_args):
 | 
			
		||||
            # argparse may call sys.exit on error; capture to avoid test aborts
 | 
			
		||||
            try:
 | 
			
		||||
                main()
 | 
			
		||||
            except SystemExit:
 | 
			
		||||
                pass
 | 
			
		||||
        mock_init.assert_called_once()  # got constructed
 | 
			
		||||
        mock_run.assert_called_once_with()  # run() called
 | 
			
		||||
 | 
			
		||||
    def test_build_help(self):
 | 
			
		||||
        test_args = ["cli.run", "build", "--help"]
 | 
			
		||||
 | 
			
		||||
        with patch.object(sys, "argv", test_args):
 | 
			
		||||
            stdout = io.StringIO()
 | 
			
		||||
            stderr = io.StringIO()
 | 
			
		||||
 | 
			
		||||
            # --help always raises SystemExit(0)
 | 
			
		||||
            with self.assertRaises(SystemExit) as cm:
 | 
			
		||||
                with redirect_stdout(stdout), redirect_stderr(stderr):
 | 
			
		||||
                    main()
 | 
			
		||||
 | 
			
		||||
            self.assertEqual(cm.exception.code, 0)
 | 
			
		||||
 | 
			
		||||
            output = stdout.getvalue()
 | 
			
		||||
            self.assertIn("usage", output)
 | 
			
		||||
            self.assertIn("external", output)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
							
								
								
									
										115
									
								
								.ci/lumen_cli/tests/test_cli_helper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										115
									
								
								.ci/lumen_cli/tests/test_cli_helper.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,115 @@
 | 
			
		||||
import argparse
 | 
			
		||||
import io
 | 
			
		||||
import unittest
 | 
			
		||||
from contextlib import redirect_stderr
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
 | 
			
		||||
from cli.lib.common.cli_helper import BaseRunner, register_targets, RichHelp, TargetSpec
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ---- Dummy runners for unittests----
 | 
			
		||||
class FooRunner(BaseRunner):
 | 
			
		||||
    """Foo description from docstring."""
 | 
			
		||||
 | 
			
		||||
    def run(self) -> None:  # replaced by mock
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BarRunner(BaseRunner):
 | 
			
		||||
    def run(self) -> None:  # replaced by mock
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def add_foo_args(p: argparse.ArgumentParser) -> None:
 | 
			
		||||
    p.add_argument("--x", type=int, required=True, help="x value")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def common_args(p: argparse.ArgumentParser) -> None:
 | 
			
		||||
    p.add_argument("--verbose", action="store_true", help="verbose flag")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def build_parser(specs: dict[str, TargetSpec]) -> argparse.ArgumentParser:
 | 
			
		||||
    parser = argparse.ArgumentParser(prog="app", formatter_class=RichHelp)
 | 
			
		||||
    register_targets(
 | 
			
		||||
        parser=parser,
 | 
			
		||||
        target_specs=specs,
 | 
			
		||||
        common_args=common_args,
 | 
			
		||||
    )
 | 
			
		||||
    return parser
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_subparser(
 | 
			
		||||
    parser: argparse.ArgumentParser, name: str
 | 
			
		||||
) -> argparse.ArgumentParser:
 | 
			
		||||
    subparsers_action = next(
 | 
			
		||||
        a
 | 
			
		||||
        for a in parser._subparsers._group_actions  # type: ignore[attr-defined]
 | 
			
		||||
        if isinstance(a, argparse._SubParsersAction)
 | 
			
		||||
    )
 | 
			
		||||
    return subparsers_action.choices[name]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestRegisterTargets(unittest.TestCase):
 | 
			
		||||
    def test_metavar_lists_targets(self):
 | 
			
		||||
        specs: dict[str, TargetSpec] = {
 | 
			
		||||
            "foo": {"runner": FooRunner, "add_arguments": add_foo_args},
 | 
			
		||||
            "bar": {"runner": BarRunner},
 | 
			
		||||
        }
 | 
			
		||||
        parser = build_parser(specs)
 | 
			
		||||
        subparsers_action = next(
 | 
			
		||||
            a
 | 
			
		||||
            for a in parser._subparsers._group_actions  # type: ignore[attr-defined]
 | 
			
		||||
            if isinstance(a, argparse._SubParsersAction)
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(subparsers_action.metavar, "{foo,bar}")
 | 
			
		||||
 | 
			
		||||
    def test_add_arguments_and_common_args_present(self):
 | 
			
		||||
        specs: dict[str, TargetSpec] = {
 | 
			
		||||
            "foo": {"runner": FooRunner, "add_arguments": add_foo_args},
 | 
			
		||||
        }
 | 
			
		||||
        parser = build_parser(specs)
 | 
			
		||||
        foo = get_subparser(parser, "foo")
 | 
			
		||||
        help_text = foo.format_help()
 | 
			
		||||
        self.assertIn("--x", help_text)
 | 
			
		||||
        self.assertIn("--verbose", help_text)
 | 
			
		||||
 | 
			
		||||
    def test_runner_constructed_with_ns_and_run_called(self):
 | 
			
		||||
        specs: dict[str, TargetSpec] = {
 | 
			
		||||
            "foo": {"runner": FooRunner, "add_arguments": add_foo_args},
 | 
			
		||||
        }
 | 
			
		||||
        parser = build_parser(specs)
 | 
			
		||||
 | 
			
		||||
        with (
 | 
			
		||||
            patch.object(FooRunner, "__init__", return_value=None) as mock_init,
 | 
			
		||||
            patch.object(FooRunner, "run", return_value=None) as mock_run,
 | 
			
		||||
        ):
 | 
			
		||||
            ns = parser.parse_args(["foo", "--x", "3", "--verbose"])
 | 
			
		||||
            ns.func(ns)  # set by register_targets
 | 
			
		||||
            # __init__ received the Namespace
 | 
			
		||||
            self.assertEqual(mock_init.call_count, 1)
 | 
			
		||||
            (called_ns,), _ = mock_init.call_args
 | 
			
		||||
            self.assertIsInstance(called_ns, argparse.Namespace)
 | 
			
		||||
            # run() called with no args
 | 
			
		||||
            mock_run.assert_called_once_with()
 | 
			
		||||
 | 
			
		||||
    def test_runner_docstring_used_as_description_when_missing(self):
 | 
			
		||||
        specs: dict[str, TargetSpec] = {
 | 
			
		||||
            "foo": {"runner": FooRunner, "add_arguments": add_foo_args},
 | 
			
		||||
        }
 | 
			
		||||
        parser = build_parser(specs)
 | 
			
		||||
        foo = get_subparser(parser, "foo")
 | 
			
		||||
        help_text = foo.format_help()
 | 
			
		||||
        self.assertIn("Foo description from docstring.", help_text)
 | 
			
		||||
 | 
			
		||||
    def test_missing_target_raises_systemexit_with_usage(self):
 | 
			
		||||
        specs: dict[str, TargetSpec] = {"foo": {"runner": FooRunner}}
 | 
			
		||||
        parser = build_parser(specs)
 | 
			
		||||
        buf = io.StringIO()
 | 
			
		||||
        with self.assertRaises(SystemExit), redirect_stderr(buf):
 | 
			
		||||
            parser.parse_args([])
 | 
			
		||||
        err = buf.getvalue()
 | 
			
		||||
        self.assertIn("usage:", err)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
							
								
								
									
										75
									
								
								.ci/lumen_cli/tests/test_docker_helper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								.ci/lumen_cli/tests/test_docker_helper.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,75 @@
 | 
			
		||||
import unittest
 | 
			
		||||
from unittest import mock
 | 
			
		||||
from unittest.mock import MagicMock
 | 
			
		||||
 | 
			
		||||
import docker.errors as derr
 | 
			
		||||
from cli.lib.common.docker_helper import _get_client, local_image_exists
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestDockerImageHelpers(unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        # Reset the singleton in the target module
 | 
			
		||||
        patcher = mock.patch("cli.lib.common.docker_helper._docker_client", None)
 | 
			
		||||
        self.addCleanup(patcher.stop)
 | 
			
		||||
        patcher.start()
 | 
			
		||||
 | 
			
		||||
    def test_local_image_exists_true(self):
 | 
			
		||||
        # Mock a docker client whose images.get returns an object (no exception)
 | 
			
		||||
        mock_client = MagicMock()
 | 
			
		||||
        mock_client.images.get.return_value = object()
 | 
			
		||||
        ok = local_image_exists("repo:tag", client=mock_client)
 | 
			
		||||
        self.assertTrue(ok)
 | 
			
		||||
 | 
			
		||||
    def test_local_image_exists_not_found_false(self):
 | 
			
		||||
        mock_client = MagicMock()
 | 
			
		||||
        # Raise docker.errors.NotFound
 | 
			
		||||
        mock_client.images.get.side_effect = derr.NotFound("nope")
 | 
			
		||||
        ok = local_image_exists("missing:latest", client=mock_client)
 | 
			
		||||
        self.assertFalse(ok)
 | 
			
		||||
 | 
			
		||||
    def test_local_image_exists_api_error_false(self):
 | 
			
		||||
        mock_client = MagicMock()
 | 
			
		||||
        mock_client.images.get.side_effect = derr.APIError("boom", None)
 | 
			
		||||
 | 
			
		||||
        ok = local_image_exists("broken:tag", client=mock_client)
 | 
			
		||||
        self.assertFalse(ok)
 | 
			
		||||
 | 
			
		||||
    def test_local_image_exists_uses_lazy_singleton(self):
 | 
			
		||||
        # Patch docker.from_env used by _get_client()
 | 
			
		||||
        with mock.patch(
 | 
			
		||||
            "cli.lib.common.docker_helper.docker.from_env"
 | 
			
		||||
        ) as mock_from_env:
 | 
			
		||||
            mock_docker_client = MagicMock()
 | 
			
		||||
            mock_from_env.return_value = mock_docker_client
 | 
			
		||||
 | 
			
		||||
            # First call should create and cache the client
 | 
			
		||||
            c1 = _get_client()
 | 
			
		||||
            self.assertIs(c1, mock_docker_client)
 | 
			
		||||
            mock_from_env.assert_called_once()
 | 
			
		||||
 | 
			
		||||
            # Second call should reuse cached client (no extra from_env calls)
 | 
			
		||||
            c2 = _get_client()
 | 
			
		||||
            self.assertIs(c2, mock_docker_client)
 | 
			
		||||
            mock_from_env.assert_called_once()  # still once
 | 
			
		||||
 | 
			
		||||
    def test_local_image_exists_without_client_param_calls_get_client_once(self):
 | 
			
		||||
        # Ensure _get_client is called and cached; local_image_exists should reuse it
 | 
			
		||||
        with mock.patch("cli.lib.common.docker_helper._get_client") as mock_get_client:
 | 
			
		||||
            mock_client = MagicMock()
 | 
			
		||||
            mock_get_client.return_value = mock_client
 | 
			
		||||
 | 
			
		||||
            # 1st call
 | 
			
		||||
            local_image_exists("repo:tag")
 | 
			
		||||
            # 2nd call
 | 
			
		||||
            local_image_exists("repo:tag2")
 | 
			
		||||
 | 
			
		||||
            # local_image_exists should call _get_client each time,
 | 
			
		||||
            # but your _get_client itself caches docker.from_env.
 | 
			
		||||
            self.assertEqual(mock_get_client.call_count, 2)
 | 
			
		||||
            self.assertEqual(mock_client.images.get.call_count, 2)
 | 
			
		||||
            mock_client.images.get.assert_any_call("repo:tag")
 | 
			
		||||
            mock_client.images.get.assert_any_call("repo:tag2")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
							
								
								
									
										149
									
								
								.ci/lumen_cli/tests/test_envs_helper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										149
									
								
								.ci/lumen_cli/tests/test_envs_helper.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,149 @@
 | 
			
		||||
import os
 | 
			
		||||
import unittest
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
 | 
			
		||||
import cli.lib.common.envs_helper as m
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestEnvHelpers(unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        # Keep a copy of the original environment to restore later
 | 
			
		||||
        self._env_backup = dict(os.environ)
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        # Restore environment to original state
 | 
			
		||||
        os.environ.clear()
 | 
			
		||||
        os.environ.update(self._env_backup)
 | 
			
		||||
 | 
			
		||||
    # -------- get_env --------
 | 
			
		||||
    def test_get_env_unset_returns_default(self):
 | 
			
		||||
        with patch.dict(os.environ, {}, clear=True):
 | 
			
		||||
            self.assertEqual(m.get_env("FOO", "default"), "default")
 | 
			
		||||
 | 
			
		||||
    def test_get_env_empty_returns_default(self):
 | 
			
		||||
        with patch.dict(os.environ, {"FOO": ""}, clear=True):
 | 
			
		||||
            self.assertEqual(m.get_env("FOO", "default"), "default")
 | 
			
		||||
 | 
			
		||||
    def test_get_env_set_returns_value(self):
 | 
			
		||||
        with patch.dict(os.environ, {"FOO": "bar"}, clear=True):
 | 
			
		||||
            self.assertEqual(m.get_env("FOO", "default"), "bar")
 | 
			
		||||
 | 
			
		||||
    def test_get_env_not_exist_returns_default(self):
 | 
			
		||||
        with patch.dict(os.environ, {"FOO": "bar"}, clear=True):
 | 
			
		||||
            self.assertEqual(m.get_env("TEST_NOT_EXIST", "default"), "default")
 | 
			
		||||
 | 
			
		||||
    def test_get_env_not_exist_without_default(self):
 | 
			
		||||
        with patch.dict(os.environ, {"FOO": "bar"}, clear=True):
 | 
			
		||||
            self.assertEqual(m.get_env("TEST_NOT_EXIST"), "")
 | 
			
		||||
 | 
			
		||||
    # -------- env_bool --------
 | 
			
		||||
    def test_env_bool_uses_default_when_unset(self):
 | 
			
		||||
        with patch.dict(os.environ, {}, clear=True):
 | 
			
		||||
            self.assertTrue(m.env_bool("FLAG", default=True))
 | 
			
		||||
            self.assertFalse(m.env_bool("FLAG", default=False))
 | 
			
		||||
 | 
			
		||||
    def test_env_bool_uses_str2bool_when_set(self):
 | 
			
		||||
        # Patch str2bool used by env_bool so we don't depend on its exact behavior
 | 
			
		||||
        def fake_str2bool(s: str) -> bool:
 | 
			
		||||
            return s.lower() in {"1", "true", "yes", "on", "y"}
 | 
			
		||||
 | 
			
		||||
        with (
 | 
			
		||||
            patch.dict(os.environ, {"FLAG": "yEs"}, clear=True),
 | 
			
		||||
            patch.object(m, "str2bool", fake_str2bool),
 | 
			
		||||
        ):
 | 
			
		||||
            self.assertTrue(m.env_bool("FLAG", default=False))
 | 
			
		||||
 | 
			
		||||
    # -------- env_path_optional / env_path --------
 | 
			
		||||
    def test_env_path_optional_unset_returns_none_by_default(self):
 | 
			
		||||
        with patch.dict(os.environ, {}, clear=True):
 | 
			
		||||
            self.assertIsNone(m.env_path_optional("P"))
 | 
			
		||||
 | 
			
		||||
    def test_env_path_optional_unset_returns_none_when_env_var_is_empty(self):
 | 
			
		||||
        with patch.dict(os.environ, {"P": ""}, clear=True):
 | 
			
		||||
            self.assertIsNone(m.env_path_optional("P"))
 | 
			
		||||
 | 
			
		||||
    def test_env_path_optional_unset_returns_default_str(self):
 | 
			
		||||
        # default as string; resolve=True by default -> absolute path
 | 
			
		||||
        default_str = "x/y"
 | 
			
		||||
        with patch.dict(os.environ, {}, clear=True):
 | 
			
		||||
            p = m.env_path_optional("P", default=default_str)
 | 
			
		||||
            self.assertIsInstance(p, Path)
 | 
			
		||||
            self.assertIsNotNone(p)
 | 
			
		||||
            if p:
 | 
			
		||||
                self.assertTrue(p.is_absolute())
 | 
			
		||||
                self.assertEqual(p.parts[-2:], ("x", "y"))
 | 
			
		||||
 | 
			
		||||
    def test_env_path_optional_unset_returns_default_path_no_resolve(self):
 | 
			
		||||
        d = Path("z")
 | 
			
		||||
        with patch.dict(os.environ, {}, clear=True):
 | 
			
		||||
            p = m.env_path_optional("P", default=d, resolve=False)
 | 
			
		||||
            self.assertEqual(p, d)
 | 
			
		||||
 | 
			
		||||
    def test_env_path_optional_respects_resolve_true(self):
 | 
			
		||||
        with patch.dict(os.environ, {"P": "a/b"}, clear=True):
 | 
			
		||||
            p = m.env_path_optional("P", resolve=True)
 | 
			
		||||
            self.assertIsInstance(p, Path)
 | 
			
		||||
            if p:
 | 
			
		||||
                self.assertTrue(p.is_absolute())
 | 
			
		||||
 | 
			
		||||
    def test_env_path_optional_respects_resolve_false(self):
 | 
			
		||||
        with patch.dict(os.environ, {"P": "rel/dir"}, clear=True):
 | 
			
		||||
            p = m.env_path_optional("P", resolve=False)
 | 
			
		||||
            self.assertEqual(p, Path("rel/dir"))
 | 
			
		||||
            if p:
 | 
			
		||||
                self.assertFalse(p.is_absolute())
 | 
			
		||||
 | 
			
		||||
    def test_env_path_raises_when_missing_and_default_none(self):
 | 
			
		||||
        with patch.dict(os.environ, {}, clear=True):
 | 
			
		||||
            with self.assertRaises(ValueError):
 | 
			
		||||
                m.env_path("P", None, resolve=True)
 | 
			
		||||
 | 
			
		||||
    def test_env_path_returns_path_when_present(self):
 | 
			
		||||
        tmp = Path("./b").resolve()
 | 
			
		||||
        with patch.dict(os.environ, {"P": str(tmp)}, clear=True):
 | 
			
		||||
            p = m.env_path("P", None, resolve=True)
 | 
			
		||||
            self.assertEqual(p, tmp)
 | 
			
		||||
 | 
			
		||||
    # -------- dataclass field helpers --------
 | 
			
		||||
    def test_dataclass_fields_read_env_at_instantiation(self):
 | 
			
		||||
        @dataclass
 | 
			
		||||
        class Cfg:
 | 
			
		||||
            flag: bool = m.env_bool_field("FLAG", default=False)
 | 
			
		||||
            out: Path = m.env_path_field("OUT", default="ab", resolve=True)
 | 
			
		||||
            name: str = m.env_str_field("NAME", default="anon")
 | 
			
		||||
 | 
			
		||||
        # First instantiation
 | 
			
		||||
        with patch.dict(
 | 
			
		||||
            os.environ, {"FLAG": "true", "OUT": "outdir", "NAME": "alice"}, clear=True
 | 
			
		||||
        ):
 | 
			
		||||
            cfg1 = Cfg()
 | 
			
		||||
            self.assertTrue(cfg1.flag)
 | 
			
		||||
            self.assertIsInstance(cfg1.out, Path)
 | 
			
		||||
            self.assertTrue(cfg1.out.is_absolute())
 | 
			
		||||
            self.assertEqual(cfg1.name, "alice")
 | 
			
		||||
            cfg1.name = "bob"  # change instance value
 | 
			
		||||
            self.assertEqual(cfg1.name, "bob")  # change is reflected
 | 
			
		||||
 | 
			
		||||
        # Change env; new instance should reflect new values
 | 
			
		||||
        with patch.dict(os.environ, {"FLAG": "false", "NAME": ""}, clear=True):
 | 
			
		||||
            cfg2 = Cfg()
 | 
			
		||||
            self.assertFalse(cfg2.flag)  # str2bool("false") -> False
 | 
			
		||||
            self.assertTrue("ab" in str(cfg2.out))
 | 
			
		||||
            self.assertIsInstance(cfg2.out, Path)
 | 
			
		||||
            self.assertTrue(cfg2.out.is_absolute())
 | 
			
		||||
            self.assertEqual(cfg2.name, "anon")  # empty -> fallback to default
 | 
			
		||||
 | 
			
		||||
    def test_dataclass_path_field_with_default_value(self):
 | 
			
		||||
        @dataclass
 | 
			
		||||
        class C2:
 | 
			
		||||
            out: Path = m.env_path_field("OUT", default="some/dir", resolve=False)
 | 
			
		||||
 | 
			
		||||
        with patch.dict(os.environ, {}, clear=True):
 | 
			
		||||
            c = C2()
 | 
			
		||||
            self.assertEqual(c.out, Path("some/dir"))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
							
								
								
									
										122
									
								
								.ci/lumen_cli/tests/test_path_helper.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										122
									
								
								.ci/lumen_cli/tests/test_path_helper.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,122 @@
 | 
			
		||||
# test_path_utils.py
 | 
			
		||||
# Run: pytest -q
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import unittest
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from tempfile import TemporaryDirectory
 | 
			
		||||
 | 
			
		||||
from cli.lib.common.path_helper import (
 | 
			
		||||
    copy,
 | 
			
		||||
    ensure_dir_exists,
 | 
			
		||||
    force_create_dir,
 | 
			
		||||
    get_path,
 | 
			
		||||
    is_path_exist,
 | 
			
		||||
    remove_dir,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestPathHelper(unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.tmpdir = TemporaryDirectory()
 | 
			
		||||
        self.tmp_path = Path(self.tmpdir.name)
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        self.tmpdir.cleanup()
 | 
			
		||||
 | 
			
		||||
    # -------- get_path --------
 | 
			
		||||
    def test_get_path_returns_path_for_str(self):
 | 
			
		||||
        # Use relative path to avoid absolute-ness
 | 
			
		||||
        rel_str = "sub/f.txt"
 | 
			
		||||
        os.chdir(self.tmp_path)
 | 
			
		||||
        p = get_path(rel_str, resolve=False)
 | 
			
		||||
        self.assertIsInstance(p, Path)
 | 
			
		||||
        self.assertFalse(p.is_absolute())
 | 
			
		||||
        self.assertEqual(str(p), rel_str)
 | 
			
		||||
 | 
			
		||||
    def test_get_path_resolves(self):
 | 
			
		||||
        rel_str = "sub/f.txt"
 | 
			
		||||
        p = get_path(str(self.tmp_path / rel_str), resolve=True)
 | 
			
		||||
        self.assertTrue(p.is_absolute())
 | 
			
		||||
        self.assertTrue(str(p).endswith(rel_str))
 | 
			
		||||
 | 
			
		||||
    def test_get_path_with_path_input(self):
 | 
			
		||||
        p_in = self.tmp_path / "sub/f.txt"
 | 
			
		||||
        p_out = get_path(p_in, resolve=False)
 | 
			
		||||
        self.assertTrue(str(p_out) == str(p_in))
 | 
			
		||||
 | 
			
		||||
    def test_get_path_with_none_raises(self):
 | 
			
		||||
        with self.assertRaises(ValueError):
 | 
			
		||||
            get_path(None)  # type: ignore[arg-type]
 | 
			
		||||
 | 
			
		||||
    def test_get_path_invalid_type_raises(self):
 | 
			
		||||
        with self.assertRaises(TypeError):
 | 
			
		||||
            get_path(123)  # type: ignore[arg-type]
 | 
			
		||||
 | 
			
		||||
    # -------- ensure_dir_exists / force_create_dir / remove_dir --------
 | 
			
		||||
    def test_ensure_dir_exists_creates_and_is_idempotent(self):
 | 
			
		||||
        d = self.tmp_path / "made"
 | 
			
		||||
        ensure_dir_exists(d)
 | 
			
		||||
        self.assertTrue(d.exists() and d.is_dir())
 | 
			
		||||
        ensure_dir_exists(d)
 | 
			
		||||
 | 
			
		||||
    def test_force_create_dir_clears_existing(self):
 | 
			
		||||
        d = self.tmp_path / "fresh"
 | 
			
		||||
        (d / "inner").mkdir(parents=True)
 | 
			
		||||
        (d / "inner" / "f.txt").write_text("x")
 | 
			
		||||
        force_create_dir(d)
 | 
			
		||||
        self.assertTrue(d.exists())
 | 
			
		||||
        self.assertEqual(list(d.iterdir()), [])
 | 
			
		||||
 | 
			
		||||
    def test_remove_dir_none_is_noop(self):
 | 
			
		||||
        remove_dir(None)  # type: ignore[arg-type]
 | 
			
		||||
 | 
			
		||||
    def test_remove_dir_nonexistent_is_noop(self):
 | 
			
		||||
        ghost = self.tmp_path / "ghost"
 | 
			
		||||
        remove_dir(ghost)
 | 
			
		||||
 | 
			
		||||
    def test_remove_dir_accepts_str(self):
 | 
			
		||||
        d = self.tmp_path / "to_rm"
 | 
			
		||||
        d.mkdir()
 | 
			
		||||
        remove_dir(str(d))
 | 
			
		||||
        self.assertFalse(d.exists())
 | 
			
		||||
 | 
			
		||||
    # -------- copy --------
 | 
			
		||||
    def test_copy_file_to_file(self):
 | 
			
		||||
        src = self.tmp_path / "src.txt"
 | 
			
		||||
        dst = self.tmp_path / "out" / "dst.txt"
 | 
			
		||||
        src.write_text("hello")
 | 
			
		||||
        copy(src, dst)
 | 
			
		||||
        self.assertEqual(dst.read_text(), "hello")
 | 
			
		||||
 | 
			
		||||
    def test_copy_dir_to_new_dir(self):
 | 
			
		||||
        src = self.tmp_path / "srcdir"
 | 
			
		||||
        (src / "a").mkdir(parents=True)
 | 
			
		||||
        (src / "a" / "f.txt").write_text("content")
 | 
			
		||||
        dst = self.tmp_path / "destdir"
 | 
			
		||||
        copy(src, dst)
 | 
			
		||||
        self.assertEqual((dst / "a" / "f.txt").read_text(), "content")
 | 
			
		||||
 | 
			
		||||
    def test_copy_dir_into_existing_dir_overwrite_true_merges(self):
 | 
			
		||||
        src = self.tmp_path / "srcdir"
 | 
			
		||||
        dst = self.tmp_path / "destdir"
 | 
			
		||||
        (src / "x").mkdir(parents=True)
 | 
			
		||||
        (src / "x" / "new.txt").write_text("new")
 | 
			
		||||
        dst.mkdir()
 | 
			
		||||
        (dst / "existing.txt").write_text("old")
 | 
			
		||||
        copy(src, dst)
 | 
			
		||||
        self.assertEqual((dst / "existing.txt").read_text(), "old")
 | 
			
		||||
        self.assertEqual((dst / "x" / "new.txt").read_text(), "new")
 | 
			
		||||
 | 
			
		||||
    def test_is_str_path_exist(self):
 | 
			
		||||
        p = self.tmp_path / "x.txt"
 | 
			
		||||
        p.write_text("1")
 | 
			
		||||
        self.assertTrue(is_path_exist(str(p)))
 | 
			
		||||
        self.assertTrue(is_path_exist(p))
 | 
			
		||||
        self.assertFalse(is_path_exist(str(self.tmp_path / "missing")))
 | 
			
		||||
        self.assertFalse(is_path_exist(self.tmp_path / "missing"))
 | 
			
		||||
        self.assertFalse(is_path_exist(""))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
							
								
								
									
										181
									
								
								.ci/lumen_cli/tests/test_vllm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										181
									
								
								.ci/lumen_cli/tests/test_vllm.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,181 @@
 | 
			
		||||
import os
 | 
			
		||||
import tempfile
 | 
			
		||||
import unittest
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from unittest.mock import MagicMock, patch
 | 
			
		||||
 | 
			
		||||
import cli.lib.core.vllm as vllm
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestVllmBuildParameters(unittest.TestCase):
 | 
			
		||||
    @patch("cli.lib.core.vllm.local_image_exists", return_value=True)
 | 
			
		||||
    @patch("cli.lib.core.vllm.is_path_exist", return_value=True)
 | 
			
		||||
    @patch(
 | 
			
		||||
        "cli.lib.common.envs_helper.env_path_optional",
 | 
			
		||||
        side_effect=lambda name, default=None, resolve=True: {
 | 
			
		||||
            "DOCKERFILE_PATH": Path("/abs/vllm/Dockerfile"),
 | 
			
		||||
            "TORCH_WHEELS_PATH": Path("/abs/dist"),
 | 
			
		||||
            "OUTPUT_DIR": Path("/abs/shared"),
 | 
			
		||||
        }.get(name, Path(default) if default is not None else None),
 | 
			
		||||
    )
 | 
			
		||||
    @patch.dict(
 | 
			
		||||
        os.environ,
 | 
			
		||||
        {
 | 
			
		||||
            "USE_TORCH_WHEEL": "1",
 | 
			
		||||
            "USE_LOCAL_BASE_IMAGE": "1",
 | 
			
		||||
            "USE_LOCAL_DOCKERFILE": "1",
 | 
			
		||||
            "BASE_IMAGE": "my/image:tag",
 | 
			
		||||
            "DOCKERFILE_PATH": "vllm/Dockerfile",
 | 
			
		||||
            "TORCH_WHEELS_PATH": "dist",
 | 
			
		||||
            "OUTPUT_DIR": "shared",
 | 
			
		||||
        },
 | 
			
		||||
        clear=True,
 | 
			
		||||
    )
 | 
			
		||||
    def test_params_success_normalizes_and_validates(
 | 
			
		||||
        self, mock_env_path, mock_is_path, mock_local_img
 | 
			
		||||
    ):
 | 
			
		||||
        params = vllm.VllmBuildParameters()
 | 
			
		||||
        self.assertEqual(params.torch_whls_path, Path("/abs/dist"))
 | 
			
		||||
        self.assertEqual(params.dockerfile_path, Path("/abs/vllm/Dockerfile"))
 | 
			
		||||
        self.assertEqual(params.output_dir, Path("/abs/shared"))
 | 
			
		||||
        self.assertEqual(params.base_image, "my/image:tag")
 | 
			
		||||
 | 
			
		||||
    @patch("cli.lib.core.vllm.is_path_exist", return_value=False)
 | 
			
		||||
    @patch.dict(
 | 
			
		||||
        os.environ, {"USE_TORCH_WHEEL": "1", "TORCH_WHEELS_PATH": "dist"}, clear=True
 | 
			
		||||
    )
 | 
			
		||||
    def test_params_missing_torch_whls_raises(self, _is_path):
 | 
			
		||||
        with tempfile.TemporaryDirectory() as td:
 | 
			
		||||
            os.chdir(td)
 | 
			
		||||
            with self.assertRaises(ValueError) as cm:
 | 
			
		||||
                vllm.VllmBuildParameters(
 | 
			
		||||
                    use_local_base_image=False,
 | 
			
		||||
                    use_local_dockerfile=False,
 | 
			
		||||
                )
 | 
			
		||||
        err = cm.exception
 | 
			
		||||
        self.assertIn("TORCH_WHEELS_PATH", str(err))
 | 
			
		||||
 | 
			
		||||
    @patch("cli.lib.core.vllm.local_image_exists", return_value=False)
 | 
			
		||||
    @patch.dict(
 | 
			
		||||
        os.environ, {"USE_LOCAL_BASE_IMAGE": "1", "BASE_IMAGE": "img:tag"}, clear=True
 | 
			
		||||
    )
 | 
			
		||||
    def test_params_missing_local_base_image_raises(self, _local_img):
 | 
			
		||||
        with tempfile.TemporaryDirectory() as td:
 | 
			
		||||
            os.chdir(td)
 | 
			
		||||
            with self.assertRaises(ValueError) as cm:
 | 
			
		||||
                vllm.VllmBuildParameters(
 | 
			
		||||
                    use_torch_whl=False,
 | 
			
		||||
                    use_local_dockerfile=False,
 | 
			
		||||
                )
 | 
			
		||||
        err = cm.exception
 | 
			
		||||
        self.assertIn("BASE_IMAGE", str(err))
 | 
			
		||||
 | 
			
		||||
    @patch("cli.lib.core.vllm.is_path_exist", return_value=False)
 | 
			
		||||
    @patch.dict(
 | 
			
		||||
        os.environ,
 | 
			
		||||
        {"USE_LOCAL_DOCKERFILE": "1", "DOCKERFILE_PATH": "Dockerfile"},
 | 
			
		||||
        clear=True,
 | 
			
		||||
    )
 | 
			
		||||
    def test_params_missing_dockerfile_raises(self, _is_path):
 | 
			
		||||
        with tempfile.TemporaryDirectory() as td:
 | 
			
		||||
            os.chdir(td)
 | 
			
		||||
            with self.assertRaises(ValueError) as cm:
 | 
			
		||||
                vllm.VllmBuildParameters(
 | 
			
		||||
                    use_torch_whl=False,
 | 
			
		||||
                    use_local_base_image=False,
 | 
			
		||||
                )
 | 
			
		||||
        err = cm.exception
 | 
			
		||||
        self.assertIn("DOCKERFILE_PATH", str(err))
 | 
			
		||||
 | 
			
		||||
    @patch("cli.lib.core.vllm.is_path_exist", return_value=False)
 | 
			
		||||
    @patch.dict(
 | 
			
		||||
        os.environ,
 | 
			
		||||
        {"OUTPUT_DIR": ""},
 | 
			
		||||
        clear=True,
 | 
			
		||||
    )
 | 
			
		||||
    def test_params_missing_output_dir(self, _is_path):
 | 
			
		||||
        with self.assertRaises(FileNotFoundError):
 | 
			
		||||
            vllm.VllmBuildParameters()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestBuildCmdAndRun(unittest.TestCase):
 | 
			
		||||
    @patch("cli.lib.core.vllm.local_image_exists", return_value=True)
 | 
			
		||||
    def test_generate_docker_build_cmd_includes_bits(self, _exists):
 | 
			
		||||
        runner = vllm.VllmBuildRunner()
 | 
			
		||||
        # Craft inputs that simulate a prepared build
 | 
			
		||||
        inputs = MagicMock()
 | 
			
		||||
        inputs.output_dir = Path("/abs/out")
 | 
			
		||||
        inputs.use_local_base_image = True
 | 
			
		||||
        inputs.base_image = "img:tag"
 | 
			
		||||
        inputs.torch_whls_path = Path("./vllm/tmp")
 | 
			
		||||
        inputs.max_jobs = 64
 | 
			
		||||
        inputs.cuda_version = "12.8.1"
 | 
			
		||||
        inputs.python_version = "3.12"
 | 
			
		||||
        inputs.sccache_bucket = "my-bucket"
 | 
			
		||||
        inputs.sccache_region = "us-west-2"
 | 
			
		||||
        inputs.torch_cuda_arch_list = "8.0;9.0"
 | 
			
		||||
        inputs.target_stage = "export-wheels"
 | 
			
		||||
        inputs.tag_name = "vllm-wheels"
 | 
			
		||||
 | 
			
		||||
        cmd = runner._generate_docker_build_cmd(inputs)
 | 
			
		||||
        squashed = " ".join(cmd.split())  # normalize whitespace for matching
 | 
			
		||||
 | 
			
		||||
        self.assertIn("--output type=local,dest=/abs/out", squashed)
 | 
			
		||||
        self.assertIn("-f docker/Dockerfile.nightly_torch", squashed)
 | 
			
		||||
        self.assertIn("--pull=false", squashed)
 | 
			
		||||
        self.assertIn("--build-arg TORCH_WHEELS_PATH=tmp", squashed)
 | 
			
		||||
        self.assertIn("--build-arg BUILD_BASE_IMAGE=img:tag", squashed)
 | 
			
		||||
        self.assertIn("--build-arg FINAL_BASE_IMAGE=img:tag", squashed)
 | 
			
		||||
        self.assertIn("--build-arg max_jobs=64", squashed)
 | 
			
		||||
        self.assertIn("--build-arg CUDA_VERSION=12.8.1", squashed)
 | 
			
		||||
        self.assertIn("--build-arg PYTHON_VERSION=3.12", squashed)
 | 
			
		||||
        self.assertIn("--build-arg USE_SCCACHE=1", squashed)
 | 
			
		||||
        self.assertIn("--build-arg SCCACHE_BUCKET_NAME=my-bucket", squashed)
 | 
			
		||||
        self.assertIn("--build-arg SCCACHE_REGION_NAME=us-west-2", squashed)
 | 
			
		||||
        self.assertIn("--build-arg torch_cuda_arch_list='8.0;9.0'", squashed)
 | 
			
		||||
        self.assertIn("--target export-wheels", squashed)
 | 
			
		||||
        self.assertIn("-t vllm-wheels", squashed)
 | 
			
		||||
 | 
			
		||||
    @patch("cli.lib.core.vllm.run_command")
 | 
			
		||||
    @patch("cli.lib.core.vllm.ensure_dir_exists")
 | 
			
		||||
    @patch("cli.lib.core.vllm.clone_vllm")
 | 
			
		||||
    @patch.object(
 | 
			
		||||
        vllm.VllmBuildRunner,
 | 
			
		||||
        "_generate_docker_build_cmd",
 | 
			
		||||
        return_value="docker buildx ...",
 | 
			
		||||
    )
 | 
			
		||||
    @patch.dict(
 | 
			
		||||
        os.environ,
 | 
			
		||||
        {
 | 
			
		||||
            # Make __post_init__ validations pass cheaply
 | 
			
		||||
            "USE_TORCH_WHEEL": "0",
 | 
			
		||||
            "USE_LOCAL_BASE_IMAGE": "0",
 | 
			
		||||
            "USE_LOCAL_DOCKERFILE": "0",
 | 
			
		||||
            "OUTPUT_DIR": "shared",
 | 
			
		||||
        },
 | 
			
		||||
        clear=True,
 | 
			
		||||
    )
 | 
			
		||||
    def test_run_calls_clone_prepare_and_build(
 | 
			
		||||
        self, mock_gen, mock_clone, mock_ensure, mock_run
 | 
			
		||||
    ):
 | 
			
		||||
        # Stub parameters instance so we avoid FS/Docker accesses in run()
 | 
			
		||||
        params = MagicMock()
 | 
			
		||||
        params.output_dir = Path("shared")
 | 
			
		||||
        params.use_local_dockerfile = False
 | 
			
		||||
        params.use_torch_whl = False
 | 
			
		||||
 | 
			
		||||
        with patch("cli.lib.core.vllm.VllmBuildParameters", return_value=params):
 | 
			
		||||
            runner = vllm.VllmBuildRunner()
 | 
			
		||||
            runner.run()
 | 
			
		||||
 | 
			
		||||
        mock_clone.assert_called_once()
 | 
			
		||||
        mock_ensure.assert_called_once_with(Path("shared"))
 | 
			
		||||
        mock_gen.assert_called_once_with(params)
 | 
			
		||||
        mock_run.assert_called_once()
 | 
			
		||||
        # ensure we run in vllm workdir
 | 
			
		||||
        _, kwargs = mock_run.call_args
 | 
			
		||||
        assert kwargs.get("cwd") == "vllm"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    unittest.main()
 | 
			
		||||
@ -5,10 +5,6 @@ set -ex
 | 
			
		||||
SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
 | 
			
		||||
 | 
			
		||||
case "${GPU_ARCH_TYPE:-BLANK}" in
 | 
			
		||||
    BLANK)
 | 
			
		||||
        # Legacy behavior for CircleCI
 | 
			
		||||
        bash "${SCRIPTPATH}/build_cuda.sh"
 | 
			
		||||
        ;;
 | 
			
		||||
    cuda)
 | 
			
		||||
        bash "${SCRIPTPATH}/build_cuda.sh"
 | 
			
		||||
        ;;
 | 
			
		||||
 | 
			
		||||
@ -134,6 +134,7 @@ if [[ $CUDA_VERSION == 12* ]]; then
 | 
			
		||||
            "/usr/local/cuda/lib64/libnvrtc-builtins.so"
 | 
			
		||||
            "/usr/local/cuda/lib64/libcufile.so.0"
 | 
			
		||||
            "/usr/local/cuda/lib64/libcufile_rdma.so.1"
 | 
			
		||||
            "/usr/local/cuda/lib64/libnvshem_host.so.3"
 | 
			
		||||
            "/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12"
 | 
			
		||||
            "/usr/local/cuda/extras/CUPTI/lib64/libnvperf_host.so"
 | 
			
		||||
        )
 | 
			
		||||
@ -152,6 +153,7 @@ if [[ $CUDA_VERSION == 12* ]]; then
 | 
			
		||||
            "libcudart.so.12"
 | 
			
		||||
            "libnvrtc.so.12"
 | 
			
		||||
            "libnvrtc-builtins.so"
 | 
			
		||||
            "libnvshmem_host.so.3"
 | 
			
		||||
            "libcufile.so.0"
 | 
			
		||||
            "libcufile_rdma.so.1"
 | 
			
		||||
            "libcupti.so.12"
 | 
			
		||||
 | 
			
		||||
@ -92,6 +92,27 @@ if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then
 | 
			
		||||
  export ACL_ROOT_DIR=/ComputeLibrary
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
if [[ "$BUILD_ENVIRONMENT" == *riscv64* ]]; then
 | 
			
		||||
  if [[ -f /opt/riscv-cross-env/bin/activate ]]; then
 | 
			
		||||
    # shellcheck disable=SC1091
 | 
			
		||||
    source /opt/riscv-cross-env/bin/activate
 | 
			
		||||
  else
 | 
			
		||||
    echo "Activation file not found"
 | 
			
		||||
    exit 1
 | 
			
		||||
  fi
 | 
			
		||||
 | 
			
		||||
  export CMAKE_CROSSCOMPILING=TRUE
 | 
			
		||||
  export CMAKE_SYSTEM_NAME=Linux
 | 
			
		||||
  export CMAKE_SYSTEM_PROCESSOR=riscv64
 | 
			
		||||
 | 
			
		||||
  export USE_CUDA=0
 | 
			
		||||
  export USE_MKLDNN=0
 | 
			
		||||
 | 
			
		||||
  export SLEEF_TARGET_EXEC_USE_QEMU=ON
 | 
			
		||||
  sudo chown -R jenkins /var/lib/jenkins/workspace /opt
 | 
			
		||||
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
if [[ "$BUILD_ENVIRONMENT" == *libtorch* ]]; then
 | 
			
		||||
  POSSIBLE_JAVA_HOMES=()
 | 
			
		||||
  POSSIBLE_JAVA_HOMES+=(/usr/local)
 | 
			
		||||
@ -209,7 +230,7 @@ fi
 | 
			
		||||
 | 
			
		||||
# Do not change workspace permissions for ROCm and s390x CI jobs
 | 
			
		||||
# as it can leave workspace with bad permissions for cancelled jobs
 | 
			
		||||
if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *s390x* && -d /var/lib/jenkins/workspace ]]; then
 | 
			
		||||
if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *riscv64* && -d /var/lib/jenkins/workspace ]]; then
 | 
			
		||||
  # Workaround for dind-rootless userid mapping (https://github.com/pytorch/ci-infra/issues/96)
 | 
			
		||||
  WORKSPACE_ORIGINAL_OWNER_ID=$(stat -c '%u' "/var/lib/jenkins/workspace")
 | 
			
		||||
  cleanup_workspace() {
 | 
			
		||||
@ -254,8 +275,7 @@ else
 | 
			
		||||
    # XLA test build fails when WERROR=1
 | 
			
		||||
    # set only when building other architectures
 | 
			
		||||
    # or building non-XLA tests.
 | 
			
		||||
    if [[ "$BUILD_ENVIRONMENT" != *rocm*  &&
 | 
			
		||||
          "$BUILD_ENVIRONMENT" != *xla* ]]; then
 | 
			
		||||
    if [[ "$BUILD_ENVIRONMENT" != *rocm*  && "$BUILD_ENVIRONMENT" != *xla* && "$BUILD_ENVIRONMENT" != *riscv64* ]]; then
 | 
			
		||||
      # Install numpy-2.0.2 for builds which are backward compatible with 1.X
 | 
			
		||||
      python -mpip install numpy==2.0.2
 | 
			
		||||
 | 
			
		||||
@ -392,7 +412,7 @@ if [[ "$BUILD_ENVIRONMENT" != *libtorch* && "$BUILD_ENVIRONMENT" != *bazel* ]];
 | 
			
		||||
  # don't do this for libtorch as libtorch is C++ only and thus won't have python tests run on its build
 | 
			
		||||
  python tools/stats/export_test_times.py
 | 
			
		||||
fi
 | 
			
		||||
# don't do this for bazel or s390x as they don't use sccache
 | 
			
		||||
if [[ "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *-bazel-* ]]; then
 | 
			
		||||
# don't do this for bazel or s390x or riscv64 as they don't use sccache
 | 
			
		||||
if [[ "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *riscv64* && "$BUILD_ENVIRONMENT" != *-bazel-* ]]; then
 | 
			
		||||
  print_sccache_stats
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
@ -178,6 +178,10 @@ checkout_install_torchbench() {
 | 
			
		||||
  # soxr comes from https://github.com/huggingface/transformers/pull/39429
 | 
			
		||||
  pip install transformers==4.54.0 soxr==0.5.0
 | 
			
		||||
 | 
			
		||||
  # https://github.com/pytorch/pytorch/issues/160689 to remove torchao because
 | 
			
		||||
  # its current version 0.12.0 doesn't work with transformers 4.54.0
 | 
			
		||||
  pip uninstall -y torchao
 | 
			
		||||
 | 
			
		||||
  echo "Print all dependencies after TorchBench is installed"
 | 
			
		||||
  python -mpip freeze
 | 
			
		||||
  popd
 | 
			
		||||
 | 
			
		||||
@ -37,7 +37,7 @@ IF "%CUDA_PATH_V126%"=="" (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
IF "%BUILD_VISION%" == "" (
 | 
			
		||||
    set TORCH_CUDA_ARCH_LIST=6.1;7.0;7.5;8.0;8.6;9.0
 | 
			
		||||
    set TORCH_CUDA_ARCH_LIST=5.0;6.0;6.1;7.0;7.5;8.0;8.6;9.0
 | 
			
		||||
    set TORCH_NVCC_FLAGS=-Xfatbin -compress-all
 | 
			
		||||
) ELSE (
 | 
			
		||||
    set NVCC_FLAGS=-D__CUDA_NO_HALF_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_90,code=compute_90
 | 
			
		||||
 | 
			
		||||
@ -133,6 +133,25 @@ EXTRA_CONDA_INSTALL_FLAGS=""
 | 
			
		||||
CONDA_ENV_CREATE_FLAGS=""
 | 
			
		||||
RENAME_WHEEL=true
 | 
			
		||||
case $desired_python in
 | 
			
		||||
    3.14t)
 | 
			
		||||
        echo "Using 3.14 deps"
 | 
			
		||||
        SETUPTOOLS_PINNED_VERSION=">=70.1.0"
 | 
			
		||||
        PYYAML_PINNED_VERSION=">=6.0.1"
 | 
			
		||||
        NUMPY_PINNED_VERSION="=2.1.0"
 | 
			
		||||
        CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
        EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
        desired_python="3.14.0rc1"
 | 
			
		||||
        RENAME_WHEEL=false
 | 
			
		||||
        ;;
 | 
			
		||||
    3.14)
 | 
			
		||||
        echo "Using 3.14t deps"
 | 
			
		||||
        SETUPTOOLS_PINNED_VERSION=">=70.1.0"
 | 
			
		||||
        PYYAML_PINNED_VERSION=">=6.0.1"
 | 
			
		||||
        NUMPY_PINNED_VERSION="=2.1.0"
 | 
			
		||||
        EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
        desired_python="3.14.0rc1"
 | 
			
		||||
        RENAME_WHEEL=false
 | 
			
		||||
        ;;
 | 
			
		||||
    3.13t)
 | 
			
		||||
        echo "Using 3.13 deps"
 | 
			
		||||
        SETUPTOOLS_PINNED_VERSION=">=70.1.0"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/actionlint.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/actionlint.yaml
									
									
									
									
										vendored
									
									
								
							@ -54,6 +54,7 @@ self-hosted-runner:
 | 
			
		||||
    - linux.rocm.gpu.2
 | 
			
		||||
    - linux.rocm.gpu.4
 | 
			
		||||
    # gfx942 runners
 | 
			
		||||
    - linux.rocm.gpu.gfx942.1
 | 
			
		||||
    - linux.rocm.gpu.gfx942.2
 | 
			
		||||
    - linux.rocm.gpu.gfx942.4
 | 
			
		||||
    - rocm-docker
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										80
									
								
								.github/actions/build-external-packages/action.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								.github/actions/build-external-packages/action.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,80 @@
 | 
			
		||||
# .github/workflows/build-external.yml
 | 
			
		||||
name: Build External packages
 | 
			
		||||
 | 
			
		||||
description: build external packages for PyTorch
 | 
			
		||||
 | 
			
		||||
inputs:
 | 
			
		||||
  cuda-arch-list:
 | 
			
		||||
    description: TORCH_CUDA_ARCH_LIST (e.g., "8.0;8.9;9.0")
 | 
			
		||||
    type: string
 | 
			
		||||
    required: true
 | 
			
		||||
    default: ""
 | 
			
		||||
  docker-image:
 | 
			
		||||
    description: Base image to use
 | 
			
		||||
    type: string
 | 
			
		||||
    required: true
 | 
			
		||||
  build-targets:
 | 
			
		||||
    description: Build targets
 | 
			
		||||
    type: string
 | 
			
		||||
    required: true
 | 
			
		||||
  torch-wheel-dir:
 | 
			
		||||
    description: Directory to built torch wheel
 | 
			
		||||
    type: string
 | 
			
		||||
    required: false
 | 
			
		||||
    default: dist
 | 
			
		||||
  output-dir:
 | 
			
		||||
    description: Directory to store build artifact
 | 
			
		||||
    default: external
 | 
			
		||||
    type: string
 | 
			
		||||
    required: false
 | 
			
		||||
 | 
			
		||||
outputs:
 | 
			
		||||
  build_time:
 | 
			
		||||
    description: "Total build time in seconds"
 | 
			
		||||
    value: ${{ steps.build-external.outputs.build_time }}
 | 
			
		||||
  output_dir:
 | 
			
		||||
    description: "Directory where build artifact is stored"
 | 
			
		||||
    value: ${{ steps.build-external.outputs.output_dir }}
 | 
			
		||||
 | 
			
		||||
runs:
 | 
			
		||||
  using: composite
 | 
			
		||||
  steps:
 | 
			
		||||
    - name: Build external packages in sequence
 | 
			
		||||
      id: build-external
 | 
			
		||||
      env:
 | 
			
		||||
        SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2
 | 
			
		||||
        SCCACHE_REGION: us-east-1
 | 
			
		||||
        TORCH_CUDA_ARCH_LIST: ${{ inputs.cuda-arch-list }}
 | 
			
		||||
        BASE_IMAGE: ${{ inputs.docker-image }}
 | 
			
		||||
        BUILD_TARGETS: ${{ inputs.build-targets }}
 | 
			
		||||
        PARENT_OUTPUT_DIR: ${{ inputs.output-dir}}
 | 
			
		||||
      shell: bash
 | 
			
		||||
      run: |
 | 
			
		||||
        set -euo pipefail
 | 
			
		||||
        python3 --version
 | 
			
		||||
        docker images
 | 
			
		||||
        START_TIME=$(date +%s)
 | 
			
		||||
        (
 | 
			
		||||
          cd .ci/lumen_cli
 | 
			
		||||
          python3 -m pip install -e .
 | 
			
		||||
        )
 | 
			
		||||
        MAX_JOBS="$(nproc --ignore=6)"
 | 
			
		||||
        export MAX_JOBS
 | 
			
		||||
 | 
			
		||||
        # Split the comma-separated list and build each target
 | 
			
		||||
        IFS=',' read -ra TARGETS <<< "$BUILD_TARGETS"
 | 
			
		||||
        for target in "${TARGETS[@]}"; do
 | 
			
		||||
          OUTPUT_DIR="$PARENT_OUTPUT_DIR/$target"
 | 
			
		||||
          export OUTPUT_DIR
 | 
			
		||||
          echo "Building external package: $target in directory $OUTPUT_DIR"
 | 
			
		||||
          python3 -m cli.run build external "$target"
 | 
			
		||||
 | 
			
		||||
        done
 | 
			
		||||
 | 
			
		||||
        END_TIME=$(date +%s)
 | 
			
		||||
        {
 | 
			
		||||
          echo "build_time=$((END_TIME - START_TIME))"
 | 
			
		||||
          if [ -d "$PARENT_OUTPUT_DIR" ]; then
 | 
			
		||||
            echo "output_dir=$PARENT_OUTPUT_DIR"
 | 
			
		||||
          fi
 | 
			
		||||
        } >> "$GITHUB_OUTPUT"
 | 
			
		||||
							
								
								
									
										5
									
								
								.github/actions/setup-rocm/action.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/actions/setup-rocm/action.yml
									
									
									
									
										vendored
									
									
								
							@ -59,11 +59,6 @@ runs:
 | 
			
		||||
            echo "$msg"
 | 
			
		||||
            exit 1
 | 
			
		||||
        fi
 | 
			
		||||
        if [[ $ngpu -eq 1 ]]; then
 | 
			
		||||
            echo "Error: only 1 GPU detected, at least 2 GPUs are needed for distributed jobs"
 | 
			
		||||
            echo "$msg"
 | 
			
		||||
            exit 1
 | 
			
		||||
        fi
 | 
			
		||||
 | 
			
		||||
    - name: Runner diskspace health check
 | 
			
		||||
      uses: pytorch/pytorch/.github/actions/diskspace-cleanup@main
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/audio.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
bdb88e1d66f272cad72156c90ac8428ca61a601c
 | 
			
		||||
f92ceca80df7a36194468665d62b0f791b1826c5
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ci_commit_pins/vllm.txt
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ci_commit_pins/vllm.txt
									
									
									
									
										vendored
									
									
								
							@ -1 +1 @@
 | 
			
		||||
458e74eb907f96069e6d8a4f3c9f457001fef2ea
 | 
			
		||||
0ca2393b47e72c4424a49aa3b32c7c5d0e378a72
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										414
									
								
								.github/ci_configs/vllm/Dockerfile.tmp_vllm
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										414
									
								
								.github/ci_configs/vllm/Dockerfile.tmp_vllm
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,414 @@
 | 
			
		||||
# TODO(elainwy): remove this file after the torch nightly dockerfile is in sync in vllm repo
 | 
			
		||||
# The vLLM Dockerfile is used to construct vLLM image against torch nightly and torch main that can be directly used for testing
 | 
			
		||||
 | 
			
		||||
ARG CUDA_VERSION=12.8.1
 | 
			
		||||
ARG PYTHON_VERSION=3.12
 | 
			
		||||
 | 
			
		||||
# BUILD_BASE_IMAGE: used to setup python build xformers, and vllm wheels, It can be replaced with a different base image from local machine,
 | 
			
		||||
# by default, it uses the torch-nightly-base stage from this docker image
 | 
			
		||||
ARG BUILD_BASE_IMAGE=torch-nightly-base
 | 
			
		||||
 | 
			
		||||
# FINAL_BASE_IMAGE: used to set up vllm-instaled environment and build flashinfer,
 | 
			
		||||
# by default, it uses devel-ubuntu22.04 official image.
 | 
			
		||||
ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#################### TORCH NIGHTLY  BASE IMAGE ####################
 | 
			
		||||
# A base image for building vLLM with devel ubuntu 22.04, this is mainly used to build vllm in vllm builtkite ci
 | 
			
		||||
From nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 as torch-nightly-base
 | 
			
		||||
ARG CUDA_VERSION=12.8.1
 | 
			
		||||
ARG PYTHON_VERSION=3.12
 | 
			
		||||
ARG TARGETPLATFORM
 | 
			
		||||
ENV DEBIAN_FRONTEND=noninteractive
 | 
			
		||||
 | 
			
		||||
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
 | 
			
		||||
    echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
 | 
			
		||||
 | 
			
		||||
# Install Python and other dependencies if it does not existed
 | 
			
		||||
RUN if ! command -v python3 >/dev/null || ! python3 --version | grep -q "${PYTHON_VERSION}"; then \
 | 
			
		||||
      echo "Installing Python ${PYTHON_VERSION}..." && \
 | 
			
		||||
      echo 'tzdata tzdata/Areas select America' | debconf-set-selections && \
 | 
			
		||||
      echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections && \
 | 
			
		||||
      apt-get update -y && \
 | 
			
		||||
      apt-get install -y ccache software-properties-common git curl sudo && \
 | 
			
		||||
      for i in 1 2 3; do \
 | 
			
		||||
        add-apt-repository -y ppa:deadsnakes/ppa && break || \
 | 
			
		||||
        { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
 | 
			
		||||
      done && \
 | 
			
		||||
      apt-get update -y && \
 | 
			
		||||
      apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv && \
 | 
			
		||||
      update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 && \
 | 
			
		||||
      update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} && \
 | 
			
		||||
      ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config && \
 | 
			
		||||
      curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION}; \
 | 
			
		||||
   else \
 | 
			
		||||
      echo "Python ${PYTHON_VERSION} already present, skipping setup."; \
 | 
			
		||||
   fi \
 | 
			
		||||
   && python3 --version && python3 -m pip --version
 | 
			
		||||
 | 
			
		||||
# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519
 | 
			
		||||
# as it was causing spam when compiling the CUTLASS kernels
 | 
			
		||||
# Ensure gcc >= 10 to avoid CUTLASS issues (bug 92519)
 | 
			
		||||
RUN current_gcc_version=$(gcc -dumpversion | cut -f1 -d.) && \
 | 
			
		||||
    if [ "$current_gcc_version" -lt 10 ]; then \
 | 
			
		||||
      echo "GCC version is $current_gcc_version, installing gcc-10..."; \
 | 
			
		||||
      apt-get update && \
 | 
			
		||||
      apt-get install -y gcc-10 g++-10 && \
 | 
			
		||||
      update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 100 && \
 | 
			
		||||
      update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-10 100; \
 | 
			
		||||
    else \
 | 
			
		||||
      echo "GCC version is $current_gcc_version, no need to install gcc-10."; \
 | 
			
		||||
    fi && \
 | 
			
		||||
    gcc --version && g++ --version
 | 
			
		||||
 | 
			
		||||
# install uv for faster pip installs
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    python3 -m pip install uv==0.8.4
 | 
			
		||||
 | 
			
		||||
ENV UV_HTTP_TIMEOUT=500
 | 
			
		||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
 | 
			
		||||
 | 
			
		||||
#################### TORCH NIGHTLY  BASE IMAGE ####################
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#################### BASE BUILD IMAGE ####################
 | 
			
		||||
# A base image for building vLLM with torch nightly or torch wheels
 | 
			
		||||
# prepare basic build environment
 | 
			
		||||
FROM ${BUILD_BASE_IMAGE} AS base
 | 
			
		||||
USER root
 | 
			
		||||
 | 
			
		||||
# Workaround for https://github.com/openai/triton/issues/2507 and
 | 
			
		||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
 | 
			
		||||
# this won't be needed for future versions of this docker image
 | 
			
		||||
# or future versions of triton.
 | 
			
		||||
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
 | 
			
		||||
 | 
			
		||||
# Install uv for faster pip installs if not existed
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    if ! python3 -m uv --version >/dev/null 2>&1; then \
 | 
			
		||||
        python3 -m pip install uv==0.8.4; \
 | 
			
		||||
    fi
 | 
			
		||||
ENV UV_HTTP_TIMEOUT=500
 | 
			
		||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
 | 
			
		||||
 | 
			
		||||
WORKDIR /workspace
 | 
			
		||||
 | 
			
		||||
# install build and runtime dependencies
 | 
			
		||||
COPY requirements/common.txt requirements/common.txt
 | 
			
		||||
COPY use_existing_torch.py use_existing_torch.py
 | 
			
		||||
COPY pyproject.toml pyproject.toml
 | 
			
		||||
 | 
			
		||||
# install build and runtime dependencies without stable torch version
 | 
			
		||||
RUN python3 use_existing_torch.py
 | 
			
		||||
 | 
			
		||||
# default mount file as placeholder, this just avoid the mount error
 | 
			
		||||
# change to a different vllm folder if this does not exist anymore
 | 
			
		||||
ARG TORCH_WHEELS_PATH="./requirements"
 | 
			
		||||
ARG PINNED_TORCH_VERSION
 | 
			
		||||
 | 
			
		||||
# Install torch, torchaudio and torchvision based on the input
 | 
			
		||||
# if TORCH_WHEELS_PATH is default "./requirements", it will pull thethe nightly versions using pip
 | 
			
		||||
# otherwise, it will use the whls from TORCH_WHEELS_PATH from the host machine
 | 
			
		||||
RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \
 | 
			
		||||
    --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    if [ -n "$TORCH_WHEELS_PATH" ] && [ "$TORCH_WHEELS_PATH" != "./requirements" ] && [ -d "/dist" ] && ls /dist/torch*.whl >/dev/null 2>&1; then \
 | 
			
		||||
        torch_whl=$(find /dist -maxdepth 1 -name 'torch-*.whl' -print -quit); \
 | 
			
		||||
        vision_whl=$(find /dist/vision -name 'torchvision*.whl' | head -n1 | xargs); \
 | 
			
		||||
        audio_whl=$(find /dist/audio -name 'torchaudio*.whl' | head -n1 | xargs); \
 | 
			
		||||
        uv pip install --system "${torch_whl}[opt-einsum]"; \
 | 
			
		||||
        uv pip install --system "${vision_whl}"; \
 | 
			
		||||
        uv pip install --system "${audio_whl}"; \
 | 
			
		||||
    elif [ -n "$PINNED_TORCH_VERSION" ]; then \
 | 
			
		||||
        echo "[INFO] Installing pinned torch nightly version: $PINNED_TORCH_VERSION"; \
 | 
			
		||||
        uv pip install --system "$PINNED_TORCH_VERSION" --index-url https://download.pytorch.org/whl/nightly/cu128; \
 | 
			
		||||
    else \
 | 
			
		||||
        echo "[INFO] Installing torch nightly with latest one"; \
 | 
			
		||||
        uv pip install --system torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128; \
 | 
			
		||||
    fi
 | 
			
		||||
 | 
			
		||||
# Install numba 0.61.2 for cuda environment
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system numba==0.61.2
 | 
			
		||||
 | 
			
		||||
# Install common dependencies from vllm common.txt
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
uv pip install --system -r requirements/common.txt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Must put before installing xformers, so it can install the correct version of xfomrers.
 | 
			
		||||
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
 | 
			
		||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
 | 
			
		||||
ARG max_jobs=16
 | 
			
		||||
ENV MAX_JOBS=${max_jobs}
 | 
			
		||||
 | 
			
		||||
# Build xformers with cuda and torch nightly/wheel
 | 
			
		||||
# following official xformers guidance: https://github.com/facebookresearch/xformers#build
 | 
			
		||||
ARG XFORMERS_COMMIT=f2de641ef670510cadab099ce6954031f52f191c
 | 
			
		||||
ENV CCACHE_DIR=/root/.cache/ccache
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/ccache \
 | 
			
		||||
    --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    echo 'git clone xformers...' \
 | 
			
		||||
    && git clone https://github.com/facebookresearch/xformers.git --recursive \
 | 
			
		||||
    && cd xformers \
 | 
			
		||||
    && git checkout ${XFORMERS_COMMIT} \
 | 
			
		||||
    && git submodule update --init --recursive \
 | 
			
		||||
    && echo 'finish git clone xformers...' \
 | 
			
		||||
    && rm -rf build \
 | 
			
		||||
    && python3 setup.py bdist_wheel --dist-dir=../xformers-dist --verbose \
 | 
			
		||||
    && cd .. \
 | 
			
		||||
    && rm -rf xformers
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system xformers-dist/*.whl --verbose
 | 
			
		||||
 | 
			
		||||
# Build can take a long time, and the torch nightly version fetched from url can be different in next docker stage.
 | 
			
		||||
# track the nightly torch version used in the build, when we set up runtime environment we can make sure the version is the same
 | 
			
		||||
RUN uv pip freeze | grep -i '^torch\|^torchvision\|^torchaudio' > torch_build_versions.txt
 | 
			
		||||
RUN cat  torch_build_versions.txt
 | 
			
		||||
 | 
			
		||||
RUN pip freeze | grep -E 'torch|xformers|torchvision|torchaudio'
 | 
			
		||||
 | 
			
		||||
#################### BASE BUILD IMAGE ####################
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#################### WHEEL BUILD IMAGE ####################
 | 
			
		||||
# Image used to build vllm wheel
 | 
			
		||||
FROM base AS build
 | 
			
		||||
ARG TARGETPLATFORM
 | 
			
		||||
 | 
			
		||||
ENV UV_HTTP_TIMEOUT=500
 | 
			
		||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
 | 
			
		||||
 | 
			
		||||
COPY . .
 | 
			
		||||
 | 
			
		||||
RUN python3 use_existing_torch.py
 | 
			
		||||
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system -r requirements/build.txt
 | 
			
		||||
 | 
			
		||||
ARG GIT_REPO_CHECK=0
 | 
			
		||||
RUN --mount=type=bind,source=.git,target=.git \
 | 
			
		||||
    if [ "$GIT_REPO_CHECK" != "0" ]; then bash tools/check_repo.sh ; fi
 | 
			
		||||
 | 
			
		||||
# Max jobs used by Ninja to build extensions
 | 
			
		||||
ARG max_jobs=16
 | 
			
		||||
ENV MAX_JOBS=${max_jobs}
 | 
			
		||||
ARG nvcc_threads=2
 | 
			
		||||
ENV NVCC_THREADS=$nvcc_threads
 | 
			
		||||
ARG torch_cuda_arch_list='8.0;8.6;8.9;9.0'
 | 
			
		||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
 | 
			
		||||
 | 
			
		||||
ARG USE_SCCACHE
 | 
			
		||||
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
 | 
			
		||||
ARG SCCACHE_REGION_NAME=us-west-2
 | 
			
		||||
ARG SCCACHE_S3_NO_CREDENTIALS=0
 | 
			
		||||
 | 
			
		||||
# if USE_SCCACHE is set, use sccache to speed up compilation
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    --mount=type=bind,source=.git,target=.git \
 | 
			
		||||
    if [ "$USE_SCCACHE" = "1" ]; then \
 | 
			
		||||
        echo "Installing sccache..." \
 | 
			
		||||
        && curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \
 | 
			
		||||
        && tar -xzf sccache.tar.gz \
 | 
			
		||||
        && sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
 | 
			
		||||
        && rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
 | 
			
		||||
        && export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \
 | 
			
		||||
        && export SCCACHE_REGION=${SCCACHE_REGION_NAME} \
 | 
			
		||||
        && export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \
 | 
			
		||||
        && export SCCACHE_IDLE_TIMEOUT=0 \
 | 
			
		||||
        && export CMAKE_BUILD_TYPE=Release \
 | 
			
		||||
        && sccache --show-stats \
 | 
			
		||||
        && python3 setup.py bdist_wheel --dist-dir=vllm-dist --py-limited-api=cp38 \
 | 
			
		||||
        && sccache --show-stats; \
 | 
			
		||||
    fi
 | 
			
		||||
 | 
			
		||||
ENV CCACHE_DIR=/root/.cache/ccache
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/ccache \
 | 
			
		||||
    --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    --mount=type=bind,source=.git,target=.git  \
 | 
			
		||||
    if [ "$USE_SCCACHE" != "1" ]; then \
 | 
			
		||||
        # Clean any existing CMake artifacts
 | 
			
		||||
        rm -rf .deps && \
 | 
			
		||||
        mkdir -p .deps && \
 | 
			
		||||
        python3 setup.py bdist_wheel --dist-dir=vllm-dist --py-limited-api=cp38; \
 | 
			
		||||
    fi
 | 
			
		||||
 | 
			
		||||
RUN echo "[DEBUG] Listing  current directory:" && \
 | 
			
		||||
    ls -al && \
 | 
			
		||||
    echo "[DEBUG] Showing torch_build_versions.txt content:" && \
 | 
			
		||||
    cat torch_build_versions.txt
 | 
			
		||||
 | 
			
		||||
#################### WHEEL BUILD IMAGE ####################
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
################### VLLM INSTALLED IMAGE ####################
 | 
			
		||||
# Setup clean environment for vLLM for test and api server using ubuntu22.04 with AOT flashinfer
 | 
			
		||||
FROM ${FINAL_BASE_IMAGE} AS vllm-base
 | 
			
		||||
USER root
 | 
			
		||||
# prepare for environment starts
 | 
			
		||||
WORKDIR /workspace
 | 
			
		||||
 | 
			
		||||
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
 | 
			
		||||
    echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
 | 
			
		||||
 | 
			
		||||
# Install Python and other dependencies if it does not existed
 | 
			
		||||
RUN if ! command -v python3 >/dev/null || ! python3 --version | grep -q "${PYTHON_VERSION}"; then \
 | 
			
		||||
      echo "Installing Python ${PYTHON_VERSION}..." && \
 | 
			
		||||
      echo 'tzdata tzdata/Areas select America' | debconf-set-selections && \
 | 
			
		||||
      echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections && \
 | 
			
		||||
      apt-get update -y && \
 | 
			
		||||
      apt-get install -y ccache software-properties-common git curl sudo && \
 | 
			
		||||
      for i in 1 2 3; do \
 | 
			
		||||
        add-apt-repository -y ppa:deadsnakes/ppa && break || \
 | 
			
		||||
        { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
 | 
			
		||||
      done && \
 | 
			
		||||
      apt-get update -y && \
 | 
			
		||||
      apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv && \
 | 
			
		||||
      update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 && \
 | 
			
		||||
      update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} && \
 | 
			
		||||
      ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config && \
 | 
			
		||||
      curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION}; \
 | 
			
		||||
   else \
 | 
			
		||||
      echo "Python ${PYTHON_VERSION} already present, skipping setup."; \
 | 
			
		||||
   fi \
 | 
			
		||||
   && python3 --version && python3 -m pip --version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Get the torch versions, and whls used in previous stagtes for consistency
 | 
			
		||||
COPY --from=base /workspace/torch_build_versions.txt ./torch_build_versions.txt
 | 
			
		||||
COPY --from=base /workspace/xformers-dist /wheels/xformers
 | 
			
		||||
COPY --from=build /workspace/vllm-dist /wheels/vllm
 | 
			
		||||
RUN echo "[DEBUG] Listing current directory before torch install step:" && \
 | 
			
		||||
    ls -al && \
 | 
			
		||||
    echo "[DEBUG] Showing torch_build_versions.txt content:" && \
 | 
			
		||||
    cat torch_build_versions.txt
 | 
			
		||||
 | 
			
		||||
# Workaround for https://github.com/openai/triton/issues/2507 and
 | 
			
		||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
 | 
			
		||||
# this won't be needed for future versions of this docker image
 | 
			
		||||
# or future versions of triton.
 | 
			
		||||
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Install uv for faster pip installs if not existed
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    if ! python3 -m uv --version > /dev/null 2>&1; then \
 | 
			
		||||
        python3 -m pip install uv==0.8.4; \
 | 
			
		||||
    fi
 | 
			
		||||
ENV UV_HTTP_TIMEOUT=500
 | 
			
		||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
 | 
			
		||||
 | 
			
		||||
# Default mount file as placeholder, this just avoid the mount error
 | 
			
		||||
ARG TORCH_WHEELS_PATH="./requirements"
 | 
			
		||||
# Install torch, torchaudio and torchvision
 | 
			
		||||
# if TORCH_WHEELS_PATH is default "./requirements", it will pull the nightly versions using pip using torch_build_versions.txt
 | 
			
		||||
# otherwise, it will use the whls from TORCH_WHEELS_PATH from the host machine
 | 
			
		||||
RUN --mount=type=bind,source=${TORCH_WHEELS_PATH},target=/dist \
 | 
			
		||||
    --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    if [ -n "$TORCH_WHEELS_PATH" ] && [ "$TORCH_WHEELS_PATH" != "./requirements" ] && [ -d "/dist" ] && ls /dist/torch*.whl >/dev/null 2>&1; then \
 | 
			
		||||
        torch_whl=$(find /dist -maxdepth 1 -name 'torch-*.whl' -print -quit); \
 | 
			
		||||
        vision_whl=$(find /dist/vision -name 'torchvision*.whl' | head -n1 | xargs); \
 | 
			
		||||
        audio_whl=$(find /dist/audio -name 'torchaudio*.whl' | head -n1 | xargs); \
 | 
			
		||||
        echo "Found: '${torch_whl}' '${audio_whl}' '${vision_whl}'"; \
 | 
			
		||||
        uv pip install --system "${torch_whl}[opt-einsum]"; \
 | 
			
		||||
        uv pip install --system "${vision_whl}"; \
 | 
			
		||||
        uv pip install --system "${audio_whl}"; \
 | 
			
		||||
    else \
 | 
			
		||||
        echo "[INFO] Installing torch versions from torch_build_versions.txt"; \
 | 
			
		||||
        uv pip install --system $(cat torch_build_versions.txt | xargs) --index-url https://download.pytorch.org/whl/nightly/cu128; \
 | 
			
		||||
    fi
 | 
			
		||||
 | 
			
		||||
# Install the vllm wheel from previous stage
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system /wheels/vllm/*.whl --verbose
 | 
			
		||||
 | 
			
		||||
# Install xformers wheel from previous stage
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system /wheels/xformers/*.whl --verbose
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Build flashinfer from source.
 | 
			
		||||
ARG torch_cuda_arch_list='8.0;8.9;9.0a'
 | 
			
		||||
# install package for build flashinfer
 | 
			
		||||
# see issue: https://github.com/flashinfer-ai/flashinfer/issues/738
 | 
			
		||||
 | 
			
		||||
RUN pip install build==1.3.0
 | 
			
		||||
RUN pip freeze | grep -E 'setuptools|packaging|build'
 | 
			
		||||
 | 
			
		||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
 | 
			
		||||
# Build flashinfer for torch nightly from source around 10 mins
 | 
			
		||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
 | 
			
		||||
# Keep this in sync with https://github.com/vllm-project/vllm/blob/main/requirements/cuda.txt
 | 
			
		||||
ARG FLASHINFER_GIT_REF="v0.2.9rc2"
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    git clone --depth 1 --recursive --shallow-submodules \
 | 
			
		||||
        --branch ${FLASHINFER_GIT_REF} \
 | 
			
		||||
        ${FLASHINFER_GIT_REPO} flashinfer \
 | 
			
		||||
    && echo "Building FlashInfer with AOT for arches: ${torch_cuda_arch_list}" \
 | 
			
		||||
    && cd flashinfer \
 | 
			
		||||
    && python3 -m flashinfer.aot \
 | 
			
		||||
    && python3 -m build --no-isolation --wheel --outdir ../wheels/flashinfer \
 | 
			
		||||
    && cd .. \
 | 
			
		||||
    && rm -rf flashinfer
 | 
			
		||||
 | 
			
		||||
# install flashinfer python
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system wheels/flashinfer/*.whl --verbose
 | 
			
		||||
 | 
			
		||||
# Logging to confirm the torch versions
 | 
			
		||||
RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer'
 | 
			
		||||
################### VLLM INSTALLED IMAGE ####################
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#################### UNITTEST IMAGE #############################
 | 
			
		||||
FROM vllm-base as test
 | 
			
		||||
 | 
			
		||||
ENV UV_HTTP_TIMEOUT=500
 | 
			
		||||
ENV UV_INDEX_STRATEGY="unsafe-best-match"
 | 
			
		||||
 | 
			
		||||
COPY tests/ tests/
 | 
			
		||||
COPY examples examples
 | 
			
		||||
COPY benchmarks benchmarks
 | 
			
		||||
COPY ./vllm/collect_env.py .
 | 
			
		||||
COPY requirements/common.txt requirements/common.txt
 | 
			
		||||
COPY use_existing_torch.py use_existing_torch.py
 | 
			
		||||
COPY pyproject.toml pyproject.toml
 | 
			
		||||
# Install build and runtime dependencies without stable torch version
 | 
			
		||||
COPY requirements/nightly_torch_test.txt requirements/nightly_torch_test.txt
 | 
			
		||||
 | 
			
		||||
RUN python3 use_existing_torch.py
 | 
			
		||||
 | 
			
		||||
# install packages
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system -r requirements/common.txt
 | 
			
		||||
# enable fast downloads from hf (for testing)
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system hf_transfer
 | 
			
		||||
ENV HF_HUB_ENABLE_HF_TRANSFER 1
 | 
			
		||||
 | 
			
		||||
# install development dependencies (for testing)
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system -e tests/vllm_test_utils
 | 
			
		||||
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system -r requirements/nightly_torch_test.txt
 | 
			
		||||
 | 
			
		||||
# Workaround for #17068
 | 
			
		||||
# pinned commit for v2.2.4
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/uv \
 | 
			
		||||
    uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@95d8aba8a8c75aedcaa6143713b11e745e7cd0d9#egg=mamba-ssm"
 | 
			
		||||
 | 
			
		||||
# Logging to confirm the torch versions
 | 
			
		||||
RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer'
 | 
			
		||||
 | 
			
		||||
# Logging to confirm all the packages are installed
 | 
			
		||||
RUN pip freeze
 | 
			
		||||
 | 
			
		||||
#################### UNITTEST IMAGE #############################
 | 
			
		||||
 | 
			
		||||
#################### EXPORT STAGE ####################
 | 
			
		||||
FROM scratch as export-wheels
 | 
			
		||||
 | 
			
		||||
# Just copy the wheels we prepared in previous stages
 | 
			
		||||
COPY --from=base /workspace/xformers-dist /wheels/xformers
 | 
			
		||||
COPY --from=build /workspace/vllm-dist /wheels/vllm
 | 
			
		||||
COPY --from=vllm-base /workspace/wheels/flashinfer /wheels/flashinfer-python
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/pytorch-probot.yml
									
									
									
									
										vendored
									
									
								
							@ -22,6 +22,7 @@ ciflow_push_tags:
 | 
			
		||||
- ciflow/rocm
 | 
			
		||||
- ciflow/rocm-mi300
 | 
			
		||||
- ciflow/s390
 | 
			
		||||
- ciflow/riscv64
 | 
			
		||||
- ciflow/slow
 | 
			
		||||
- ciflow/trunk
 | 
			
		||||
- ciflow/unstable
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										5
									
								
								.github/requirements/conda-env-macOS-ARM64
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/requirements/conda-env-macOS-ARM64
									
									
									
									
										vendored
									
									
								
							@ -1,5 +0,0 @@
 | 
			
		||||
# Not pinning certifi so that we can always get the latest certificates
 | 
			
		||||
certifi
 | 
			
		||||
pip=23.2.1
 | 
			
		||||
pkg-config=0.29.2
 | 
			
		||||
wheel=0.37.1
 | 
			
		||||
							
								
								
									
										31
									
								
								.github/scripts/amd/package_triton_wheel.sh
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										31
									
								
								.github/scripts/amd/package_triton_wheel.sh
									
									
									
									
										vendored
									
									
								
							@ -1,3 +1,4 @@
 | 
			
		||||
#!/bin/bash
 | 
			
		||||
set -ex
 | 
			
		||||
 | 
			
		||||
# Set ROCM_HOME isn't available, use ROCM_PATH if set or /opt/rocm
 | 
			
		||||
@ -50,29 +51,15 @@ do
 | 
			
		||||
    cp $lib $TRITON_ROCM_DIR/lib/
 | 
			
		||||
done
 | 
			
		||||
 | 
			
		||||
# Required ROCm libraries
 | 
			
		||||
if [[ "${MAJOR_VERSION}" == "6" ]]; then
 | 
			
		||||
    libamdhip="libamdhip64.so.6"
 | 
			
		||||
else
 | 
			
		||||
    libamdhip="libamdhip64.so.5"
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
# Required ROCm libraries - ROCm 6.0
 | 
			
		||||
ROCM_SO=(
 | 
			
		||||
    "${libamdhip}"
 | 
			
		||||
    "libhsa-runtime64.so.1"
 | 
			
		||||
    "libdrm.so.2"
 | 
			
		||||
    "libdrm_amdgpu.so.1"
 | 
			
		||||
    "libamdhip64.so"
 | 
			
		||||
    "libhsa-runtime64.so"
 | 
			
		||||
    "libdrm.so"
 | 
			
		||||
    "libdrm_amdgpu.so"
 | 
			
		||||
    "libamd_comgr.so"
 | 
			
		||||
    "librocprofiler-register.so"
 | 
			
		||||
)
 | 
			
		||||
if [[ $ROCM_INT -ge 60400 ]]; then
 | 
			
		||||
    ROCM_SO+=("libamd_comgr.so.3")
 | 
			
		||||
else
 | 
			
		||||
    ROCM_SO+=("libamd_comgr.so.2")
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
if [[ $ROCM_INT -ge 60100 ]]; then
 | 
			
		||||
    ROCM_SO+=("librocprofiler-register.so.0")
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
for lib in "${ROCM_SO[@]}"
 | 
			
		||||
do
 | 
			
		||||
@ -94,10 +81,6 @@ do
 | 
			
		||||
    fi
 | 
			
		||||
 | 
			
		||||
    cp $file_path $TRITON_ROCM_DIR/lib
 | 
			
		||||
    # When running locally, and not building a wheel, we need to satisfy shared objects requests that don't look for versions
 | 
			
		||||
    LINKNAME=$(echo $lib | sed -e 's/\.so.*/.so/g')
 | 
			
		||||
    ln -sf $lib $TRITON_ROCM_DIR/lib/$LINKNAME
 | 
			
		||||
 | 
			
		||||
done
 | 
			
		||||
 | 
			
		||||
# Copy Include Files
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										16
									
								
								.github/scripts/amd/patch_triton_wheel.sh
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										16
									
								
								.github/scripts/amd/patch_triton_wheel.sh
									
									
									
									
										vendored
									
									
								
							@ -19,15 +19,13 @@ replace_needed_sofiles() {
 | 
			
		||||
    find $1 -name '*.so*' -o -name 'ld.lld' | while read sofile; do
 | 
			
		||||
        origname=$2
 | 
			
		||||
        patchedname=$3
 | 
			
		||||
        if [[ "$origname" != "$patchedname" ]]; then
 | 
			
		||||
            set +e
 | 
			
		||||
            origname=$($PATCHELF_BIN --print-needed $sofile | grep "$origname.*")
 | 
			
		||||
            ERRCODE=$?
 | 
			
		||||
            set -e
 | 
			
		||||
            if [ "$ERRCODE" -eq "0" ]; then
 | 
			
		||||
                echo "patching $sofile entry $origname to $patchedname"
 | 
			
		||||
                $PATCHELF_BIN --replace-needed $origname $patchedname $sofile
 | 
			
		||||
            fi
 | 
			
		||||
        set +e
 | 
			
		||||
        origname=$($PATCHELF_BIN --print-needed $sofile | grep "$origname.*")
 | 
			
		||||
        ERRCODE=$?
 | 
			
		||||
        set -e
 | 
			
		||||
        if [ "$ERRCODE" -eq "0" ]; then
 | 
			
		||||
            echo "patching $sofile entry $origname to $patchedname"
 | 
			
		||||
            $PATCHELF_BIN --replace-needed $origname $patchedname $sofile
 | 
			
		||||
        fi
 | 
			
		||||
    done
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -315,7 +315,7 @@ def generate_wheels_matrix(
 | 
			
		||||
            if gpu_arch_type == "cpu-s390x" and python_version == "3.13t":
 | 
			
		||||
                continue
 | 
			
		||||
            # TODO: Enable python 3.14 on non linux OSes
 | 
			
		||||
            if os != "linux" and (
 | 
			
		||||
            if os not in ["linux", "macos-arm64"] and (
 | 
			
		||||
                python_version == "3.14" or python_version == "3.14t"
 | 
			
		||||
            ):
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								.github/scripts/gql_mocks.json.gz
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										
											BIN
										
									
								
								.github/scripts/gql_mocks.json.gz
									
									
									
									
										vendored
									
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										3
									
								
								.github/scripts/test_trymerge.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/scripts/test_trymerge.py
									
									
									
									
										vendored
									
									
								
							@ -70,6 +70,9 @@ def mock_query(
 | 
			
		||||
    if key in mocked_queries:
 | 
			
		||||
        return mocked_queries[key]
 | 
			
		||||
 | 
			
		||||
    # TODO: Remove me once https://github.com/pytorch/pytorch/issues/160489 is resolved
 | 
			
		||||
    raise ValueError(f"Key {key} could not be found in gql_mocks")
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        rc = fallback_function(*args)
 | 
			
		||||
    except HTTPError as err:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								.github/scripts/trymerge.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/scripts/trymerge.py
									
									
									
									
										vendored
									
									
								
							@ -108,10 +108,6 @@ GH_CHECKSUITES_FRAGMENT = """
 | 
			
		||||
fragment PRCheckSuites on CheckSuiteConnection {
 | 
			
		||||
  edges {
 | 
			
		||||
    node {
 | 
			
		||||
      app {
 | 
			
		||||
        name
 | 
			
		||||
        databaseId
 | 
			
		||||
      }
 | 
			
		||||
      workflowRun {
 | 
			
		||||
        workflow {
 | 
			
		||||
          name
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/scripts/windows/build_triton.bat
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/scripts/windows/build_triton.bat
									
									
									
									
										vendored
									
									
								
							@ -10,7 +10,7 @@ if "%PY_VERS%" == "3.13t" (
 | 
			
		||||
    call conda create -n %PYTHON_PREFIX% -y -c=conda-forge python=%PY_VERS%
 | 
			
		||||
)
 | 
			
		||||
:: Fix cmake version for issue https://github.com/pytorch/pytorch/issues/150480
 | 
			
		||||
call conda run -n %PYTHON_PREFIX% pip install wheel pybind11 certifi cython cmake==3.31.6 setuptools==72.1.0 ninja
 | 
			
		||||
call conda run -n %PYTHON_PREFIX% pip install wheel pybind11 certifi cython cmake==3.31.6 setuptools==72.1.0 ninja==1.11.1.4
 | 
			
		||||
 | 
			
		||||
dir "%VC_INSTALL_PATH%"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -110,12 +110,33 @@ jobs:
 | 
			
		||||
          # Create new "clean" conda environment for testing
 | 
			
		||||
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
          if [[ $DESIRED_PYTHON == "3.13t" ]]; then
 | 
			
		||||
            conda create -yn "test_conda_env" python="3.13" python-freethreading -c conda-forge
 | 
			
		||||
            SMOKE_TEST_PARAMS="--torch-compile-check disabled"
 | 
			
		||||
          else
 | 
			
		||||
            conda create -yn "test_conda_env" python="$DESIRED_PYTHON"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          EXTRA_CONDA_INSTALL_FLAGS=""
 | 
			
		||||
          CONDA_ENV_CREATE_FLAGS=""
 | 
			
		||||
          # shellcheck disable=SC2153
 | 
			
		||||
          case $DESIRED_PYTHON in
 | 
			
		||||
            3.14t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.14)
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.13t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge"
 | 
			
		||||
              desired_python="3.13"
 | 
			
		||||
              ;;
 | 
			
		||||
            *)
 | 
			
		||||
              # shellcheck disable=SC2153
 | 
			
		||||
              desired_python=${DESIRED_PYTHON}
 | 
			
		||||
              ;;
 | 
			
		||||
          esac
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          conda create -yn "test_conda_env" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} ${EXTRA_CONDA_INSTALL_FLAGS}
 | 
			
		||||
          conda activate test_conda_env
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										26
									
								
								.github/workflows/_linux-build.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										26
									
								
								.github/workflows/_linux-build.yml
									
									
									
									
										vendored
									
									
								
							@ -287,10 +287,36 @@ jobs:
 | 
			
		||||
          # comes from https://github.com/pytorch/test-infra/pull/6058
 | 
			
		||||
          TOTAL_MEMORY_WITH_SWAP=$(("${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}" + 3))
 | 
			
		||||
 | 
			
		||||
          if [[ ${BUILD_ENVIRONMENT} == *"riscv64"* ]]; then
 | 
			
		||||
            # EC2 specific setup for RISC-V emulation
 | 
			
		||||
            # Ensure binfmt_misc is available
 | 
			
		||||
            echo "Mounting binfmt_misc filesystem"
 | 
			
		||||
            sudo mount binfmt_misc -t binfmt_misc /proc/sys/fs/binfmt_misc 2>/dev/null || true
 | 
			
		||||
 | 
			
		||||
            echo "QEMU registration: multiarch/qemu-user-static"
 | 
			
		||||
            docker run --rm --privileged multiarch/qemu-user-static --reset -p yes || true
 | 
			
		||||
 | 
			
		||||
            # Final verification
 | 
			
		||||
            echo "Checking binfmt_misc status:"
 | 
			
		||||
            ls -la /proc/sys/fs/binfmt_misc/ 2>/dev/null || echo "Cannot access binfmt_misc directory"
 | 
			
		||||
 | 
			
		||||
            if [ -f /proc/sys/fs/binfmt_misc/qemu-riscv64 ]; then
 | 
			
		||||
              echo "qemu-riscv64 registration successful"
 | 
			
		||||
            else
 | 
			
		||||
              echo "qemu-riscv64 registration failed - proceeding without emulation"
 | 
			
		||||
              echo "This may cause RISC-V builds to fail"
 | 
			
		||||
            fi
 | 
			
		||||
 | 
			
		||||
            RISCV_DOCKER_ARGS="--privileged"
 | 
			
		||||
          else
 | 
			
		||||
            RISCV_DOCKER_ARGS=
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          # detached container should get cleaned up by teardown_ec2_linux
 | 
			
		||||
          # Used for JENKINS_USER and DOCKER_SHELL_CMD, which can be empty
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          container_name=$(docker run \
 | 
			
		||||
            ${RISCV_DOCKER_ARGS} \
 | 
			
		||||
            -e BUILD_ENVIRONMENT \
 | 
			
		||||
            -e MAX_JOBS="$(nproc --ignore=2)" \
 | 
			
		||||
            -e PR_NUMBER \
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										10
									
								
								.github/workflows/_rocm-test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										10
									
								
								.github/workflows/_rocm-test.yml
									
									
									
									
										vendored
									
									
								
							@ -88,6 +88,16 @@ jobs:
 | 
			
		||||
      - name: Setup ROCm
 | 
			
		||||
        uses: ./.github/actions/setup-rocm
 | 
			
		||||
 | 
			
		||||
      - name: Runner check GPU count (distributed jobs)
 | 
			
		||||
        if: ${{ contains(matrix.config, 'distributed') }}
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx')
 | 
			
		||||
          if [[ $ngpu -lt 4 ]]; then
 | 
			
		||||
            echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs"
 | 
			
		||||
            exit 1
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
      - name: configure aws credentials
 | 
			
		||||
        id: aws_creds
 | 
			
		||||
        uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										3
									
								
								.github/workflows/check-labels.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/check-labels.yml
									
									
									
									
										vendored
									
									
								
							@ -34,8 +34,7 @@ jobs:
 | 
			
		||||
      contents: read
 | 
			
		||||
      pull-requests: write
 | 
			
		||||
    name: Check labels
 | 
			
		||||
    # Disabling the job until https://github.com/pytorch/pytorch/issues/159825 is resolved
 | 
			
		||||
    if: github.repository_owner == 'pytorch' && false
 | 
			
		||||
    if: github.repository_owner == 'pytorch'
 | 
			
		||||
    runs-on: linux.24_04.4x
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
 | 
			
		||||
@ -7,8 +7,7 @@ on:
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  ghstack-mergeability-check:
 | 
			
		||||
    # Disabling the job until https://github.com/pytorch/pytorch/issues/159825 is resolved
 | 
			
		||||
    if: github.repository_owner == 'pytorch' && false
 | 
			
		||||
    if: github.repository_owner == 'pytorch'
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    steps:
 | 
			
		||||
      - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										3
									
								
								.github/workflows/docker-builds.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/docker-builds.yml
									
									
									
									
										vendored
									
									
								
							@ -74,7 +74,8 @@ jobs:
 | 
			
		||||
          pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter,
 | 
			
		||||
          # Executorch pin needs update
 | 
			
		||||
          # pytorch-linux-jammy-py3-clang12-executorch,
 | 
			
		||||
          pytorch-linux-jammy-py3.12-triton-cpu
 | 
			
		||||
          pytorch-linux-jammy-py3.12-triton-cpu,
 | 
			
		||||
          pytorch-linux-noble-riscv64-py3.12-gcc14
 | 
			
		||||
        ]
 | 
			
		||||
        include:
 | 
			
		||||
          - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										488
									
								
								.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							
							
						
						
									
										488
									
								
								.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml
									
									
									
										generated
									
									
										vendored
									
									
								
							@ -115,12 +115,33 @@ jobs:
 | 
			
		||||
          # Create new "clean" conda environment for testing
 | 
			
		||||
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
          if [[ $DESIRED_PYTHON == "3.13t" ]]; then
 | 
			
		||||
            conda create -yn "test_conda_env" python="3.13" python-freethreading -c conda-forge
 | 
			
		||||
            SMOKE_TEST_PARAMS="--torch-compile-check disabled"
 | 
			
		||||
          else
 | 
			
		||||
            conda create -yn "test_conda_env" python="$DESIRED_PYTHON"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          EXTRA_CONDA_INSTALL_FLAGS=""
 | 
			
		||||
          CONDA_ENV_CREATE_FLAGS=""
 | 
			
		||||
          # shellcheck disable=SC2153
 | 
			
		||||
          case $DESIRED_PYTHON in
 | 
			
		||||
            3.14t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.14)
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.13t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge"
 | 
			
		||||
              desired_python="3.13"
 | 
			
		||||
              ;;
 | 
			
		||||
            *)
 | 
			
		||||
              # shellcheck disable=SC2153
 | 
			
		||||
              desired_python=${DESIRED_PYTHON}
 | 
			
		||||
              ;;
 | 
			
		||||
          esac
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          conda create -yn "test_conda_env" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} ${EXTRA_CONDA_INSTALL_FLAGS}
 | 
			
		||||
          conda activate test_conda_env
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -239,12 +260,33 @@ jobs:
 | 
			
		||||
          # Create new "clean" conda environment for testing
 | 
			
		||||
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
          if [[ $DESIRED_PYTHON == "3.13t" ]]; then
 | 
			
		||||
            conda create -yn "test_conda_env" python="3.13" python-freethreading -c conda-forge
 | 
			
		||||
            SMOKE_TEST_PARAMS="--torch-compile-check disabled"
 | 
			
		||||
          else
 | 
			
		||||
            conda create -yn "test_conda_env" python="$DESIRED_PYTHON"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          EXTRA_CONDA_INSTALL_FLAGS=""
 | 
			
		||||
          CONDA_ENV_CREATE_FLAGS=""
 | 
			
		||||
          # shellcheck disable=SC2153
 | 
			
		||||
          case $DESIRED_PYTHON in
 | 
			
		||||
            3.14t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.14)
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.13t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge"
 | 
			
		||||
              desired_python="3.13"
 | 
			
		||||
              ;;
 | 
			
		||||
            *)
 | 
			
		||||
              # shellcheck disable=SC2153
 | 
			
		||||
              desired_python=${DESIRED_PYTHON}
 | 
			
		||||
              ;;
 | 
			
		||||
          esac
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          conda create -yn "test_conda_env" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} ${EXTRA_CONDA_INSTALL_FLAGS}
 | 
			
		||||
          conda activate test_conda_env
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -363,12 +405,33 @@ jobs:
 | 
			
		||||
          # Create new "clean" conda environment for testing
 | 
			
		||||
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
          if [[ $DESIRED_PYTHON == "3.13t" ]]; then
 | 
			
		||||
            conda create -yn "test_conda_env" python="3.13" python-freethreading -c conda-forge
 | 
			
		||||
            SMOKE_TEST_PARAMS="--torch-compile-check disabled"
 | 
			
		||||
          else
 | 
			
		||||
            conda create -yn "test_conda_env" python="$DESIRED_PYTHON"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          EXTRA_CONDA_INSTALL_FLAGS=""
 | 
			
		||||
          CONDA_ENV_CREATE_FLAGS=""
 | 
			
		||||
          # shellcheck disable=SC2153
 | 
			
		||||
          case $DESIRED_PYTHON in
 | 
			
		||||
            3.14t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.14)
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.13t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge"
 | 
			
		||||
              desired_python="3.13"
 | 
			
		||||
              ;;
 | 
			
		||||
            *)
 | 
			
		||||
              # shellcheck disable=SC2153
 | 
			
		||||
              desired_python=${DESIRED_PYTHON}
 | 
			
		||||
              ;;
 | 
			
		||||
          esac
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          conda create -yn "test_conda_env" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} ${EXTRA_CONDA_INSTALL_FLAGS}
 | 
			
		||||
          conda activate test_conda_env
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -487,12 +550,33 @@ jobs:
 | 
			
		||||
          # Create new "clean" conda environment for testing
 | 
			
		||||
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
          if [[ $DESIRED_PYTHON == "3.13t" ]]; then
 | 
			
		||||
            conda create -yn "test_conda_env" python="3.13" python-freethreading -c conda-forge
 | 
			
		||||
            SMOKE_TEST_PARAMS="--torch-compile-check disabled"
 | 
			
		||||
          else
 | 
			
		||||
            conda create -yn "test_conda_env" python="$DESIRED_PYTHON"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          EXTRA_CONDA_INSTALL_FLAGS=""
 | 
			
		||||
          CONDA_ENV_CREATE_FLAGS=""
 | 
			
		||||
          # shellcheck disable=SC2153
 | 
			
		||||
          case $DESIRED_PYTHON in
 | 
			
		||||
            3.14t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.14)
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.13t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge"
 | 
			
		||||
              desired_python="3.13"
 | 
			
		||||
              ;;
 | 
			
		||||
            *)
 | 
			
		||||
              # shellcheck disable=SC2153
 | 
			
		||||
              desired_python=${DESIRED_PYTHON}
 | 
			
		||||
              ;;
 | 
			
		||||
          esac
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          conda create -yn "test_conda_env" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} ${EXTRA_CONDA_INSTALL_FLAGS}
 | 
			
		||||
          conda activate test_conda_env
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -611,12 +695,33 @@ jobs:
 | 
			
		||||
          # Create new "clean" conda environment for testing
 | 
			
		||||
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
          if [[ $DESIRED_PYTHON == "3.13t" ]]; then
 | 
			
		||||
            conda create -yn "test_conda_env" python="3.13" python-freethreading -c conda-forge
 | 
			
		||||
            SMOKE_TEST_PARAMS="--torch-compile-check disabled"
 | 
			
		||||
          else
 | 
			
		||||
            conda create -yn "test_conda_env" python="$DESIRED_PYTHON"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          EXTRA_CONDA_INSTALL_FLAGS=""
 | 
			
		||||
          CONDA_ENV_CREATE_FLAGS=""
 | 
			
		||||
          # shellcheck disable=SC2153
 | 
			
		||||
          case $DESIRED_PYTHON in
 | 
			
		||||
            3.14t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.14)
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.13t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge"
 | 
			
		||||
              desired_python="3.13"
 | 
			
		||||
              ;;
 | 
			
		||||
            *)
 | 
			
		||||
              # shellcheck disable=SC2153
 | 
			
		||||
              desired_python=${DESIRED_PYTHON}
 | 
			
		||||
              ;;
 | 
			
		||||
          esac
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          conda create -yn "test_conda_env" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} ${EXTRA_CONDA_INSTALL_FLAGS}
 | 
			
		||||
          conda activate test_conda_env
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -735,12 +840,33 @@ jobs:
 | 
			
		||||
          # Create new "clean" conda environment for testing
 | 
			
		||||
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
          if [[ $DESIRED_PYTHON == "3.13t" ]]; then
 | 
			
		||||
            conda create -yn "test_conda_env" python="3.13" python-freethreading -c conda-forge
 | 
			
		||||
            SMOKE_TEST_PARAMS="--torch-compile-check disabled"
 | 
			
		||||
          else
 | 
			
		||||
            conda create -yn "test_conda_env" python="$DESIRED_PYTHON"
 | 
			
		||||
          fi
 | 
			
		||||
 | 
			
		||||
          EXTRA_CONDA_INSTALL_FLAGS=""
 | 
			
		||||
          CONDA_ENV_CREATE_FLAGS=""
 | 
			
		||||
          # shellcheck disable=SC2153
 | 
			
		||||
          case $DESIRED_PYTHON in
 | 
			
		||||
            3.14t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.14)
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.13t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge"
 | 
			
		||||
              desired_python="3.13"
 | 
			
		||||
              ;;
 | 
			
		||||
            *)
 | 
			
		||||
              # shellcheck disable=SC2153
 | 
			
		||||
              desired_python=${DESIRED_PYTHON}
 | 
			
		||||
              ;;
 | 
			
		||||
          esac
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          conda create -yn "test_conda_env" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} ${EXTRA_CONDA_INSTALL_FLAGS}
 | 
			
		||||
          conda activate test_conda_env
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
@ -774,3 +900,293 @@ jobs:
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
  wheel-py3_14-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-14-xlarge
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      PACKAGE_TYPE: wheel
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cpu
 | 
			
		||||
      GPU_ARCH_TYPE: cpu
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      DESIRED_PYTHON: "3.14"
 | 
			
		||||
    steps:
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          # shellcheck disable=SC2129
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          # shellcheck disable=SC2129
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          # shellcheck disable=SC2129
 | 
			
		||||
          echo "MAC_PACKAGE_WORK_DIR=${RUNNER_TEMP}" >> "${GITHUB_ENV}"
 | 
			
		||||
      - name: Install conda and dependencies
 | 
			
		||||
        run: |
 | 
			
		||||
          # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on
 | 
			
		||||
          curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" "https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-MacOSX-$(uname -m).sh"
 | 
			
		||||
          chmod +x "${RUNNER_TEMP}/conda.sh"
 | 
			
		||||
          /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda"
 | 
			
		||||
          echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}"
 | 
			
		||||
          if [ -d "/Applications/Xcode_14.3.1.app" ]; then
 | 
			
		||||
            echo "DEVELOPER_DIR=/Applications/Xcode_14.3.1.app/Contents/Developer" >> "${GITHUB_ENV}"
 | 
			
		||||
          elif [ -d "/Applications/Xcode_13.3.1.app" ]; then
 | 
			
		||||
            echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}"
 | 
			
		||||
          fi
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          show-progress: false
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        run: |
 | 
			
		||||
          # shellcheck disable=SC1091
 | 
			
		||||
          source "${RUNNER_TEMP}/anaconda/bin/activate"
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Build PyTorch binary
 | 
			
		||||
        run: |
 | 
			
		||||
          # shellcheck disable=SC1091
 | 
			
		||||
          source "${RUNNER_TEMP}/anaconda/bin/activate"
 | 
			
		||||
          set -eux -o pipefail
 | 
			
		||||
          # shellcheck disable=SC1090
 | 
			
		||||
          source "${BINARY_ENV_FILE:-/Users/distiller/project/env}"
 | 
			
		||||
          mkdir -p "$PYTORCH_FINAL_PACKAGE_DIR"
 | 
			
		||||
 | 
			
		||||
          # Build
 | 
			
		||||
          USE_PYTORCH_METAL_EXPORT=1
 | 
			
		||||
          USE_COREML_DELEGATE=1
 | 
			
		||||
          TORCH_PACKAGE_NAME="${TORCH_PACKAGE_NAME//-/_}"
 | 
			
		||||
          export USE_PYTORCH_METAL_EXPORT
 | 
			
		||||
          export USE_COREML_DELEGATE
 | 
			
		||||
          export TORCH_PACKAGE_NAME
 | 
			
		||||
          "${PYTORCH_ROOT}/.ci/wheel/build_wheel.sh"
 | 
			
		||||
      - name: Test PyTorch wheel
 | 
			
		||||
        run: |
 | 
			
		||||
          # shellcheck disable=SC1091
 | 
			
		||||
          source "${RUNNER_TEMP}/anaconda/bin/activate"
 | 
			
		||||
          set -eux -o pipefail
 | 
			
		||||
          # shellcheck disable=SC1090
 | 
			
		||||
          source "${BINARY_ENV_FILE:-/Users/distiller/project/env}"
 | 
			
		||||
          pip uninstall -y "$TORCH_PACKAGE_NAME" || true
 | 
			
		||||
          pip uninstall -y "$TORCH_PACKAGE_NAME" || true
 | 
			
		||||
 | 
			
		||||
          # Create new "clean" conda environment for testing
 | 
			
		||||
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
 | 
			
		||||
          EXTRA_CONDA_INSTALL_FLAGS=""
 | 
			
		||||
          CONDA_ENV_CREATE_FLAGS=""
 | 
			
		||||
          # shellcheck disable=SC2153
 | 
			
		||||
          case $DESIRED_PYTHON in
 | 
			
		||||
            3.14t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.14)
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.13t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge"
 | 
			
		||||
              desired_python="3.13"
 | 
			
		||||
              ;;
 | 
			
		||||
            *)
 | 
			
		||||
              # shellcheck disable=SC2153
 | 
			
		||||
              desired_python=${DESIRED_PYTHON}
 | 
			
		||||
              ;;
 | 
			
		||||
          esac
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          conda create -yn "test_conda_env" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} ${EXTRA_CONDA_INSTALL_FLAGS}
 | 
			
		||||
          conda activate test_conda_env
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          python "${PYTORCH_ROOT}/.ci/pytorch/smoke_test/smoke_test.py" --package torchonly ${SMOKE_TEST_PARAMS}
 | 
			
		||||
      - uses: actions/upload-artifact@v4.4.0
 | 
			
		||||
        if: always()
 | 
			
		||||
        with:
 | 
			
		||||
          name: wheel-py3_14-cpu
 | 
			
		||||
          retention-days: 14
 | 
			
		||||
          if-no-files-found: error
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
  wheel-py3_14-cpu-upload:  # Uploading
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    needs: wheel-py3_14-cpu-build
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      PACKAGE_TYPE: wheel
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cpu
 | 
			
		||||
      GPU_ARCH_TYPE: cpu
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu
 | 
			
		||||
      DESIRED_PYTHON: "3.14"
 | 
			
		||||
      build_name: wheel-py3_14-cpu
 | 
			
		||||
      use_s3: False
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
  wheel-py3_14t-cpu-build:
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: macos-14-xlarge
 | 
			
		||||
    timeout-minutes: 240
 | 
			
		||||
    env:
 | 
			
		||||
      PYTORCH_ROOT: ${{ github.workspace }}/pytorch
 | 
			
		||||
      PACKAGE_TYPE: wheel
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cpu
 | 
			
		||||
      GPU_ARCH_TYPE: cpu
 | 
			
		||||
      SKIP_ALL_TESTS: 1
 | 
			
		||||
      DESIRED_PYTHON: "3.14t"
 | 
			
		||||
    steps:
 | 
			
		||||
      # NOTE: These environment variables are put here so that they can be applied on every job equally
 | 
			
		||||
      #       They are also here because setting them at a workflow level doesn't give us access to the
 | 
			
		||||
      #       runner.temp variable, which we need.
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          # shellcheck disable=SC2129
 | 
			
		||||
          echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}"
 | 
			
		||||
          # shellcheck disable=SC2129
 | 
			
		||||
          echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}"
 | 
			
		||||
          # shellcheck disable=SC2129
 | 
			
		||||
          echo "MAC_PACKAGE_WORK_DIR=${RUNNER_TEMP}" >> "${GITHUB_ENV}"
 | 
			
		||||
      - name: Install conda and dependencies
 | 
			
		||||
        run: |
 | 
			
		||||
          # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on
 | 
			
		||||
          curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" "https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-MacOSX-$(uname -m).sh"
 | 
			
		||||
          chmod +x "${RUNNER_TEMP}/conda.sh"
 | 
			
		||||
          /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda"
 | 
			
		||||
          echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}"
 | 
			
		||||
          if [ -d "/Applications/Xcode_14.3.1.app" ]; then
 | 
			
		||||
            echo "DEVELOPER_DIR=/Applications/Xcode_14.3.1.app/Contents/Developer" >> "${GITHUB_ENV}"
 | 
			
		||||
          elif [ -d "/Applications/Xcode_13.3.1.app" ]; then
 | 
			
		||||
            echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}"
 | 
			
		||||
          fi
 | 
			
		||||
      - name: Checkout PyTorch
 | 
			
		||||
        uses: actions/checkout@v4
 | 
			
		||||
        with:
 | 
			
		||||
          ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
 | 
			
		||||
          submodules: recursive
 | 
			
		||||
          path: pytorch
 | 
			
		||||
          show-progress: false
 | 
			
		||||
      - name: Clean PyTorch checkout
 | 
			
		||||
        run: |
 | 
			
		||||
          # Remove any artifacts from the previous checkouts
 | 
			
		||||
          git clean -fxd
 | 
			
		||||
        working-directory: pytorch
 | 
			
		||||
      - name: Populate binary env
 | 
			
		||||
        run: |
 | 
			
		||||
          # shellcheck disable=SC1091
 | 
			
		||||
          source "${RUNNER_TEMP}/anaconda/bin/activate"
 | 
			
		||||
          "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh"
 | 
			
		||||
      - name: Build PyTorch binary
 | 
			
		||||
        run: |
 | 
			
		||||
          # shellcheck disable=SC1091
 | 
			
		||||
          source "${RUNNER_TEMP}/anaconda/bin/activate"
 | 
			
		||||
          set -eux -o pipefail
 | 
			
		||||
          # shellcheck disable=SC1090
 | 
			
		||||
          source "${BINARY_ENV_FILE:-/Users/distiller/project/env}"
 | 
			
		||||
          mkdir -p "$PYTORCH_FINAL_PACKAGE_DIR"
 | 
			
		||||
 | 
			
		||||
          # Build
 | 
			
		||||
          USE_PYTORCH_METAL_EXPORT=1
 | 
			
		||||
          USE_COREML_DELEGATE=1
 | 
			
		||||
          TORCH_PACKAGE_NAME="${TORCH_PACKAGE_NAME//-/_}"
 | 
			
		||||
          export USE_PYTORCH_METAL_EXPORT
 | 
			
		||||
          export USE_COREML_DELEGATE
 | 
			
		||||
          export TORCH_PACKAGE_NAME
 | 
			
		||||
          "${PYTORCH_ROOT}/.ci/wheel/build_wheel.sh"
 | 
			
		||||
      - name: Test PyTorch wheel
 | 
			
		||||
        run: |
 | 
			
		||||
          # shellcheck disable=SC1091
 | 
			
		||||
          source "${RUNNER_TEMP}/anaconda/bin/activate"
 | 
			
		||||
          set -eux -o pipefail
 | 
			
		||||
          # shellcheck disable=SC1090
 | 
			
		||||
          source "${BINARY_ENV_FILE:-/Users/distiller/project/env}"
 | 
			
		||||
          pip uninstall -y "$TORCH_PACKAGE_NAME" || true
 | 
			
		||||
          pip uninstall -y "$TORCH_PACKAGE_NAME" || true
 | 
			
		||||
 | 
			
		||||
          # Create new "clean" conda environment for testing
 | 
			
		||||
 | 
			
		||||
          SMOKE_TEST_PARAMS=""
 | 
			
		||||
 | 
			
		||||
          EXTRA_CONDA_INSTALL_FLAGS=""
 | 
			
		||||
          CONDA_ENV_CREATE_FLAGS=""
 | 
			
		||||
          # shellcheck disable=SC2153
 | 
			
		||||
          case $DESIRED_PYTHON in
 | 
			
		||||
            3.14t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.14)
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge/label/python_rc -c conda-forge"
 | 
			
		||||
              desired_python="3.14.0rc1"
 | 
			
		||||
              ;;
 | 
			
		||||
            3.13t)
 | 
			
		||||
              CONDA_ENV_CREATE_FLAGS="python-freethreading"
 | 
			
		||||
              EXTRA_CONDA_INSTALL_FLAGS="-c conda-forge"
 | 
			
		||||
              desired_python="3.13"
 | 
			
		||||
              ;;
 | 
			
		||||
            *)
 | 
			
		||||
              # shellcheck disable=SC2153
 | 
			
		||||
              desired_python=${DESIRED_PYTHON}
 | 
			
		||||
              ;;
 | 
			
		||||
          esac
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          conda create -yn "test_conda_env" python="$desired_python" ${CONDA_ENV_CREATE_FLAGS} ${EXTRA_CONDA_INSTALL_FLAGS}
 | 
			
		||||
          conda activate test_conda_env
 | 
			
		||||
          pip install "$PYTORCH_FINAL_PACKAGE_DIR"/*.whl numpy -v
 | 
			
		||||
 | 
			
		||||
          # shellcheck disable=SC2086
 | 
			
		||||
          python "${PYTORCH_ROOT}/.ci/pytorch/smoke_test/smoke_test.py" --package torchonly ${SMOKE_TEST_PARAMS}
 | 
			
		||||
      - uses: actions/upload-artifact@v4.4.0
 | 
			
		||||
        if: always()
 | 
			
		||||
        with:
 | 
			
		||||
          name: wheel-py3_14t-cpu
 | 
			
		||||
          retention-days: 14
 | 
			
		||||
          if-no-files-found: error
 | 
			
		||||
          path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}"
 | 
			
		||||
  wheel-py3_14t-cpu-upload:  # Uploading
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    permissions:
 | 
			
		||||
      id-token: write
 | 
			
		||||
      contents: read
 | 
			
		||||
    needs: wheel-py3_14t-cpu-build
 | 
			
		||||
    with:
 | 
			
		||||
      PYTORCH_ROOT: /pytorch
 | 
			
		||||
      PACKAGE_TYPE: wheel
 | 
			
		||||
      # TODO: This is a legacy variable that we eventually want to get rid of in
 | 
			
		||||
      #       favor of GPU_ARCH_VERSION
 | 
			
		||||
      DESIRED_CUDA: cpu
 | 
			
		||||
      GPU_ARCH_TYPE: cpu
 | 
			
		||||
      DOCKER_IMAGE: manylinux2_28-builder
 | 
			
		||||
      DOCKER_IMAGE_TAG_PREFIX: cpu
 | 
			
		||||
      DESIRED_PYTHON: "3.14t"
 | 
			
		||||
      build_name: wheel-py3_14t-cpu
 | 
			
		||||
      use_s3: False
 | 
			
		||||
    secrets:
 | 
			
		||||
      github-token: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
    uses: ./.github/workflows/_binary-upload.yml
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										5
									
								
								.github/workflows/h100-cutlass-backend.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/workflows/h100-cutlass-backend.yml
									
									
									
									
										vendored
									
									
								
							@ -4,9 +4,12 @@ on:
 | 
			
		||||
  pull_request:
 | 
			
		||||
    paths:
 | 
			
		||||
      - .github/workflows/h100-cutlass-backend.yml
 | 
			
		||||
      - torch/_inductor/codegen/cuda/**
 | 
			
		||||
      - test/inductor/test_cutlass_backend.py
 | 
			
		||||
      - test/inductor/test_cutlass_evt.py
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
  schedule:
 | 
			
		||||
    - cron: 22 9 * * *  # every 24 hours about 2:22am PDT
 | 
			
		||||
    - cron: 22 9,21 * * *  # every 12 hours
 | 
			
		||||
  push:
 | 
			
		||||
    tags:
 | 
			
		||||
      - ciflow/h100-cutlass-backend/*
 | 
			
		||||
 | 
			
		||||
@ -88,23 +88,23 @@ jobs:
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 1, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 2, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 3, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_huggingface_perf_rocm", shard: 4, num_shards: 4, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 1, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 2, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 3, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 4, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_timm_perf_rocm", shard: 5, num_shards: 5, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 1, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 2, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 3, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 4, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 5, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 6, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 7, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor_torchbench_perf_rocm", shard: 8, num_shards: 8, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										30
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										30
									
								
								.github/workflows/inductor-periodic.yml
									
									
									
									
										vendored
									
									
								
							@ -81,21 +81,21 @@ jobs:
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "dynamo_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "dynamo_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "dynamo_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "dynamo_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "dynamo_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "dynamic_aot_eager_torchbench", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "dynamic_aot_eager_torchbench", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "dynamic_aot_eager_huggingface", shard: 1, num_shards: 1, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								.github/workflows/inductor-rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/inductor-rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							@ -47,8 +47,8 @@ jobs:
 | 
			
		||||
      docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "inductor", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "inductor", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "inductor", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										1
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/pull.yml
									
									
									
									
										vendored
									
									
								
							@ -251,7 +251,6 @@ jobs:
 | 
			
		||||
      build-environment: linux-jammy-py3.13-clang12
 | 
			
		||||
      docker-image: ${{ needs.linux-jammy-py3_13-clang12-build.outputs.docker-image }}
 | 
			
		||||
      test-matrix: ${{ needs.linux-jammy-py3_13-clang12-build.outputs.test-matrix }}
 | 
			
		||||
      timeout-minutes: 600
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
  linux-jammy-cuda12_8-cudnn9-py3_9-clang12-build:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										24
									
								
								.github/workflows/riscv64.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								.github/workflows/riscv64.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,24 @@
 | 
			
		||||
name: riscv64
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    tags:
 | 
			
		||||
      - ciflow/riscv64/*
 | 
			
		||||
  workflow_dispatch:
 | 
			
		||||
 | 
			
		||||
concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
 | 
			
		||||
  cancel-in-progress: true
 | 
			
		||||
 | 
			
		||||
permissions: read-all
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  pytorch-linux-noble-riscv64-py3_12-gcc14-cross-build:
 | 
			
		||||
    if: github.repository_owner == 'pytorch'
 | 
			
		||||
    name: pytorch-linux-noble-riscv64-py3_12-gcc14-cross-build
 | 
			
		||||
    uses: ./.github/workflows/_linux-build.yml
 | 
			
		||||
    with:
 | 
			
		||||
      build-environment: linux-noble-riscv64-py3.12-gcc14
 | 
			
		||||
      docker-image-name: pytorch-linux-noble-riscv64-py3.12-gcc14
 | 
			
		||||
      runner: linux.2xlarge
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
							
								
								
									
										12
									
								
								.github/workflows/rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/rocm-mi300.yml
									
									
									
									
										vendored
									
									
								
							@ -48,12 +48,12 @@ jobs:
 | 
			
		||||
      sync-tag: rocm-build
 | 
			
		||||
      test-matrix: |
 | 
			
		||||
        { include: [
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.2" },
 | 
			
		||||
          { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
          { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" },
 | 
			
		||||
        ]}
 | 
			
		||||
    secrets: inherit
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										70
									
								
								.github/workflows/tools-unit-tests.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								.github/workflows/tools-unit-tests.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@ -0,0 +1,70 @@
 | 
			
		||||
name: test-scripts-and-ci-tools
 | 
			
		||||
 | 
			
		||||
on:
 | 
			
		||||
  push:
 | 
			
		||||
    branches:
 | 
			
		||||
      - main
 | 
			
		||||
    paths:
 | 
			
		||||
      - scripts/lumen_cli/**
 | 
			
		||||
      - .github/workflows/tools-unit-tests.yml
 | 
			
		||||
  pull_request:
 | 
			
		||||
    paths:
 | 
			
		||||
      - scripts/lumen_cli/**
 | 
			
		||||
      - .github/workflows/tools-unit-tests.yml
 | 
			
		||||
 | 
			
		||||
concurrency:
 | 
			
		||||
  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}
 | 
			
		||||
  cancel-in-progress: true
 | 
			
		||||
 | 
			
		||||
jobs:
 | 
			
		||||
  lumen-cli-unit-tests-python312:
 | 
			
		||||
    permissions:
 | 
			
		||||
      contents: read
 | 
			
		||||
      pull-requests: write
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Checkout pytorch
 | 
			
		||||
        uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
 | 
			
		||||
        with:
 | 
			
		||||
          submodules: true
 | 
			
		||||
          fetch-depth: 0
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
 | 
			
		||||
        with:
 | 
			
		||||
          python-version: '3.12'
 | 
			
		||||
          cache: pip
 | 
			
		||||
 | 
			
		||||
      - name: Run tests
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        run: |
 | 
			
		||||
          set -ex
 | 
			
		||||
          python3 -m venv /tmp/venv
 | 
			
		||||
          source /tmp/venv/bin/activate
 | 
			
		||||
          pip install -e .ci/lumen_cli/
 | 
			
		||||
          pytest -v -s .ci/lumen_cli/tests/*
 | 
			
		||||
 | 
			
		||||
  lumen-cli-compatible-python39:
 | 
			
		||||
    permissions:
 | 
			
		||||
      contents: read
 | 
			
		||||
      pull-requests: write
 | 
			
		||||
    if: ${{ github.repository_owner == 'pytorch' }}
 | 
			
		||||
    runs-on: ubuntu-latest
 | 
			
		||||
    steps:
 | 
			
		||||
      - name: Checkout pytorch
 | 
			
		||||
        uses: pytorch/pytorch/.github/actions/checkout-pytorch@main
 | 
			
		||||
        with:
 | 
			
		||||
          submodules: true
 | 
			
		||||
          fetch-depth: 0
 | 
			
		||||
      - name: Setup Python
 | 
			
		||||
        uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
 | 
			
		||||
        with:
 | 
			
		||||
          python-version: '3.9'
 | 
			
		||||
          cache: 'pip'
 | 
			
		||||
      - name: Run tests
 | 
			
		||||
        continue-on-error: true
 | 
			
		||||
        run: |
 | 
			
		||||
          set -ex
 | 
			
		||||
          python3 -m venv /tmp/venv
 | 
			
		||||
          source /tmp/venv/bin/activate
 | 
			
		||||
          pip install -e .ci/lumen_cli/
 | 
			
		||||
@ -1196,7 +1196,7 @@ if(APPLE)
 | 
			
		||||
    string(
 | 
			
		||||
      APPEND
 | 
			
		||||
      CMAKE_SHARED_LINKER_FLAGS
 | 
			
		||||
      " -weak_framework Foundation -weak_framework MetalPerformanceShaders -weak_framework MetalPerformanceShadersGraph -weak_framework Metal"
 | 
			
		||||
      " -weak_framework Foundation -weak_framework MetalPerformanceShaders -weak_framework MetalPerformanceShadersGraph -weak_framework Metal -weak_framework IOKit"
 | 
			
		||||
    )
 | 
			
		||||
    # To suppress MPSGraph availability warnings
 | 
			
		||||
    append_cxx_flag_if_supported("-Wno-unguarded-availability-new"
 | 
			
		||||
 | 
			
		||||
@ -2,7 +2,7 @@
 | 
			
		||||
 | 
			
		||||
## Demo applications and tutorials
 | 
			
		||||
 | 
			
		||||
Please refer to [pytorch-labs/executorch-examples](https://github.com/pytorch-labs/executorch-examples/tree/main/dl3/android/DeepLabV3Demo) for the Android demo app based on [ExecuTorch](https://github.com/pytorch/executorch).
 | 
			
		||||
Please refer to [meta-pytorch/executorch-examples](https://github.com/meta-pytorch/executorch-examples/tree/main/dl3/android/DeepLabV3Demo) for the Android demo app based on [ExecuTorch](https://github.com/pytorch/executorch).
 | 
			
		||||
 | 
			
		||||
Please join our [Discord](https://discord.com/channels/1334270993966825602/1349854760299270284) for any questions.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -216,6 +216,7 @@ TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) {
 | 
			
		||||
  KERNEL_MPS(_convolution, lower_precision_fp)
 | 
			
		||||
  KERNEL_MPS(conv1d, lower_precision_fp)
 | 
			
		||||
  KERNEL_MPS(conv2d, lower_precision_fp)
 | 
			
		||||
  KERNEL_MPS(conv3d, lower_precision_fp)
 | 
			
		||||
  KERNEL_MPS(conv_tbc, lower_precision_fp)
 | 
			
		||||
  KERNEL_MPS(conv_transpose1d, lower_precision_fp)
 | 
			
		||||
  KERNEL_MPS(conv_transpose2d, input, lower_precision_fp)
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <c10/core/Allocator.h>
 | 
			
		||||
#include <c10/core/AllocatorConfig.h>
 | 
			
		||||
#include <c10/core/Stream.h>
 | 
			
		||||
#include <c10/core/thread_pool.h>
 | 
			
		||||
#include <c10/util/flat_hash_map.h>
 | 
			
		||||
@ -251,6 +252,7 @@ struct CachingHostAllocatorImpl {
 | 
			
		||||
    auto* block = reinterpret_cast<B*>(ctx);
 | 
			
		||||
 | 
			
		||||
    std::optional<std::vector<E>> events;
 | 
			
		||||
    ska::flat_hash_set<S> streams;
 | 
			
		||||
    {
 | 
			
		||||
      std::lock_guard<std::mutex> g(block->mutex_);
 | 
			
		||||
      block->allocated_ = false;
 | 
			
		||||
@ -259,14 +261,19 @@ struct CachingHostAllocatorImpl {
 | 
			
		||||
      } else {
 | 
			
		||||
        events = std::vector<E>();
 | 
			
		||||
        events->reserve(block->streams_.size());
 | 
			
		||||
        for (auto stream : block->streams_) {
 | 
			
		||||
          record_stream(events, stream);
 | 
			
		||||
        }
 | 
			
		||||
        block->event_count_ += events->size();
 | 
			
		||||
        block->event_count_ += block->streams_.size();
 | 
			
		||||
        // Move out streams to avoid holding the mutex during event recording
 | 
			
		||||
        streams = std::move(block->streams_);
 | 
			
		||||
        block->streams_.clear();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Event recording must be done outside the mutex to avoid potential
 | 
			
		||||
    // deadlocks (e.g., when Python GIL is involved)
 | 
			
		||||
    for (auto stream : streams) {
 | 
			
		||||
      record_stream(events, stream);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (!events) {
 | 
			
		||||
      auto index = size_index(block->size_);
 | 
			
		||||
      std::lock_guard<std::mutex> g(free_list_[index].mutex_);
 | 
			
		||||
@ -345,7 +352,8 @@ struct CachingHostAllocatorImpl {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual bool pinned_use_background_threads() {
 | 
			
		||||
    return false;
 | 
			
		||||
    return c10::CachingAllocator::AcceleratorAllocatorConfig::
 | 
			
		||||
        pinned_use_background_threads();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  virtual void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]]) const {
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,8 @@
 | 
			
		||||
#include <c10/core/DispatchKeySet.h>
 | 
			
		||||
#include <c10/util/TypeList.h>
 | 
			
		||||
#include <c10/util/intrusive_ptr.h>
 | 
			
		||||
#include <atomic>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <type_traits>
 | 
			
		||||
 | 
			
		||||
namespace c10 {
 | 
			
		||||
@ -17,6 +19,9 @@ class OperatorHandle;
 | 
			
		||||
struct OperatorKernel;
 | 
			
		||||
class KernelFunction;
 | 
			
		||||
 | 
			
		||||
class KernelToken;
 | 
			
		||||
class SafeKernelFunction;
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
using has_symint = std::disjunction<
 | 
			
		||||
    std::is_same<c10::SymInt, T>,
 | 
			
		||||
@ -90,6 +95,12 @@ class TORCH_API KernelFunction final {
 | 
			
		||||
      BoxedKernel::BoxedKernelFunction_withDispatchKeys;
 | 
			
		||||
 | 
			
		||||
  KernelFunction();
 | 
			
		||||
  ~KernelFunction();
 | 
			
		||||
 | 
			
		||||
  KernelFunction(const KernelFunction&) = default;
 | 
			
		||||
  KernelFunction& operator=(const KernelFunction&) = default;
 | 
			
		||||
 | 
			
		||||
  KernelFunction(KernelFunction&&) noexcept = default;
 | 
			
		||||
 | 
			
		||||
  // Fast path for dispatch to allow not touching the boxed kernel in
 | 
			
		||||
  // the common case where unboxed is available.
 | 
			
		||||
@ -262,6 +273,13 @@ class TORCH_API KernelFunction final {
 | 
			
		||||
  // For testing internal invariants only
 | 
			
		||||
  bool _equalsBoxedAndUnboxed(const KernelFunction&) const;
 | 
			
		||||
 | 
			
		||||
  // Register a token to be invalidated when this KernelFunction is destroyed
 | 
			
		||||
  void registerToken(std::weak_ptr<KernelToken> token) const;
 | 
			
		||||
 | 
			
		||||
  // List of tokens that need to be invalidated when this KernelFunction is
 | 
			
		||||
  // destroyed
 | 
			
		||||
  mutable std::vector<std::weak_ptr<KernelToken>> tokens_;
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  explicit KernelFunction(
 | 
			
		||||
      std::unique_ptr<OperatorKernel> functor,
 | 
			
		||||
@ -278,6 +296,47 @@ class TORCH_API KernelFunction final {
 | 
			
		||||
  void* sym_unboxed_kernel_func_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Token held by SafeKernelFunction that gets invalidated when KernelFunction is
 | 
			
		||||
// destroyed
 | 
			
		||||
class KernelToken {
 | 
			
		||||
 public:
 | 
			
		||||
  bool isValid() const;
 | 
			
		||||
  void invalidate();
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  std::atomic<bool> invalid_{false};
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class SafeKernelFunction {
 | 
			
		||||
 public:
 | 
			
		||||
  SafeKernelFunction(
 | 
			
		||||
      const KernelFunction* kernel,
 | 
			
		||||
      std::string debug,
 | 
			
		||||
      std::shared_ptr<OperatorHandle> opHandle);
 | 
			
		||||
 | 
			
		||||
  // Safe callBoxed - checks token validity first
 | 
			
		||||
  void callBoxed(
 | 
			
		||||
      const OperatorHandle& opHandle,
 | 
			
		||||
      DispatchKeySet dispatchKeySet,
 | 
			
		||||
      Stack* stack) const;
 | 
			
		||||
 | 
			
		||||
  // Get debug information
 | 
			
		||||
  const std::string& debug() const {
 | 
			
		||||
    return debug_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Get the OpHandle that lives on this SafeKernelFunction
 | 
			
		||||
  const OperatorHandle& opHandle() const {
 | 
			
		||||
    return *opHandle_;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
  KernelFunction kernel_;
 | 
			
		||||
  std::shared_ptr<KernelToken> token_;
 | 
			
		||||
  std::string debug_;
 | 
			
		||||
  std::shared_ptr<OperatorHandle> opHandle_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
} // namespace c10
 | 
			
		||||
 | 
			
		||||
#include <ATen/core/boxing/KernelFunction_impl.h>
 | 
			
		||||
 | 
			
		||||
@ -24,6 +24,14 @@ inline KernelFunction::KernelFunction()
 | 
			
		||||
      unboxed_kernel_func_(nullptr),
 | 
			
		||||
      sym_unboxed_kernel_func_(nullptr) {}
 | 
			
		||||
 | 
			
		||||
inline KernelFunction::~KernelFunction() {
 | 
			
		||||
  for (auto& weak_token : tokens_) {
 | 
			
		||||
    if (auto token = weak_token.lock()) {
 | 
			
		||||
      token->invalidate();
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline KernelFunction::KernelFunction(
 | 
			
		||||
    std::unique_ptr<OperatorKernel> functor,
 | 
			
		||||
    InternalBoxedKernelFunction* boxed_kernel_func,
 | 
			
		||||
@ -157,6 +165,11 @@ C10_ALWAYS_INLINE Return KernelFunction::call(
 | 
			
		||||
      std::forward<Args>(args)...);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline void KernelFunction::registerToken(
 | 
			
		||||
    std::weak_ptr<KernelToken> token) const {
 | 
			
		||||
  tokens_.push_back(std::move(token));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline KernelFunction KernelFunction::makeFromBoxedKernel(
 | 
			
		||||
    BoxedKernel boxed_fn) {
 | 
			
		||||
  return KernelFunction(
 | 
			
		||||
@ -317,4 +330,38 @@ KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
 | 
			
		||||
          std::forward<Lambda>(lambda)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline bool KernelToken::isValid() const {
 | 
			
		||||
  return !invalid_.load(std::memory_order_acquire);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline void KernelToken::invalidate() {
 | 
			
		||||
  invalid_.store(true, std::memory_order_release);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline SafeKernelFunction::SafeKernelFunction(
 | 
			
		||||
    const KernelFunction* kernel,
 | 
			
		||||
    std::string debug,
 | 
			
		||||
    std::shared_ptr<OperatorHandle> opHandle)
 | 
			
		||||
    : kernel_(kernel ? *kernel : KernelFunction()),
 | 
			
		||||
      token_(std::make_shared<KernelToken>()),
 | 
			
		||||
      debug_(std::move(debug)),
 | 
			
		||||
      opHandle_(std::move(opHandle)) {
 | 
			
		||||
  // Register the token with the original kernel so it gets invalidated when the
 | 
			
		||||
  // kernel is destroyed
 | 
			
		||||
  if (kernel) {
 | 
			
		||||
    kernel->registerToken(token_);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline void SafeKernelFunction::callBoxed(
 | 
			
		||||
    const OperatorHandle& opHandle,
 | 
			
		||||
    DispatchKeySet dispatchKeySet,
 | 
			
		||||
    Stack* stack) const {
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      token_ && token_->isValid(),
 | 
			
		||||
      "SafeKernelFunction has been invalidated ",
 | 
			
		||||
      debug_);
 | 
			
		||||
  kernel_.callBoxed(opHandle, dispatchKeySet, stack);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace c10
 | 
			
		||||
 | 
			
		||||
@ -487,6 +487,10 @@ class TORCH_API OperatorHandle {
 | 
			
		||||
    return operatorDef_->op.hasComputedKernelForDispatchKey(k);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const {
 | 
			
		||||
    return operatorDef_->op.getComputedKernelForDispatchKey(k);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::string dumpComputedTable() const {
 | 
			
		||||
    return operatorDef_->op.dumpComputedTable();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -315,6 +315,42 @@ const AnnotatedKernel* OperatorEntry::getKernelForDispatchKey(DispatchKey dispat
 | 
			
		||||
  return nullptr;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
SafeKernelFunction OperatorEntry::getComputedKernelForDispatchKey(
 | 
			
		||||
    DispatchKey k) const {
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      !isAliasDispatchKey(k),
 | 
			
		||||
      "Alias keys do not have runtime kernel registrations.");
 | 
			
		||||
  const auto dispatch_ix = getDispatchTableIndexForDispatchKey(k);
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      dispatchTable_[dispatch_ix].isValid(),
 | 
			
		||||
      "no kernel for ",
 | 
			
		||||
      k,
 | 
			
		||||
      " for ",
 | 
			
		||||
      name_);
 | 
			
		||||
 | 
			
		||||
  // Get the KernelFunction object from kernels_ to pass to SafeKernelFunction
 | 
			
		||||
 | 
			
		||||
  // The KernelFunction object in dispatchTable_ is a copy of the KernelFunction
 | 
			
		||||
  // in the AnnotatedKernel in kernels_. A KernelFunction is only truly
 | 
			
		||||
  // deregistered when the kernel is removed from kernels_. However, the
 | 
			
		||||
  // KernelFunction in dispatchTable_ might be removed before it is deregistered
 | 
			
		||||
  // (when a newer kernel is registered). Therefore, here we want to return a
 | 
			
		||||
  // SafeKernelFunction that is backed by the original KernelFunction in
 | 
			
		||||
  // kernels_, so that we only invalidate it when the kernel is deregistered.
 | 
			
		||||
  auto [annotatedKernel, _] =
 | 
			
		||||
      computeDispatchTableEntryWithDebug(c10::Dispatcher::singleton(), k);
 | 
			
		||||
 | 
			
		||||
  // Use findSchemaOrThrow to get OpHandle for the OperatorEntry
 | 
			
		||||
  auto& dispatcher = c10::Dispatcher::singleton();
 | 
			
		||||
  auto opHandle = dispatcher.findSchemaOrThrow(
 | 
			
		||||
      name_.name.c_str(), name_.overload_name.c_str());
 | 
			
		||||
 | 
			
		||||
  return SafeKernelFunction(
 | 
			
		||||
      &annotatedKernel.kernel,
 | 
			
		||||
      annotatedKernel.debug,
 | 
			
		||||
      std::make_shared<OperatorHandle>(opHandle));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const std::vector<at::Tag>& OperatorEntry::getTags() const {
 | 
			
		||||
  #if defined C10_MOBILE
 | 
			
		||||
    TORCH_CHECK(false, "tags are not saved for Mobile");
 | 
			
		||||
 | 
			
		||||
@ -217,6 +217,8 @@ class TORCH_API OperatorEntry final {
 | 
			
		||||
  const KernelFunction& kernelForDispatchKey(DispatchKey k) const;
 | 
			
		||||
  // Returns true if the "computed table" has an entry for a particular key.
 | 
			
		||||
  bool hasComputedKernelForDispatchKey(DispatchKey k) const;
 | 
			
		||||
  // Returns a KernelFunction corresponding to the kernel in dispatchTable
 | 
			
		||||
  SafeKernelFunction getComputedKernelForDispatchKey(DispatchKey k) const;
 | 
			
		||||
  // Returns all the operator tags added at the time of registration
 | 
			
		||||
  const std::vector<at::Tag>& getTags() const;
 | 
			
		||||
  void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback);
 | 
			
		||||
 | 
			
		||||
@ -161,11 +161,6 @@ struct CUDACachingHostAllocatorImpl
 | 
			
		||||
    return true;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool pinned_use_background_threads() override {
 | 
			
		||||
    return c10::CachingAllocator::AcceleratorAllocatorConfig::
 | 
			
		||||
        pinned_use_background_threads();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  EventPool::Event create_event_internal(DeviceIndex idx) {
 | 
			
		||||
    // Leak the event pool to avoid shutdown issue.
 | 
			
		||||
    static auto* event_pool = new EventPool();
 | 
			
		||||
 | 
			
		||||
@ -54,7 +54,7 @@
 | 
			
		||||
 | 
			
		||||
// There were many bc-breaking changes in major version release of CCCL v3.0.0
 | 
			
		||||
// Please see https://nvidia.github.io/cccl/cccl/3.0_migration_guide.html
 | 
			
		||||
#if CUB_VERSION >= 300000
 | 
			
		||||
#if CUB_VERSION >= 200800
 | 
			
		||||
#define CUB_V3_PLUS() true
 | 
			
		||||
#else
 | 
			
		||||
#define CUB_V3_PLUS() false
 | 
			
		||||
 | 
			
		||||
@ -55,6 +55,17 @@ class TORCH_API MPSDevice {
 | 
			
		||||
   */
 | 
			
		||||
  bool isMacOS13Plus(MacOSVersion version) const;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Returns device name
 | 
			
		||||
   */
 | 
			
		||||
  std::string getName() const;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * Returns number of GPU cores.
 | 
			
		||||
   * 1 Core = 16 ExecutionUnit x 8 ALU x 24 threads
 | 
			
		||||
   */
 | 
			
		||||
  unsigned getCoreCount() const;
 | 
			
		||||
 | 
			
		||||
  ~MPSDevice();
 | 
			
		||||
 | 
			
		||||
 private:
 | 
			
		||||
 | 
			
		||||
@ -85,10 +85,36 @@ bool MPSDevice::isMacOS13Plus(MacOSVersion version) const {
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::string MPSDevice::getName() const {
 | 
			
		||||
  @autoreleasepool {
 | 
			
		||||
    return [[_mtl_device name] UTF8String];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
unsigned MPSDevice::getCoreCount() const {
 | 
			
		||||
  io_iterator_t iterator = 0;
 | 
			
		||||
  io_registry_entry_t entry = 0;
 | 
			
		||||
  int core_count = 0;
 | 
			
		||||
  auto matchingDict = IOServiceMatching("AGXAccelerator");
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(matchingDict, "Failed to create matching dict");
 | 
			
		||||
  const auto status = IOServiceGetMatchingServices(kIOMainPortDefault, matchingDict, &iterator);
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(status == KERN_SUCCESS);
 | 
			
		||||
  while ((entry = IOIteratorNext(iterator)) != 0) {
 | 
			
		||||
    auto property = IORegistryEntryCreateCFProperty(entry, CFSTR("gpu-core-count"), kCFAllocatorDefault, 0);
 | 
			
		||||
    auto found = CFNumberGetValue(static_cast<CFNumberRef>(property), kCFNumberIntType, &core_count);
 | 
			
		||||
    CFRelease(property);
 | 
			
		||||
    IOObjectRelease(entry);
 | 
			
		||||
    if (found) {
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  IOObjectRelease(iterator);
 | 
			
		||||
  return core_count;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
at::Allocator* GetMPSAllocator(bool useSharedAllocator) {
 | 
			
		||||
  return getIMPSAllocator(useSharedAllocator);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool is_available() {
 | 
			
		||||
  return MPSDevice::getInstance()->device() != nil;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -362,20 +362,24 @@ inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Ten
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  bool can_use_miopen_channels_last_2d = false;
 | 
			
		||||
  // TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
 | 
			
		||||
  // See #64427
 | 
			
		||||
  static std::optional<bool> PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC");
 | 
			
		||||
  static bool suggest_nhwc = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC;
 | 
			
		||||
 | 
			
		||||
  auto input_memory_format = input.suggest_memory_format();
 | 
			
		||||
  auto weight_memory_format = weight.suggest_memory_format();
 | 
			
		||||
  auto weight_ndim = weight.ndimension();
 | 
			
		||||
 | 
			
		||||
  can_use_miopen_channels_last_2d = PYTORCH_MIOPEN_SUGGEST_NHWC &&  *PYTORCH_MIOPEN_SUGGEST_NHWC && (
 | 
			
		||||
            ( (input_memory_format  == at::MemoryFormat::ChannelsLast) ||
 | 
			
		||||
            (weight_memory_format == at::MemoryFormat::ChannelsLast) )
 | 
			
		||||
        );
 | 
			
		||||
  bool can_use_miopen_channels_last_2d = suggest_nhwc && (weight_ndim == 4) && (
 | 
			
		||||
    (input_memory_format  == at::MemoryFormat::ChannelsLast) ||
 | 
			
		||||
    (weight_memory_format == at::MemoryFormat::ChannelsLast)
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  bool can_use_miopen_channels_last_3d = false;
 | 
			
		||||
  bool can_use_miopen_channels_last_3d = suggest_nhwc && (weight_ndim == 5) && (
 | 
			
		||||
    (input_memory_format  == at::MemoryFormat::ChannelsLast3d) ||
 | 
			
		||||
    (weight_memory_format == at::MemoryFormat::ChannelsLast3d)
 | 
			
		||||
  );
 | 
			
		||||
 | 
			
		||||
  return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -1422,7 +1422,7 @@ static inline at::MemoryFormat determine_backend_memory_format(
 | 
			
		||||
      if (detail::getCUDAHooks().compiledWithMIOpen() && miopen_conv_use_channels_last(input, weight)) {
 | 
			
		||||
        TORCH_INTERNAL_ASSERT((k == 4 || k == 5),
 | 
			
		||||
            "Expected 4D or 5D input for miopen memory format selection in determine_backend_memory_format()");
 | 
			
		||||
        backend_memory_format = (k == 5) ? at::MemoryFormat::Contiguous /*at::MemoryFormat::ChannelsLast3d*/ : at::MemoryFormat::ChannelsLast;
 | 
			
		||||
        backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
 | 
			
		||||
      }
 | 
			
		||||
      break;
 | 
			
		||||
    case ConvBackend::Mkldnn:
 | 
			
		||||
 | 
			
		||||
@ -520,6 +520,15 @@ BatchNormBackend _select_batch_norm_backend(
 | 
			
		||||
    return BatchNormBackend::Cudnn;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM once ROCm officially supports NHWC in MIOpen
 | 
			
		||||
  // See https://github.com/pytorch/pytorch/issues/64427.
 | 
			
		||||
  // non static variable is used to be able to change environment variable in runtime for testing
 | 
			
		||||
  // enabled by default for ROCm >= 7.0.0 with miopen 3.5
 | 
			
		||||
  int miopen_version = detail::getCUDAHooks().compiledWithMIOpen() ? detail::getCUDAHooks().versionMIOpen() : 0;
 | 
			
		||||
  bool is_miopen_3_4 = miopen_version >= 30400;  // ROCm 6.4
 | 
			
		||||
  bool is_miopen_3_5 = miopen_version >= 30500;  // ROCm 7.0
 | 
			
		||||
  bool PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM").value_or(is_miopen_3_5);
 | 
			
		||||
 | 
			
		||||
  if (
 | 
			
		||||
      detail::getCUDAHooks().compiledWithMIOpen()
 | 
			
		||||
      && cudnn_enabled
 | 
			
		||||
@ -527,13 +536,15 @@ BatchNormBackend _select_batch_norm_backend(
 | 
			
		||||
      && input.dim() <= MIOPEN_DIM_MAX
 | 
			
		||||
      && input.dim() >= 3
 | 
			
		||||
      && input.scalar_type() != at::kDouble
 | 
			
		||||
      && (detail::getCUDAHooks().versionMIOpen() >= 30400 || input.scalar_type() != at::kBFloat16)
 | 
			
		||||
      && (is_miopen_3_4 || input.scalar_type() != at::kBFloat16)
 | 
			
		||||
      && weight.scalar_type() == at::kFloat // only FP32 weight for FP32 or FP16/BF16(mixed) input
 | 
			
		||||
      && weight.defined() && bias.defined()
 | 
			
		||||
      && ((running_mean.defined() && running_var.defined())
 | 
			
		||||
        || (!running_mean.defined() && !running_var.defined() && training))
 | 
			
		||||
      && input.suggest_memory_format() != MemoryFormat::ChannelsLast
 | 
			
		||||
      && input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
 | 
			
		||||
      && (input.suggest_memory_format() == MemoryFormat::Contiguous
 | 
			
		||||
          || (is_miopen_3_5 && PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM &&
 | 
			
		||||
              (input.suggest_memory_format() == MemoryFormat::ChannelsLast
 | 
			
		||||
               || input.suggest_memory_format() == MemoryFormat::ChannelsLast3d)))
 | 
			
		||||
  ) {
 | 
			
		||||
    return BatchNormBackend::Miopen;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
@ -411,7 +411,8 @@ Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) {
 | 
			
		||||
Tensor fbgemm_linear_fp16_weight_fp32_activation(
 | 
			
		||||
    const Tensor& input,
 | 
			
		||||
    const Tensor& packed_weight,
 | 
			
		||||
    const std::optional<Tensor>& bias) {
 | 
			
		||||
    const std::optional<Tensor>& bias,
 | 
			
		||||
    at::Tensor& output) {
 | 
			
		||||
  TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated "
 | 
			
		||||
                  "and will be removed in a future PyTorch release.")
 | 
			
		||||
 | 
			
		||||
@ -436,9 +437,11 @@ Tensor fbgemm_linear_fp16_weight_fp32_activation(
 | 
			
		||||
  // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
 | 
			
		||||
  const int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
 | 
			
		||||
  const int64_t N = packed_weight_fp16.numCols();
 | 
			
		||||
 | 
			
		||||
  std::vector<int64_t> output_size = input.sizes().vec();
 | 
			
		||||
  output_size.back() = N;
 | 
			
		||||
  Tensor output = at::empty(output_size, input.options().dtype(at::kFloat));
 | 
			
		||||
  // Resize output Tensor
 | 
			
		||||
  output.resize_(output_size);
 | 
			
		||||
 | 
			
		||||
  // Call the fp16 gemm interface
 | 
			
		||||
  fbgemm::cblas_gemm_compute(
 | 
			
		||||
@ -460,6 +463,14 @@ Tensor fbgemm_linear_fp16_weight_fp32_activation(
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor fbgemm_linear_fp16_weight_fp32_activation(
 | 
			
		||||
    const Tensor& input,
 | 
			
		||||
    const Tensor& packed_weight,
 | 
			
		||||
    const std::optional<Tensor>& bias) {
 | 
			
		||||
      at::Tensor output = at::empty({0}, input.options().dtype(at::kFloat));
 | 
			
		||||
      return at::native::fbgemm_linear_fp16_weight_fp32_activation(input, packed_weight, bias, output);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
Tensor fbgemm_linear_fp16_weight(
 | 
			
		||||
    const Tensor& input,
 | 
			
		||||
    const Tensor& packed_weight,
 | 
			
		||||
@ -468,6 +479,15 @@ Tensor fbgemm_linear_fp16_weight(
 | 
			
		||||
      input, packed_weight, bias);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor fbgemm_linear_fp16_weight(
 | 
			
		||||
  const Tensor& input,
 | 
			
		||||
    const Tensor& packed_weight,
 | 
			
		||||
    const Tensor& bias,
 | 
			
		||||
  at::Tensor& output) {
 | 
			
		||||
  return at::native::fbgemm_linear_fp16_weight_fp32_activation(
 | 
			
		||||
      input, packed_weight, bias, output);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#else // USE_FBGEMM
 | 
			
		||||
 | 
			
		||||
Tensor fbgemm_linear_int8_weight_fp32_activation(
 | 
			
		||||
@ -554,6 +574,21 @@ Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) {
 | 
			
		||||
      false, "This PyTorch installation was not built with FBGEMM operators");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor fbgemm_linear_fp16_weight_fp32_activation(
 | 
			
		||||
    const Tensor& input,
 | 
			
		||||
    const Tensor& packed_weight,
 | 
			
		||||
    const std::optional<Tensor>& bias,
 | 
			
		||||
    at::Tensor& output) {
 | 
			
		||||
  TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated "
 | 
			
		||||
                  "and will be removed in a future PyTorch release.")
 | 
			
		||||
 | 
			
		||||
  // We make a strong guarantee that models using these operators will have the
 | 
			
		||||
  // same numerics across different machines. Therefore, we do not provide a
 | 
			
		||||
  // fallback path and rather fail loudly if we cannot run FBGEMM.
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      false, "This PyTorch installation was not built with FBGEMM operators");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor fbgemm_linear_fp16_weight_fp32_activation(
 | 
			
		||||
    const Tensor& input,
 | 
			
		||||
    const Tensor& packed_weight,
 | 
			
		||||
@ -568,6 +603,21 @@ Tensor fbgemm_linear_fp16_weight_fp32_activation(
 | 
			
		||||
      false, "This PyTorch installation was not built with FBGEMM operators");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor fbgemm_linear_fp16_weight(
 | 
			
		||||
    const Tensor& input,
 | 
			
		||||
    const Tensor& packed_weight,
 | 
			
		||||
    const Tensor& bias,
 | 
			
		||||
    at::Tensor& output) {
 | 
			
		||||
  TORCH_WARN_ONCE("fbgemm_linear_fp16_weight is deprecated "
 | 
			
		||||
                  "and will be removed in a future PyTorch release.")
 | 
			
		||||
 | 
			
		||||
  // We make a strong guarantee that models using these operators will have the
 | 
			
		||||
  // same numerics across different machines. Therefore, we do not provide a
 | 
			
		||||
  // fallback path and rather fail loudly if we cannot run FBGEMM.
 | 
			
		||||
  TORCH_CHECK(
 | 
			
		||||
      false, "This PyTorch installation was not built with FBGEMM operators");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor fbgemm_linear_fp16_weight(
 | 
			
		||||
    const Tensor& input,
 | 
			
		||||
    const Tensor& packed_weight,
 | 
			
		||||
 | 
			
		||||
@ -220,6 +220,8 @@ static void check_argmax_argmin(
 | 
			
		||||
    const char* name,
 | 
			
		||||
    const Tensor& self,
 | 
			
		||||
    const std::optional<int64_t>& dim) {
 | 
			
		||||
  TORCH_CHECK(!self.is_complex(), name, ": does not support complex input");
 | 
			
		||||
  TORCH_CHECK(!(self.scalar_type() == kBool), name, ": does not support bool input");
 | 
			
		||||
  if (dim.has_value()) {
 | 
			
		||||
    auto dim_ = maybe_wrap_dim(dim.value(), self.dim());
 | 
			
		||||
    native::zero_numel_check_dims(self, dim_, name);
 | 
			
		||||
 | 
			
		||||
@ -59,6 +59,8 @@ TORCH_META_FUNC(topk)
 | 
			
		||||
      "selected index k out of range");
 | 
			
		||||
  int64_t sliceSize = self.dim() == 0 ? 1 : self.size(dim);
 | 
			
		||||
  TORCH_CHECK(k >= 0 && k <= sliceSize, "k not in range for dimension");
 | 
			
		||||
  TORCH_CHECK(!self.is_complex(), " topk does not support complex dtypes on CPU");
 | 
			
		||||
  TORCH_CHECK(!(self.scalar_type() == kBool), "topk does not support bool dtypes on CPU");
 | 
			
		||||
 | 
			
		||||
  // Build the output size, which is the dim being selected set to
 | 
			
		||||
  // size k
 | 
			
		||||
@ -74,11 +76,7 @@ TORCH_META_FUNC2(sort, stable)
 | 
			
		||||
(const Tensor& self, std::optional<bool> stable, int64_t dim, bool descending) {
 | 
			
		||||
  maybe_wrap_dim(dim, self.dim());
 | 
			
		||||
 | 
			
		||||
  const auto self_dtype = self.dtype();
 | 
			
		||||
  TORCH_CHECK_VALUE(
 | 
			
		||||
    self_dtype != ScalarType::ComplexFloat &&
 | 
			
		||||
    self_dtype != ScalarType::ComplexDouble,
 | 
			
		||||
    "Sort currently does not support complex dtypes on CPU.");
 | 
			
		||||
  TORCH_CHECK(!self.is_complex(), " Sort does not support complex dtypes on CPU");
 | 
			
		||||
 | 
			
		||||
  // See issue: https://github.com/pytorch/pytorch/issues/65863
 | 
			
		||||
  // Strides should be dense, so as not to allocate too much memory.
 | 
			
		||||
 | 
			
		||||
@ -226,8 +226,9 @@ C10_LAUNCH_BOUNDS_1(num_threads())
 | 
			
		||||
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
 | 
			
		||||
  using traits = function_traits<func_t>;
 | 
			
		||||
  constexpr auto io_size = calc_io_size<func_t>();
 | 
			
		||||
#ifdef __gfx942__
 | 
			
		||||
  constexpr int tws = (io_size >= 2) ? 8 : 16;
 | 
			
		||||
#if defined(USE_ROCM) && defined(__gfx942__)
 | 
			
		||||
  // Similar check in launch_vectorized_kernel() as well. Both should be in sync.
 | 
			
		||||
  constexpr int tws = 16;
 | 
			
		||||
#else
 | 
			
		||||
  constexpr int tws = elems_per_thread<io_size>();
 | 
			
		||||
#endif
 | 
			
		||||
@ -296,7 +297,8 @@ static inline void launch_vectorized_kernel(
 | 
			
		||||
  int vec_size = memory::can_vectorize_up_to<func_t>(data);
 | 
			
		||||
  c10::DeviceIndex curDevice = -1;
 | 
			
		||||
  AT_CUDA_CHECK(c10::cuda::GetDevice(&curDevice));
 | 
			
		||||
  int tws = at::detail::getCUDAHooks().isGPUArch({"gfx942"}, curDevice) ? ((io_size >= 2) ? 8 : 16) : elems_per_thread<io_size>();
 | 
			
		||||
  // Similar check in vectorized_elementwise_kernel() as well. Both should be in sync.
 | 
			
		||||
  int tws = at::detail::getCUDAHooks().isGPUArch({"gfx942"}, curDevice) ? 16 : elems_per_thread<io_size>();
 | 
			
		||||
#else
 | 
			
		||||
  using cpp_type = typename function_traits<func_t>::result_type;
 | 
			
		||||
  const uint16_t max_vec_size = memory::can_vectorize_up_to<func_t>(data);
 | 
			
		||||
 | 
			
		||||
@ -209,6 +209,10 @@ struct ReduceConfig {
 | 
			
		||||
  int values_per_thread() const {
 | 
			
		||||
    return div_up(num_inputs, step_input);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  int mock_values_per_thread(int parallelism) {
 | 
			
		||||
    return div_up(num_inputs, step_input * parallelism);
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);
 | 
			
		||||
@ -1058,7 +1062,7 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
 | 
			
		||||
  // In such case, values in each loaded vector always correspond to different outputs.
 | 
			
		||||
  if (fastest_moving_stride == sizeof(scalar_t)) {
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
    if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1) {
 | 
			
		||||
    if (reduction_on_fastest_striding_dimension && dim0 >= 128 && iter.num_reduce_dims() == 1) {
 | 
			
		||||
#else
 | 
			
		||||
    if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= input_vec_size) {
 | 
			
		||||
#endif
 | 
			
		||||
@ -1166,8 +1170,17 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
 | 
			
		||||
    else if (config.ctas_per_output < 16)
 | 
			
		||||
      config.ctas_per_output = 1;
 | 
			
		||||
    bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast);
 | 
			
		||||
    if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last)
 | 
			
		||||
    if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last) {
 | 
			
		||||
      config.ctas_per_output = 4;
 | 
			
		||||
      int vpt = config.values_per_thread();
 | 
			
		||||
      // Capping the number of values per thread to 2048 for now
 | 
			
		||||
      // based on known use cases.
 | 
			
		||||
      while (vpt >= 2048) {
 | 
			
		||||
        config.ctas_per_output *= 2;
 | 
			
		||||
        // Computes the new values per thread without side effects
 | 
			
		||||
        vpt = config.mock_values_per_thread(config.ctas_per_output);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
    if (config.ctas_per_output > 1) {
 | 
			
		||||
      config.input_mult[2] = config.split_input(config.ctas_per_output);
 | 
			
		||||
 | 
			
		||||
@ -1304,7 +1304,7 @@ at::Tensor _convert_weight_to_int4pack_cuda(
 | 
			
		||||
  constexpr int32_t kKTileSize = 16;
 | 
			
		||||
 | 
			
		||||
  // GPT-FAST assumes nTileSize of 8 for quantized weight tensor.
 | 
			
		||||
  // See https://github.com/pytorch-labs/gpt-fast/blob/091515ab5b06f91c0d6a3b92f9c27463f738cc9b/quantize.py#L510
 | 
			
		||||
  // See https://github.com/meta-pytorch/gpt-fast/blob/091515ab5b06f91c0d6a3b92f9c27463f738cc9b/quantize.py#L510
 | 
			
		||||
  // Torch dynamo also requires the torch ops has the same output shape for each device.
 | 
			
		||||
  // See https://github.com/pytorch/pytorch/blob/ec284d3a74ec1863685febd53687d491fd99a161/torch/_meta_registrations.py#L3263
 | 
			
		||||
  constexpr int32_t kNTileSizeTensor = 8;
 | 
			
		||||
 | 
			
		||||
@ -762,7 +762,7 @@ Tensor miopen_convolution_forward(
 | 
			
		||||
 | 
			
		||||
  auto memory_format = at::MemoryFormat::Contiguous;
 | 
			
		||||
  if (miopen_conv_use_channels_last(*input, *weight)) {
 | 
			
		||||
    memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
 | 
			
		||||
    memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Tensor output_t = at::detail::empty_cuda(
 | 
			
		||||
@ -870,7 +870,7 @@ Tensor miopen_depthwise_convolution_forward(
 | 
			
		||||
 | 
			
		||||
  auto memory_format = at::MemoryFormat::Contiguous;
 | 
			
		||||
  if (miopen_conv_use_channels_last(*input, *weight)) {
 | 
			
		||||
    memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
 | 
			
		||||
    memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Tensor output_t = at::detail::empty_cuda(
 | 
			
		||||
@ -1070,7 +1070,7 @@ Tensor miopen_depthwise_convolution_backward_weight(
 | 
			
		||||
 | 
			
		||||
  auto memory_format = at::MemoryFormat::Contiguous;
 | 
			
		||||
  if (miopen_conv_use_channels_last(*input, *grad_output)) {
 | 
			
		||||
    memory_format = (input->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
 | 
			
		||||
    memory_format = (input->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Tensor grad_output_contig_t = grad_output->contiguous(memory_format);
 | 
			
		||||
@ -1123,7 +1123,7 @@ Tensor miopen_convolution_backward_weight(
 | 
			
		||||
 | 
			
		||||
  auto memory_format = at::MemoryFormat::Contiguous;
 | 
			
		||||
  if (miopen_conv_use_channels_last(*input, *grad_output)) {
 | 
			
		||||
    memory_format = (input->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
 | 
			
		||||
    memory_format = (input->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Tensor grad_output_contig_t = grad_output->contiguous(memory_format);
 | 
			
		||||
@ -1196,7 +1196,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> miopen_convolution_transpose_backwa
 | 
			
		||||
    IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
 | 
			
		||||
    bool benchmark, bool deterministic, std::array<bool,3> output_mask) {
 | 
			
		||||
 | 
			
		||||
  Tensor grad_output = grad_output_t.contiguous();
 | 
			
		||||
  Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());
 | 
			
		||||
 | 
			
		||||
  Tensor grad_input, grad_weight, grad_bias;
 | 
			
		||||
  if (output_mask[0]) {
 | 
			
		||||
@ -1276,7 +1276,7 @@ Tensor miopen_convolution_backward_input(
 | 
			
		||||
 | 
			
		||||
  auto memory_format = at::MemoryFormat::Contiguous;
 | 
			
		||||
  if (miopen_conv_use_channels_last(*grad_output, *weight)) {
 | 
			
		||||
    memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
 | 
			
		||||
    memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Tensor grad_input_t = at::detail::empty_cuda(
 | 
			
		||||
@ -1383,7 +1383,7 @@ Tensor miopen_depthwise_convolution_backward_input(
 | 
			
		||||
 | 
			
		||||
  auto memory_format = at::MemoryFormat::Contiguous;
 | 
			
		||||
  if (miopen_conv_use_channels_last(*grad_output, *weight)) {
 | 
			
		||||
    memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
 | 
			
		||||
    memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Tensor grad_input_t = at::detail::empty_cuda(
 | 
			
		||||
@ -1446,7 +1446,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> miopen_depthwise_convolution_backwa
 | 
			
		||||
    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
 | 
			
		||||
    bool benchmark, bool deterministic, std::array<bool,3> output_mask) {
 | 
			
		||||
 | 
			
		||||
  Tensor grad_output = grad_output_t.contiguous();
 | 
			
		||||
  Tensor grad_output = grad_output_t.contiguous(input.suggest_memory_format());
 | 
			
		||||
 | 
			
		||||
  Tensor grad_input, grad_weight, grad_bias;
 | 
			
		||||
  if (output_mask[0]) {
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,4 @@
 | 
			
		||||
#include <ATen/Context.h>
 | 
			
		||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
 | 
			
		||||
#include <ATen/native/transformers/attention.h>
 | 
			
		||||
#include <ATen/native/transformers/sdp_utils.h>
 | 
			
		||||
@ -49,7 +50,7 @@ bool check_no_grad(sdp::sdp_params const& params, bool debug) {
 | 
			
		||||
  return !any_inputs_require_grad || !gradmode_enabled;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) {
 | 
			
		||||
bool can_use_overrideable_attention(sdp::sdp_params const& params, bool debug) {
 | 
			
		||||
  constexpr auto supported_dtypes = c10::array_of<at::ScalarType>(
 | 
			
		||||
      at::kFloat, at::kBFloat16, at::kHalf); // double is not supported
 | 
			
		||||
 | 
			
		||||
@ -73,6 +74,42 @@ bool use_overrideable_xpu(sdp::sdp_params const& params, bool debug) {
 | 
			
		||||
  return sdp::check_tensor_dtype(params, supported_dtypes, debug);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool can_use_flash_attention(sdp::sdp_params const& params, bool debug) {
 | 
			
		||||
  // Currently, XPU fallbacks flash attention to overrideable
 | 
			
		||||
  return can_use_overrideable_attention(params, debug);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool can_use_cudnn_attention(sdp::sdp_params const& params, bool debug) {
 | 
			
		||||
  if (debug) {
 | 
			
		||||
    TORCH_WARN("XPU don't support SDPA cudnn attention backend.");
 | 
			
		||||
  }
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool can_use_mem_efficien_attention(sdp::sdp_params const& params, bool debug) {
 | 
			
		||||
  if (debug) {
 | 
			
		||||
    TORCH_WARN("XPU don't support SDPA mem efficient attention backend.");
 | 
			
		||||
  }
 | 
			
		||||
  return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool priority_order_init = false;
 | 
			
		||||
 | 
			
		||||
std::array<sdp::SDPBackend, sdp::num_backends> priority_order(
 | 
			
		||||
    sdp::sdp_params const& params) {
 | 
			
		||||
  if (!priority_order_init) {
 | 
			
		||||
    priority_order_init = true;
 | 
			
		||||
    const std::vector<int64_t> priority_order = {
 | 
			
		||||
        static_cast<int64_t>(at::SDPBackend::overrideable),
 | 
			
		||||
        static_cast<int64_t>(at::SDPBackend::math),
 | 
			
		||||
        static_cast<int64_t>(at::SDPBackend::flash_attention),
 | 
			
		||||
        static_cast<int64_t>(at::SDPBackend::efficient_attention),
 | 
			
		||||
        static_cast<int64_t>(at::SDPBackend::cudnn_attention)};
 | 
			
		||||
    at::globalContext().setSDPPriorityOrder(priority_order);
 | 
			
		||||
  }
 | 
			
		||||
  return at::globalContext().sDPPriorityOrder();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
 | 
			
		||||
  // This function defines the priority order of the different sdp backends
 | 
			
		||||
  // 1. Flash Attention
 | 
			
		||||
@ -85,20 +122,16 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Get ideal kernel ordering
 | 
			
		||||
  const std::array<sdp::SDPBackend, 3> priority_order{
 | 
			
		||||
      sdp::SDPBackend::overrideable,
 | 
			
		||||
      sdp::SDPBackend::math,
 | 
			
		||||
      sdp::SDPBackend::flash_attention,
 | 
			
		||||
  };
 | 
			
		||||
  const auto ordering = priority_order(kernel_params);
 | 
			
		||||
 | 
			
		||||
  // Because TORCHCHECK checks if condition is true we negate debug so that
 | 
			
		||||
  // The statements will be printed when debug is true
 | 
			
		||||
  bool print_debug = false;
 | 
			
		||||
  for (auto& backend : priority_order) {
 | 
			
		||||
  for (auto& backend : ordering) {
 | 
			
		||||
    switch (backend) {
 | 
			
		||||
      case sdp::SDPBackend::overrideable:
 | 
			
		||||
        if (ctx.userEnabledOverrideableSDP() &&
 | 
			
		||||
            use_overrideable_xpu(kernel_params, print_debug)) {
 | 
			
		||||
            can_use_overrideable_attention(kernel_params, print_debug)) {
 | 
			
		||||
          return sdp::SDPBackend::overrideable;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
@ -109,25 +142,43 @@ sdp::SDPBackend select_sdp_backend_xpu(sdp::sdp_params const& kernel_params) {
 | 
			
		||||
        break;
 | 
			
		||||
      case sdp::SDPBackend::flash_attention:
 | 
			
		||||
        if (ctx.userEnabledFlashSDP() &&
 | 
			
		||||
            use_overrideable_xpu(kernel_params, print_debug)) {
 | 
			
		||||
          TORCH_WARN(
 | 
			
		||||
              "Flash Attention is not supported on XPU, falling back to overrideable kernel.");
 | 
			
		||||
            can_use_flash_attention(kernel_params, print_debug)) {
 | 
			
		||||
          TORCH_WARN_ONCE(
 | 
			
		||||
              "SDPA Flash Attention backend is not supported on XPU, falling back to OVERRIDEABLE backend.");
 | 
			
		||||
          return sdp::SDPBackend::overrideable;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      case sdp::SDPBackend::cudnn_attention:
 | 
			
		||||
        if (ctx.userEnabledCuDNNSDP() &&
 | 
			
		||||
            can_use_cudnn_attention(kernel_params, print_debug)) {
 | 
			
		||||
          TORCH_CHECK(false, "Invalid backend");
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      case sdp::SDPBackend::efficient_attention:
 | 
			
		||||
        if (ctx.userEnabledMemEfficientSDP() &&
 | 
			
		||||
            can_use_mem_efficien_attention(kernel_params, print_debug)) {
 | 
			
		||||
          TORCH_CHECK(false, "Invalid backend");
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      default:
 | 
			
		||||
        TORCH_CHECK(false, "Invalid backend");
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  // If we have gotten to this point then two things have happened:
 | 
			
		||||
  // 1. use_overrideable_xpu did not satisfy the constraints to be ran
 | 
			
		||||
  // 1. can_use_overrideable_attention did not satisfy the constraints to be ran
 | 
			
		||||
  // 2. The user has explicitly disabled the math kernel
 | 
			
		||||
  // We then re-run the kernel checks with debug enabled to print out the
 | 
			
		||||
  // reason why the kernel was not selected
 | 
			
		||||
 | 
			
		||||
  print_debug = true;
 | 
			
		||||
  TORCH_WARN("OneDNN kernel not used because:");
 | 
			
		||||
  use_overrideable_xpu(kernel_params, print_debug);
 | 
			
		||||
  TORCH_WARN("Flash attention kernel not used because:");
 | 
			
		||||
  can_use_flash_attention(kernel_params, print_debug);
 | 
			
		||||
  TORCH_WARN("Overrideable attention kernel not used because:");
 | 
			
		||||
  can_use_overrideable_attention(kernel_params, print_debug);
 | 
			
		||||
  TORCH_WARN("CuDNN attention kernel not used because:");
 | 
			
		||||
  can_use_cudnn_attention(kernel_params, print_debug);
 | 
			
		||||
  TORCH_WARN("Memory Efficient attention kernel not used because:");
 | 
			
		||||
  can_use_mem_efficien_attention(kernel_params, print_debug);
 | 
			
		||||
  TORCH_CHECK(!print_debug, "No available kernel. Aborting execution.")
 | 
			
		||||
  return sdp::SDPBackend::error;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										25
									
								
								aten/src/ATen/native/mps/kernels/GridSampler.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								aten/src/ATen/native/mps/kernels/GridSampler.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,25 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
#include <c10/metal/common.h>
 | 
			
		||||
 | 
			
		||||
#ifdef __METAL__
 | 
			
		||||
enum class GridSamplerInterpolation { Bilinear, Nearest, Bicubic };
 | 
			
		||||
enum class GridSamplerPadding { Zeros, Border, Reflection };
 | 
			
		||||
#else
 | 
			
		||||
#include <ATen/native/GridSamplerUtils.h>
 | 
			
		||||
using at::native::GridSamplerInterpolation;
 | 
			
		||||
using at::native::GridSamplerPadding;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template <unsigned N = 5, typename idx_type_t = int32_t>
 | 
			
		||||
struct GridSamplerParams {
 | 
			
		||||
  int32_t sampler_dims;
 | 
			
		||||
  ::c10::metal::array<idx_type_t, N> output_sizes;
 | 
			
		||||
  ::c10::metal::array<idx_type_t, N> output_strides;
 | 
			
		||||
  ::c10::metal::array<idx_type_t, N> input_sizes;
 | 
			
		||||
  ::c10::metal::array<idx_type_t, N> input_strides;
 | 
			
		||||
  ::c10::metal::array<idx_type_t, N> grid_sizes;
 | 
			
		||||
  ::c10::metal::array<idx_type_t, N> grid_strides;
 | 
			
		||||
  GridSamplerInterpolation interpolation_mode;
 | 
			
		||||
  GridSamplerPadding padding_mode;
 | 
			
		||||
  bool align_corners;
 | 
			
		||||
};
 | 
			
		||||
							
								
								
									
										329
									
								
								aten/src/ATen/native/mps/kernels/GridSampler.metal
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										329
									
								
								aten/src/ATen/native/mps/kernels/GridSampler.metal
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,329 @@
 | 
			
		||||
#include <ATen/native/mps/kernels/GridSampler.h>
 | 
			
		||||
#include <c10/metal/utils.h>
 | 
			
		||||
#include <metal_array>
 | 
			
		||||
#include <metal_stdlib>
 | 
			
		||||
 | 
			
		||||
using namespace metal;
 | 
			
		||||
using namespace c10::metal;
 | 
			
		||||
 | 
			
		||||
struct GridSamplerOffsets {
 | 
			
		||||
  int32_t output;
 | 
			
		||||
  int32_t input;
 | 
			
		||||
  int32_t grid;
 | 
			
		||||
 | 
			
		||||
  GridSamplerOffsets() : output(0), input(0), grid(0) {}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Find offsets into the tensors that this thread will operate on,
 | 
			
		||||
// based on the thread ID.
 | 
			
		||||
static GridSamplerOffsets find_grid_sampler_offsets(
 | 
			
		||||
    constant int32_t* output_sizes,
 | 
			
		||||
    constant int32_t* output_strides,
 | 
			
		||||
    constant int32_t* input_sizes,
 | 
			
		||||
    constant int32_t* input_strides,
 | 
			
		||||
    constant int32_t* grid_sizes,
 | 
			
		||||
    constant int32_t* grid_strides,
 | 
			
		||||
    int32_t sampler_dims,
 | 
			
		||||
    uint tid) {
 | 
			
		||||
  auto dims = sampler_dims + 2;
 | 
			
		||||
  auto output_idx = static_cast<int32_t>(tid);
 | 
			
		||||
  GridSamplerOffsets offsets;
 | 
			
		||||
 | 
			
		||||
  for (auto dim = dims - 1; dim >= 0; dim--) {
 | 
			
		||||
    auto dim_idx = output_idx % output_sizes[dim];
 | 
			
		||||
    output_idx = output_idx / output_sizes[dim];
 | 
			
		||||
 | 
			
		||||
    // Select the output element that this thread will calculate.
 | 
			
		||||
    // output shape:
 | 
			
		||||
    //   2 sampler dims: (N, C, Hout, Wout)
 | 
			
		||||
    //   3 sampler dims: (N, C, Dout, Hout, Wout)
 | 
			
		||||
    offsets.output += output_strides[dim] * dim_idx;
 | 
			
		||||
 | 
			
		||||
    // Select the batch and channel for the input.
 | 
			
		||||
    // input shape:
 | 
			
		||||
    //   2 sampler dims: (N, C, Hin, Win)
 | 
			
		||||
    //   3 sampler dims: (N, C, Din, Hin, Win)
 | 
			
		||||
    if (dim < 2) {
 | 
			
		||||
      offsets.input += input_strides[dim] * dim_idx;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Select the grid coordinates for the output element.
 | 
			
		||||
    // grid shape:
 | 
			
		||||
    //   2 sampler dims: (N, Hout, Wout, 2)
 | 
			
		||||
    //   3 sampler dims: (N, Dout, Hout, Wout, 3)
 | 
			
		||||
    if (dim == 0) {
 | 
			
		||||
      offsets.grid += grid_strides[dim] * dim_idx;
 | 
			
		||||
    } else if (dim >= 2) {
 | 
			
		||||
      offsets.grid += grid_strides[dim - 1] * dim_idx;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return offsets;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Mod function which gives postive output when `a` is negative
 | 
			
		||||
static int32_t mod(int32_t a, int32_t b) {
 | 
			
		||||
  auto r = a % b;
 | 
			
		||||
  return r + (r < 0 ? b : 0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Sentinel index value to indicate zero padding
 | 
			
		||||
constant int32_t IDX_ZERO = -1;
 | 
			
		||||
 | 
			
		||||
// Apply padding to an index into the input
 | 
			
		||||
static int32_t pad_input_index(
 | 
			
		||||
    int32_t idx,
 | 
			
		||||
    int32_t input_size,
 | 
			
		||||
    GridSamplerPadding padding_mode,
 | 
			
		||||
    bool align_corners) {
 | 
			
		||||
  int32_t idx_padded = idx;
 | 
			
		||||
 | 
			
		||||
  if (padding_mode == GridSamplerPadding::Zeros) {
 | 
			
		||||
    idx_padded = (idx < 0) ? IDX_ZERO : idx_padded;
 | 
			
		||||
    idx_padded = (idx >= input_size) ? IDX_ZERO : idx_padded;
 | 
			
		||||
 | 
			
		||||
  } else if (padding_mode == GridSamplerPadding::Border) {
 | 
			
		||||
    idx_padded = (idx < 0) ? 0 : idx_padded;
 | 
			
		||||
    idx_padded = (idx >= input_size) ? input_size - 1 : idx_padded;
 | 
			
		||||
 | 
			
		||||
  } else if (padding_mode == GridSamplerPadding::Reflection) {
 | 
			
		||||
    auto scale_length = align_corners ? (input_size - 1) : input_size;
 | 
			
		||||
    auto idx_mod = mod(idx, scale_length);
 | 
			
		||||
    auto idx_mod_reverse = (input_size - 1) - idx_mod;
 | 
			
		||||
    bool is_reverse = (abs(idx - idx_mod) / scale_length) % 2 == 1;
 | 
			
		||||
    idx_padded = is_reverse ? idx_mod_reverse : idx_mod;
 | 
			
		||||
  }
 | 
			
		||||
  return idx_padded;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int32_t dims, typename T>
 | 
			
		||||
T get_tensor_val(
 | 
			
		||||
    constant T* input,
 | 
			
		||||
    constant int32_t* input_strides,
 | 
			
		||||
    int32_t indices[dims]) {
 | 
			
		||||
  bool found_idx_zero = false;
 | 
			
		||||
  int32_t offset = 0;
 | 
			
		||||
 | 
			
		||||
  for (auto dim = 0; dim < dims; dim++) {
 | 
			
		||||
    auto idx = indices[dim];
 | 
			
		||||
    found_idx_zero = found_idx_zero || (idx == IDX_ZERO);
 | 
			
		||||
    offset += (found_idx_zero ? 0 : idx) * input_strides[dim];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return found_idx_zero ? 0 : input[offset];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// This function performs 3D linear interpolation for one value. One way to
 | 
			
		||||
// think of how this works is to imagine a unit cube where each corner of the
 | 
			
		||||
// cube has one scalar value associated with it. Inside the cube, the values
 | 
			
		||||
// change linearly, so the gradient is constant. The values associated with each
 | 
			
		||||
// corner are given by the `input`, indexed at all eight different combinations
 | 
			
		||||
// of the `left_indices` and `right_indices`. Given a 3D coordinate anywhere
 | 
			
		||||
// within the cube, specified by the `scales` argument, we must calculate the
 | 
			
		||||
// value associated with that position.
 | 
			
		||||
template <typename T>
 | 
			
		||||
T interpolate_linear_3d(
 | 
			
		||||
    constant T* input,
 | 
			
		||||
    constant int32_t* input_strides,
 | 
			
		||||
    int32_t left_indices[3],
 | 
			
		||||
    int32_t right_indices[3],
 | 
			
		||||
    opmath_t<T> scales[3]) {
 | 
			
		||||
  int32_t a_idx[3] = {left_indices[0], left_indices[1], left_indices[2]};
 | 
			
		||||
  int32_t b_idx[3] = {left_indices[0], left_indices[1], right_indices[2]};
 | 
			
		||||
  int32_t c_idx[3] = {left_indices[0], right_indices[1], left_indices[2]};
 | 
			
		||||
  int32_t d_idx[3] = {left_indices[0], right_indices[1], right_indices[2]};
 | 
			
		||||
  int32_t e_idx[3] = {right_indices[0], left_indices[1], left_indices[2]};
 | 
			
		||||
  int32_t f_idx[3] = {right_indices[0], left_indices[1], right_indices[2]};
 | 
			
		||||
  int32_t g_idx[3] = {right_indices[0], right_indices[1], left_indices[2]};
 | 
			
		||||
  int32_t h_idx[3] = {right_indices[0], right_indices[1], right_indices[2]};
 | 
			
		||||
  auto a =
 | 
			
		||||
      static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, a_idx));
 | 
			
		||||
  auto b =
 | 
			
		||||
      static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, b_idx));
 | 
			
		||||
  auto c =
 | 
			
		||||
      static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, c_idx));
 | 
			
		||||
  auto d =
 | 
			
		||||
      static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, d_idx));
 | 
			
		||||
  auto e =
 | 
			
		||||
      static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, e_idx));
 | 
			
		||||
  auto f =
 | 
			
		||||
      static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, f_idx));
 | 
			
		||||
  auto g =
 | 
			
		||||
      static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, g_idx));
 | 
			
		||||
  auto h =
 | 
			
		||||
      static_cast<opmath_t<T>>(get_tensor_val<3>(input, input_strides, h_idx));
 | 
			
		||||
 | 
			
		||||
  auto scale0_right = scales[0];
 | 
			
		||||
  auto scale1_right = scales[1];
 | 
			
		||||
  auto scale2_right = scales[2];
 | 
			
		||||
  auto scale0_left = 1 - scale0_right;
 | 
			
		||||
  auto scale1_left = 1 - scale1_right;
 | 
			
		||||
  auto scale2_left = 1 - scale2_right;
 | 
			
		||||
 | 
			
		||||
  return static_cast<T>(
 | 
			
		||||
      scale0_left * scale1_left * scale2_left * a +
 | 
			
		||||
      scale0_left * scale1_left * scale2_right * b +
 | 
			
		||||
      scale0_left * scale1_right * scale2_left * c +
 | 
			
		||||
      scale0_left * scale1_right * scale2_right * d +
 | 
			
		||||
      scale0_right * scale1_left * scale2_left * e +
 | 
			
		||||
      scale0_right * scale1_left * scale2_right * f +
 | 
			
		||||
      scale0_right * scale1_right * scale2_left * g +
 | 
			
		||||
      scale0_right * scale1_right * scale2_right * h);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Calculates a single output element.
 | 
			
		||||
// `input` shape:
 | 
			
		||||
//    2 sampler dims: (Hin, Win)
 | 
			
		||||
//    3 sampler dims: (Din, Hin, Win)
 | 
			
		||||
// `coords` values:
 | 
			
		||||
//    2 sampler dims: (Wcoord, Hcoord)
 | 
			
		||||
//    3 sampler dims: (Wcoord, Hcoord, Dcoord)
 | 
			
		||||
template <typename T>
 | 
			
		||||
void grid_sampler_single_element(
 | 
			
		||||
    device T* output,
 | 
			
		||||
    constant T* input,
 | 
			
		||||
    constant T* coords,
 | 
			
		||||
    int32_t dims,
 | 
			
		||||
    constant int32_t* input_sizes,
 | 
			
		||||
    constant int32_t* input_strides,
 | 
			
		||||
    GridSamplerInterpolation interpolation_mode,
 | 
			
		||||
    GridSamplerPadding padding_mode,
 | 
			
		||||
    bool align_corners) {
 | 
			
		||||
  int32_t left_indices[3];
 | 
			
		||||
  int32_t right_indices[3];
 | 
			
		||||
  opmath_t<T> scales[3];
 | 
			
		||||
 | 
			
		||||
  // For each dimension, find the pair of indices in the cooresponding dimension
 | 
			
		||||
  // of `input` which surround the grid coordinate in that dimension. We'll do
 | 
			
		||||
  // this by mapping different coordiante spaces onto each other. There are
 | 
			
		||||
  // basically three different coordinate spaces to keep in mind:
 | 
			
		||||
  //
 | 
			
		||||
  //  * aligned grid space
 | 
			
		||||
  //    - `-1` refers to the leftmost input value.
 | 
			
		||||
  //    - `1` refers to the rightmost input value.
 | 
			
		||||
  //
 | 
			
		||||
  //  * unaligned grid space
 | 
			
		||||
  //    - `-1` refers to the midpoint between the leftmost input value and
 | 
			
		||||
  //      a padding value to the left of that.
 | 
			
		||||
  //    - `1` refers to the midpoint between the rightmost input value and
 | 
			
		||||
  //      a padding value to the right of that.
 | 
			
		||||
  //
 | 
			
		||||
  //  * input index space
 | 
			
		||||
  //    - `n` refers to the n-th value of the input.
 | 
			
		||||
  //    - `0` refers to the leftmost input value.
 | 
			
		||||
  //    - `N-1` refers to the rightmost input value.
 | 
			
		||||
  //
 | 
			
		||||
  // If `align_corners == False`, then the coordinates are is in unaligned grid
 | 
			
		||||
  // space, and we will map it onto aligned grid space. If `align_corners ==
 | 
			
		||||
  // True`, then coordinates are already in aligned grid space.
 | 
			
		||||
  //
 | 
			
		||||
  // Then we will map unaligned grid space onto input index space, making it
 | 
			
		||||
  // relatively simple to find the two input indices that surround the
 | 
			
		||||
  // coordinate.
 | 
			
		||||
  for (auto coord_dim = 0; coord_dim < dims; coord_dim++) {
 | 
			
		||||
    auto input_dim = dims - coord_dim - 1;
 | 
			
		||||
    auto input_size = input_sizes[input_dim];
 | 
			
		||||
    auto coord = static_cast<opmath_t<T>>(coords[coord_dim]);
 | 
			
		||||
 | 
			
		||||
    // Interpret nan as -1
 | 
			
		||||
    coord = isnan(coord) ? -1 : coord;
 | 
			
		||||
 | 
			
		||||
    if (!align_corners) {
 | 
			
		||||
      // Map unaligned grid space to aligned grid space
 | 
			
		||||
      auto corner_alignment_factor = static_cast<opmath_t<T>>(input_size) /
 | 
			
		||||
          static_cast<opmath_t<T>>(input_size - 1);
 | 
			
		||||
      coord = coord * corner_alignment_factor;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Map aligned grid space to input index space
 | 
			
		||||
    coord = (coord + 1) * (static_cast<opmath_t<T>>(input_size - 1) / 2);
 | 
			
		||||
 | 
			
		||||
    // Get the input indices surrounding the coordinate, apply padding to them,
 | 
			
		||||
    // and obtain the scaling factor between the two for interpolation.
 | 
			
		||||
    auto left_idx = static_cast<int32_t>(floor(coord));
 | 
			
		||||
    auto right_idx = static_cast<int32_t>(ceil(coord));
 | 
			
		||||
    left_indices[input_dim] =
 | 
			
		||||
        pad_input_index(left_idx, input_size, padding_mode, align_corners);
 | 
			
		||||
    right_indices[input_dim] =
 | 
			
		||||
        pad_input_index(right_idx, input_size, padding_mode, align_corners);
 | 
			
		||||
 | 
			
		||||
    auto scale = coord - left_idx;
 | 
			
		||||
 | 
			
		||||
    if (interpolation_mode == GridSamplerInterpolation::Nearest) {
 | 
			
		||||
      // TODO: For some reason, rounding the scale to 0 or 1 and then using
 | 
			
		||||
      // linear interpolation seems to work perfectly with zero padding mode,
 | 
			
		||||
      // but we get flaky failures with border and reflection padding modes.
 | 
			
		||||
      // Need to investigate and fix it.
 | 
			
		||||
      scale = (scale <= 0.5) ? 0 : 1;
 | 
			
		||||
    }
 | 
			
		||||
    scales[input_dim] = scale;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Now that we have the bounding indices and scale factor for each dimension
 | 
			
		||||
  // of the input, we can interpolate.
 | 
			
		||||
  if (dims == 3) {
 | 
			
		||||
    *output = interpolate_linear_3d(
 | 
			
		||||
        input, input_strides, left_indices, right_indices, scales);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
kernel void grid_sampler(
 | 
			
		||||
    device T* output [[buffer(0)]],
 | 
			
		||||
    constant T* input [[buffer(1)]],
 | 
			
		||||
    constant T* grid [[buffer(2)]],
 | 
			
		||||
    constant GridSamplerParams<5>& params [[buffer(3)]],
 | 
			
		||||
    uint tid [[thread_position_in_grid]]) {
 | 
			
		||||
  auto output_sizes = params.output_sizes.data();
 | 
			
		||||
  auto output_strides = params.output_strides.data();
 | 
			
		||||
  auto input_sizes = params.input_sizes.data();
 | 
			
		||||
  auto input_strides = params.input_strides.data();
 | 
			
		||||
  auto grid_sizes = params.grid_sizes.data();
 | 
			
		||||
  auto grid_strides = params.grid_strides.data();
 | 
			
		||||
  auto sampler_dims = params.sampler_dims;
 | 
			
		||||
 | 
			
		||||
  auto offsets = find_grid_sampler_offsets(
 | 
			
		||||
      output_sizes,
 | 
			
		||||
      output_strides,
 | 
			
		||||
      input_sizes,
 | 
			
		||||
      input_strides,
 | 
			
		||||
      grid_sizes,
 | 
			
		||||
      grid_strides,
 | 
			
		||||
      sampler_dims,
 | 
			
		||||
      tid);
 | 
			
		||||
 | 
			
		||||
  output += offsets.output;
 | 
			
		||||
  input += offsets.input;
 | 
			
		||||
  auto coords = grid + offsets.grid;
 | 
			
		||||
 | 
			
		||||
  input_sizes += 2;
 | 
			
		||||
  input_strides += 2;
 | 
			
		||||
 | 
			
		||||
  auto interpolation_mode = params.interpolation_mode;
 | 
			
		||||
  auto padding_mode = params.padding_mode;
 | 
			
		||||
  auto align_corners = params.align_corners;
 | 
			
		||||
 | 
			
		||||
  grid_sampler_single_element(
 | 
			
		||||
      output,
 | 
			
		||||
      input,
 | 
			
		||||
      coords,
 | 
			
		||||
      sampler_dims,
 | 
			
		||||
      input_sizes,
 | 
			
		||||
      input_strides,
 | 
			
		||||
      interpolation_mode,
 | 
			
		||||
      padding_mode,
 | 
			
		||||
      align_corners);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define REGISTER_GRID_SAMPLER_OP(DTYPE)                     \
 | 
			
		||||
  template [[host_name("grid_sampler_" #DTYPE)]]            \
 | 
			
		||||
  kernel void grid_sampler<DTYPE>(                          \
 | 
			
		||||
      device DTYPE * output [[buffer(0)]],                  \
 | 
			
		||||
      constant DTYPE * input [[buffer(1)]],                 \
 | 
			
		||||
      constant DTYPE * grid [[buffer(2)]],                  \
 | 
			
		||||
      constant GridSamplerParams<5> & params [[buffer(3)]], \
 | 
			
		||||
      uint tid [[thread_position_in_grid]]);
 | 
			
		||||
 | 
			
		||||
REGISTER_GRID_SAMPLER_OP(float);
 | 
			
		||||
REGISTER_GRID_SAMPLER_OP(half);
 | 
			
		||||
REGISTER_GRID_SAMPLER_OP(bfloat);
 | 
			
		||||
@ -1,7 +1,10 @@
 | 
			
		||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 | 
			
		||||
#include <ATen/mps/MPSProfiler.h>
 | 
			
		||||
#include <ATen/native/GridSamplerUtils.h>
 | 
			
		||||
#include <ATen/native/Pool.h>
 | 
			
		||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
 | 
			
		||||
#include <ATen/native/mps/OperationUtils.h>
 | 
			
		||||
#include <ATen/native/mps/kernels/GridSampler.h>
 | 
			
		||||
 | 
			
		||||
#ifndef AT_PER_OPERATOR_HEADERS
 | 
			
		||||
#include <ATen/Functions.h>
 | 
			
		||||
@ -9,9 +12,17 @@
 | 
			
		||||
#else
 | 
			
		||||
#include <ATen/ops/grid_sampler_2d.h>
 | 
			
		||||
#include <ATen/ops/grid_sampler_2d_native.h>
 | 
			
		||||
#include <ATen/ops/grid_sampler_3d_native.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace at::native {
 | 
			
		||||
 | 
			
		||||
#ifndef PYTORCH_JIT_COMPILE_SHADERS
 | 
			
		||||
static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
 | 
			
		||||
#else
 | 
			
		||||
#include <ATen/native/mps/GridSampler_metallib.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace mps {
 | 
			
		||||
static void grid_sampler_2d_mps_impl(Tensor& output,
 | 
			
		||||
                                     const Tensor& input,
 | 
			
		||||
@ -120,6 +131,96 @@ static void grid_sampler_2d_mps_impl(Tensor& output,
 | 
			
		||||
    runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static void grid_sampler_template(Tensor& output,
 | 
			
		||||
                                  const Tensor& input,
 | 
			
		||||
                                  const Tensor& grid,
 | 
			
		||||
                                  int64_t _interpolation_mode,
 | 
			
		||||
                                  int64_t _padding_mode,
 | 
			
		||||
                                  bool align_corners,
 | 
			
		||||
                                  int32_t sampler_dims,
 | 
			
		||||
                                  const std::string& op_name) {
 | 
			
		||||
  check_grid_sampler_common(input, grid);
 | 
			
		||||
  switch (sampler_dims) {
 | 
			
		||||
    case 2:
 | 
			
		||||
      check_grid_sampler_2d(input, grid);
 | 
			
		||||
      break;
 | 
			
		||||
    case 3:
 | 
			
		||||
      check_grid_sampler_3d(input, grid, _interpolation_mode);
 | 
			
		||||
      break;
 | 
			
		||||
    default:
 | 
			
		||||
      TORCH_INTERNAL_ASSERT(false, "Only 2D and 3D sampling are supported, but got: ", sampler_dims);
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK(input.scalar_type() == grid.scalar_type(),
 | 
			
		||||
              "expected input and grid to have the same type, but got ",
 | 
			
		||||
              input.scalar_type(),
 | 
			
		||||
              " and ",
 | 
			
		||||
              grid.scalar_type());
 | 
			
		||||
 | 
			
		||||
  auto interpolation_mode = static_cast<GridSamplerInterpolation>(_interpolation_mode);
 | 
			
		||||
  auto padding_mode = static_cast<GridSamplerPadding>(_padding_mode);
 | 
			
		||||
 | 
			
		||||
  switch (interpolation_mode) {
 | 
			
		||||
    case GridSamplerInterpolation::Bilinear:
 | 
			
		||||
      break;
 | 
			
		||||
    case GridSamplerInterpolation::Nearest:
 | 
			
		||||
      TORCH_CHECK(false, op_name, ": Unsupported Nearest interpolation");
 | 
			
		||||
      break;
 | 
			
		||||
    case GridSamplerInterpolation::Bicubic:
 | 
			
		||||
      TORCH_CHECK(false, op_name, ": Unsupported Bicubic interpolation");
 | 
			
		||||
      break;
 | 
			
		||||
    default:
 | 
			
		||||
      TORCH_CHECK(false, op_name, ": Unrecognised interpolation mode: ", _interpolation_mode);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  switch (padding_mode) {
 | 
			
		||||
    case GridSamplerPadding::Zeros:
 | 
			
		||||
    case GridSamplerPadding::Border:
 | 
			
		||||
    case GridSamplerPadding::Reflection:
 | 
			
		||||
      break;
 | 
			
		||||
    default:
 | 
			
		||||
      TORCH_CHECK(false, op_name, ": Unrecognised Padding Mode: ", _padding_mode);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto input_size = input.sizes();
 | 
			
		||||
  auto grid_size = grid.sizes();
 | 
			
		||||
  output.resize_({input_size[0], input_size[1], grid_size[1], grid_size[2], grid_size[3]}, MemoryFormat::Contiguous);
 | 
			
		||||
 | 
			
		||||
  auto dims = input.dim();
 | 
			
		||||
 | 
			
		||||
  GridSamplerParams<5> params;
 | 
			
		||||
  params.sampler_dims = sampler_dims;
 | 
			
		||||
  params.padding_mode = padding_mode;
 | 
			
		||||
  params.interpolation_mode = interpolation_mode;
 | 
			
		||||
  params.align_corners = align_corners;
 | 
			
		||||
 | 
			
		||||
  for (const auto dim : c10::irange(dims)) {
 | 
			
		||||
    params.output_sizes[dim] = safe_downcast<int32_t, int64_t>(output.size(dim));
 | 
			
		||||
    params.output_strides[dim] = safe_downcast<int32_t, int64_t>(output.stride(dim));
 | 
			
		||||
    params.input_sizes[dim] = safe_downcast<int32_t, int64_t>(input.size(dim));
 | 
			
		||||
    params.input_strides[dim] = safe_downcast<int32_t, int64_t>(input.stride(dim));
 | 
			
		||||
    params.grid_sizes[dim] = safe_downcast<int32_t, int64_t>(grid.size(dim));
 | 
			
		||||
    params.grid_strides[dim] = safe_downcast<int32_t, int64_t>(grid.stride(dim));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto num_threads = output.numel();
 | 
			
		||||
  MPSStream* mpsStream = getCurrentMPSStream();
 | 
			
		||||
 | 
			
		||||
  dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
 | 
			
		||||
    @autoreleasepool {
 | 
			
		||||
      id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
 | 
			
		||||
      auto pso = lib.getPipelineStateForFunc("grid_sampler_" + scalarToMetalTypeString(input));
 | 
			
		||||
 | 
			
		||||
      getMPSProfiler().beginProfileKernel(pso, op_name, {input, grid});
 | 
			
		||||
      [computeEncoder setComputePipelineState:pso];
 | 
			
		||||
      mtl_setArgs(computeEncoder, output, input, grid, params);
 | 
			
		||||
 | 
			
		||||
      mtl_dispatch1DJob(computeEncoder, pso, num_threads);
 | 
			
		||||
      getMPSProfiler().endProfileKernel(pso);
 | 
			
		||||
    }
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace mps
 | 
			
		||||
 | 
			
		||||
Tensor grid_sampler_2d_mps(const Tensor& input,
 | 
			
		||||
@ -135,4 +236,21 @@ Tensor grid_sampler_2d_mps(const Tensor& input,
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Tensor grid_sampler_3d_mps(const Tensor& input,
 | 
			
		||||
                           const Tensor& grid,
 | 
			
		||||
                           int64_t interpolation_mode,
 | 
			
		||||
                           int64_t padding_mode,
 | 
			
		||||
                           bool align_corners) {
 | 
			
		||||
  auto output = at::empty({0}, input.options(), MemoryFormat::Contiguous);
 | 
			
		||||
  mps::grid_sampler_template(output,
 | 
			
		||||
                             input,
 | 
			
		||||
                             grid,
 | 
			
		||||
                             interpolation_mode,
 | 
			
		||||
                             padding_mode,
 | 
			
		||||
                             align_corners,
 | 
			
		||||
                             /*sampler_dims=*/3,
 | 
			
		||||
                             /*op_name=*/"grid_sampler_3d");
 | 
			
		||||
  return output;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace at::native
 | 
			
		||||
 | 
			
		||||
@ -2931,6 +2931,7 @@
 | 
			
		||||
  dispatch:
 | 
			
		||||
    CPU: grid_sampler_3d_cpu
 | 
			
		||||
    CUDA: grid_sampler_3d_cuda
 | 
			
		||||
    MPS: grid_sampler_3d_mps
 | 
			
		||||
  autogen: grid_sampler_3d.out
 | 
			
		||||
 | 
			
		||||
# `grid_sampler_3d_backward` takes in `output_mask` to optimize performance for
 | 
			
		||||
@ -3447,8 +3448,12 @@
 | 
			
		||||
 | 
			
		||||
- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor? bias) -> Tensor
 | 
			
		||||
 | 
			
		||||
- func: fbgemm_linear_fp16_weight_fp32_activation.out(Tensor input, Tensor packed_weight, Tensor? bias, Tensor(a!) output) -> Tensor
 | 
			
		||||
 | 
			
		||||
- func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor
 | 
			
		||||
 | 
			
		||||
- func: fbgemm_linear_fp16_weight.out(Tensor input, Tensor packed_weight, Tensor bias, Tensor(a!) output) -> Tensor
 | 
			
		||||
 | 
			
		||||
- func: fbgemm_pack_quantized_matrix(Tensor input) -> Tensor
 | 
			
		||||
 | 
			
		||||
- func: fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor
 | 
			
		||||
@ -7462,7 +7467,7 @@
 | 
			
		||||
- func: indices(Tensor(a) self) -> Tensor(a)
 | 
			
		||||
  variants: method
 | 
			
		||||
  dispatch:
 | 
			
		||||
    SparseCPU, SparseCUDA, SparseMeta: indices_sparse
 | 
			
		||||
    SparseCPU, SparseCUDA, SparseMPS, SparseMeta: indices_sparse
 | 
			
		||||
    CompositeExplicitAutograd: indices_default
 | 
			
		||||
  device_check: NoCheck
 | 
			
		||||
  device_guard: False
 | 
			
		||||
@ -7470,7 +7475,7 @@
 | 
			
		||||
- func: values(Tensor(a) self) -> Tensor(a)
 | 
			
		||||
  variants: method
 | 
			
		||||
  dispatch:
 | 
			
		||||
    SparseCPU, SparseCUDA, SparseMeta: values_sparse
 | 
			
		||||
    SparseCPU, SparseCUDA, SparseMPS, SparseMeta: values_sparse
 | 
			
		||||
    SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: values_sparse_csr
 | 
			
		||||
    NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: values_nested
 | 
			
		||||
    CompositeExplicitAutograd: values_default
 | 
			
		||||
 | 
			
		||||
@ -196,9 +196,17 @@ C10_LAUNCH_BOUNDS_1(num_threads())
 | 
			
		||||
__global__ void coalesceValuesKernel(
 | 
			
		||||
  int64_t *segment_offsets, int64_t *value_indices,
 | 
			
		||||
  Dtype *values, Dtype *newValues,
 | 
			
		||||
  int64_t nnz, int64_t newNnz, int64_t stride) {
 | 
			
		||||
  int64_t nnz, int64_t newNnz,
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  int64_t nsegments,
 | 
			
		||||
#endif
 | 
			
		||||
  int64_t stride) {
 | 
			
		||||
 | 
			
		||||
  int seg = blockIdx.x * 4 + threadIdx.y;
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  int64_t seg = (blockIdx.x * gridDim.y + blockIdx.y) * 4 + threadIdx.y;
 | 
			
		||||
#else
 | 
			
		||||
  int64_t seg = blockIdx.x * 4 + threadIdx.y;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  // Number of values processed by each thread (grain size)
 | 
			
		||||
  const int SZ = 4;
 | 
			
		||||
@ -207,7 +215,11 @@ __global__ void coalesceValuesKernel(
 | 
			
		||||
    const int newValueRow = seg * stride;
 | 
			
		||||
    const int begin = segment_offsets[seg];
 | 
			
		||||
    const int end = (seg < newNnz - 1) ? segment_offsets[seg + 1] : nnz;
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
    const int startFeature = threadIdx.x + blockIdx.z * nsegments * SZ;
 | 
			
		||||
#else
 | 
			
		||||
    const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
 | 
			
		||||
#endif
 | 
			
		||||
    Acctype tmp[SZ];
 | 
			
		||||
    #pragma unroll
 | 
			
		||||
    for (int ii = 0; ii < SZ; ii++) {
 | 
			
		||||
@ -250,9 +262,17 @@ C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE*4)
 | 
			
		||||
__global__ void coalesceValuesKernel(
 | 
			
		||||
  int64_t *segment_offsets, int64_t *value_indices,
 | 
			
		||||
  bool *values, bool *newValues,
 | 
			
		||||
  int64_t nnz, int64_t newNnz, int64_t stride) {
 | 
			
		||||
  int64_t nnz, int64_t newNnz,
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  int64_t nsegments,
 | 
			
		||||
#endif
 | 
			
		||||
  int64_t stride) {
 | 
			
		||||
 | 
			
		||||
  int seg = blockIdx.x * 4 + threadIdx.y;
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  int64_t seg = (blockIdx.x * gridDim.y + blockIdx.y) * 4 + threadIdx.y;
 | 
			
		||||
#else
 | 
			
		||||
  int64_t seg = blockIdx.x * 4 + threadIdx.y;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  // Number of values processed by each thread (grain size)
 | 
			
		||||
  const int SZ = 4;
 | 
			
		||||
@ -261,7 +281,11 @@ __global__ void coalesceValuesKernel(
 | 
			
		||||
    const int newValueRow = seg * stride;
 | 
			
		||||
    const int begin = segment_offsets[seg];
 | 
			
		||||
    const int end = (seg < newNnz - 1) ? segment_offsets[seg + 1] : nnz;
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
    const int startFeature = threadIdx.x + blockIdx.z * nsegments * SZ;
 | 
			
		||||
#else
 | 
			
		||||
    const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
 | 
			
		||||
#endif
 | 
			
		||||
    bool tmp[SZ];
 | 
			
		||||
    #pragma unroll
 | 
			
		||||
    for (int ii = 0; ii < SZ; ii++) {
 | 
			
		||||
 | 
			
		||||
@ -106,8 +106,34 @@ SparseTensor _coalesce_sparse_cuda(const SparseTensor& self) {
 | 
			
		||||
    values = values.contiguous();
 | 
			
		||||
    int64_t stride = c10::multiply_integers(values.sizes().slice(1));
 | 
			
		||||
    int warp_size = at::cuda::warp_size();
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
    const int64_t BATCHING_SEGMENT = 4096;
 | 
			
		||||
    int64_t nsegments = ceil_div(newNnz, (int64_t) SZ);
 | 
			
		||||
    int64_t s_batch = ceil_div(nsegments, BATCHING_SEGMENT);
 | 
			
		||||
    dim3 grid(s_batch, (s_batch == 1) ? nsegments : BATCHING_SEGMENT, ceil_div(stride, (int64_t) warp_size*SZ));
 | 
			
		||||
#else
 | 
			
		||||
    dim3 grid(ceil_div(newNnz, (int64_t) SZ), ceil_div(stride, (int64_t) warp_size*SZ));
 | 
			
		||||
#endif
 | 
			
		||||
    dim3 block(warp_size, SZ);
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
    // Must duplicate the whole section otherwise does not compile on Windows
 | 
			
		||||
    AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
 | 
			
		||||
      at::ScalarType::ComplexHalf, at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool,
 | 
			
		||||
      values.scalar_type(), "coalesce_sparse_cuda", [&] {
 | 
			
		||||
        using cuda_accscalar_t = acc_type<scalar_t, /* is_cuda */ true>;
 | 
			
		||||
        apply::coalesceValuesKernel<scalar_t, cuda_accscalar_t><<<grid, block, 0, stream>>>(
 | 
			
		||||
          uniqueOffsets.data_ptr<int64_t>(),
 | 
			
		||||
          origIndices.data_ptr<int64_t>(),
 | 
			
		||||
          values.data_ptr<scalar_t>(),
 | 
			
		||||
          newValues.data_ptr<scalar_t>(),
 | 
			
		||||
          nnz,
 | 
			
		||||
          newNnz,
 | 
			
		||||
          nsegments,
 | 
			
		||||
          stride
 | 
			
		||||
        );
 | 
			
		||||
        C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
      });
 | 
			
		||||
#else
 | 
			
		||||
    AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
 | 
			
		||||
      at::ScalarType::ComplexHalf, at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool,
 | 
			
		||||
      values.scalar_type(), "coalesce_sparse_cuda", [&] {
 | 
			
		||||
@ -123,6 +149,7 @@ SparseTensor _coalesce_sparse_cuda(const SparseTensor& self) {
 | 
			
		||||
        );
 | 
			
		||||
        C10_CUDA_KERNEL_LAUNCH_CHECK();
 | 
			
		||||
      });
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
// this grid-strided version is slower but probably more flexible
 | 
			
		||||
 | 
			
		||||
@ -132,10 +132,10 @@ def main():
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        new_entry = copy.deepcopy(entry)
 | 
			
		||||
        # only change if abs(ratio) > entry.noise_margin /3.
 | 
			
		||||
        # only change if abs(ratio) > entry.noise_margin /5.
 | 
			
		||||
        new_entry.expected_value = (
 | 
			
		||||
            replace_with_zeros(result)
 | 
			
		||||
            if abs(ratio) > entry.noise_margin * 100 / 3
 | 
			
		||||
            if abs(ratio) > entry.noise_margin * 100 / 5
 | 
			
		||||
            else entry.expected_value
 | 
			
		||||
        )
 | 
			
		||||
        new_expected[key] = new_entry
 | 
			
		||||
 | 
			
		||||
@ -74,15 +74,15 @@ aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10650000000,0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
mm_loop_inductor_gpu,compile_time_instruction_count,4461000000,0.1
 | 
			
		||||
mm_loop_inductor_gpu,compile_time_instruction_count,4820968837,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8417000000,0.1
 | 
			
		||||
mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,8802129167,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
basic_NestedModule_eager,compile_time_instruction_count,9199000000,0.1
 | 
			
		||||
basic_NestedModule_eager,compile_time_instruction_count,9554000000,0.1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		
		
			
  | 
@ -147,7 +147,7 @@ static inline StreamIdType streamIdType(StreamId s) {
 | 
			
		||||
  // rightmost bit
 | 
			
		||||
  int mask_for_type = (1 << kStreamTypeBits) - 1;
 | 
			
		||||
  auto val = (s >> 1) & mask_for_type;
 | 
			
		||||
  TORCH_INTERNAL_ASSERT(val || !(s & 1), "invalid StreamId", s);
 | 
			
		||||
  TORCH_CHECK(val || !(s & 1), "invalid StreamId", s);
 | 
			
		||||
  return StreamIdType(val);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -276,7 +276,7 @@ cudaStream_t CUDAStream::stream() const {
 | 
			
		||||
  StreamIdType st = streamIdType(stream_id);
 | 
			
		||||
  size_t si = streamIdIndex(stream_id);
 | 
			
		||||
  if (st.isDefault()) {
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        si == 0,
 | 
			
		||||
        "Unrecognized stream ",
 | 
			
		||||
        stream_,
 | 
			
		||||
@ -291,7 +291,7 @@ cudaStream_t CUDAStream::stream() const {
 | 
			
		||||
    return reinterpret_cast<cudaStream_t>(stream_id);
 | 
			
		||||
  } else {
 | 
			
		||||
    auto streamType = st.getStreamType();
 | 
			
		||||
    TORCH_INTERNAL_ASSERT(
 | 
			
		||||
    TORCH_CHECK(
 | 
			
		||||
        streamType >= 1 && streamType <= max_stream_priorities,
 | 
			
		||||
        "Unrecognized stream ",
 | 
			
		||||
        stream_,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								docs/source/_static/img/aoti_debugging_guide/cuda_ima_cca.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								docs/source/_static/img/aoti_debugging_guide/cuda_ima_cca.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 220 KiB  | 
@ -504,6 +504,10 @@ The following set of APIs transform your model into a pipeline representation.
 | 
			
		||||
.. autoclass:: ScheduleZBVZeroBubble
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```{eval-rst}
 | 
			
		||||
.. autoclass:: ScheduleDualPipeV
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```{eval-rst}
 | 
			
		||||
.. autoclass:: PipelineScheduleSingle
 | 
			
		||||
  :members:
 | 
			
		||||
 | 
			
		||||
@ -3,8 +3,8 @@
 | 
			
		||||
NUMA Binding Utilities
 | 
			
		||||
======================
 | 
			
		||||
 | 
			
		||||
.. automodule:: torch.distributed.numa
 | 
			
		||||
.. automodule:: torch.numa
 | 
			
		||||
   :members:
 | 
			
		||||
 | 
			
		||||
.. automodule:: torch.distributed.numa.binding
 | 
			
		||||
.. automodule:: torch.numa.binding
 | 
			
		||||
   :members:
 | 
			
		||||
 | 
			
		||||
@ -23,17 +23,11 @@ The APIs and performance characteristics of these features may change.
 | 
			
		||||
:glob:
 | 
			
		||||
:maxdepth: 2
 | 
			
		||||
 | 
			
		||||
Install PyTorch <https://pytorch.org/get-started/locally/>
 | 
			
		||||
user_guide/index
 | 
			
		||||
pytorch-api
 | 
			
		||||
notes
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```{toctree}
 | 
			
		||||
:glob:
 | 
			
		||||
:hidden:
 | 
			
		||||
:maxdepth: 2
 | 
			
		||||
 | 
			
		||||
community/index
 | 
			
		||||
C++ <https://docs.pytorch.org/cppdocs/>
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Indices and tables
 | 
			
		||||
 | 
			
		||||
@ -56,6 +56,7 @@ via PyTorch's C++ operator registration APIs).
 | 
			
		||||
.. autofunction:: infer_schema
 | 
			
		||||
.. autoclass:: torch._library.custom_ops.CustomOpDef
 | 
			
		||||
   :members: set_kernel_enabled
 | 
			
		||||
.. autofunction:: get_kernel
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Low-level APIs
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,16 @@
 | 
			
		||||
(pytorch_api)=
 | 
			
		||||
# Python API
 | 
			
		||||
# Reference API
 | 
			
		||||
 | 
			
		||||
```{toctree}
 | 
			
		||||
:maxdepth: 1
 | 
			
		||||
 | 
			
		||||
C++ <https://docs.pytorch.org/cppdocs/>
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```{toctree}
 | 
			
		||||
:glob:
 | 
			
		||||
:maxdepth: 1
 | 
			
		||||
:caption: Python API
 | 
			
		||||
 | 
			
		||||
torch
 | 
			
		||||
nn
 | 
			
		||||
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user