mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-14 06:07:55 +08:00
Compare commits
15 Commits
ciflow/roc
...
cpp-docs-d
| Author | SHA1 | Date | |
|---|---|---|---|
| 2913cdf29d | |||
| 0661a232a5 | |||
| 5db844dafa | |||
| 73efad99d7 | |||
| df1268c311 | |||
| 84f9f1541d | |||
| 27c0c126bf | |||
| 670873155a | |||
| 923737c510 | |||
| 13d5b14a73 | |||
| a35a42b21c | |||
| 15956bc1e8 | |||
| b319ea1111 | |||
| ce4c68a5f6 | |||
| c6da4a59a3 |
@ -7,13 +7,13 @@ ENV LC_ALL en_US.UTF-8
|
|||||||
ENV LANG en_US.UTF-8
|
ENV LANG en_US.UTF-8
|
||||||
ENV LANGUAGE en_US.UTF-8
|
ENV LANGUAGE en_US.UTF-8
|
||||||
|
|
||||||
ARG DEVTOOLSET_VERSION=13
|
ARG DEVTOOLSET_VERSION=11
|
||||||
|
|
||||||
RUN yum -y update
|
RUN yum -y update
|
||||||
RUN yum -y install epel-release
|
RUN yum -y install epel-release
|
||||||
# install glibc-langpack-en make sure en_US.UTF-8 locale is available
|
# install glibc-langpack-en make sure en_US.UTF-8 locale is available
|
||||||
RUN yum -y install glibc-langpack-en
|
RUN yum -y install glibc-langpack-en
|
||||||
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
|
RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-toolchain
|
||||||
# Just add everything as a safe.directory for git since these will be used in multiple places with git
|
# Just add everything as a safe.directory for git since these will be used in multiple places with git
|
||||||
RUN git config --global --add safe.directory '*'
|
RUN git config --global --add safe.directory '*'
|
||||||
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||||
@ -41,7 +41,6 @@ RUN bash ./install_conda.sh && rm install_conda.sh
|
|||||||
# Install CUDA
|
# Install CUDA
|
||||||
FROM base as cuda
|
FROM base as cuda
|
||||||
ARG CUDA_VERSION=12.6
|
ARG CUDA_VERSION=12.6
|
||||||
ARG DEVTOOLSET_VERSION=13
|
|
||||||
RUN rm -rf /usr/local/cuda-*
|
RUN rm -rf /usr/local/cuda-*
|
||||||
ADD ./common/install_cuda.sh install_cuda.sh
|
ADD ./common/install_cuda.sh install_cuda.sh
|
||||||
COPY ./common/install_nccl.sh install_nccl.sh
|
COPY ./common/install_nccl.sh install_nccl.sh
|
||||||
@ -51,8 +50,7 @@ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
|
|||||||
# Preserve CUDA_VERSION for the builds
|
# Preserve CUDA_VERSION for the builds
|
||||||
ENV CUDA_VERSION=${CUDA_VERSION}
|
ENV CUDA_VERSION=${CUDA_VERSION}
|
||||||
# Make things in our path by default
|
# Make things in our path by default
|
||||||
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:$PATH
|
||||||
|
|
||||||
|
|
||||||
FROM cuda as cuda12.6
|
FROM cuda as cuda12.6
|
||||||
RUN bash ./install_cuda.sh 12.6
|
RUN bash ./install_cuda.sh 12.6
|
||||||
@ -70,22 +68,8 @@ FROM cuda as cuda13.0
|
|||||||
RUN bash ./install_cuda.sh 13.0
|
RUN bash ./install_cuda.sh 13.0
|
||||||
ENV DESIRED_CUDA=13.0
|
ENV DESIRED_CUDA=13.0
|
||||||
|
|
||||||
FROM ${ROCM_IMAGE} as rocm_base
|
FROM ${ROCM_IMAGE} as rocm
|
||||||
ARG DEVTOOLSET_VERSION=13
|
|
||||||
ENV LC_ALL en_US.UTF-8
|
|
||||||
ENV LANG en_US.UTF-8
|
|
||||||
ENV LANGUAGE en_US.UTF-8
|
|
||||||
# Install devtoolset on ROCm base image
|
|
||||||
RUN yum -y update && \
|
|
||||||
yum -y install epel-release && \
|
|
||||||
yum -y install glibc-langpack-en && \
|
|
||||||
yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb
|
|
||||||
RUN git config --global --add safe.directory '*'
|
|
||||||
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
|
|
||||||
|
|
||||||
FROM rocm_base as rocm
|
|
||||||
ARG PYTORCH_ROCM_ARCH
|
ARG PYTORCH_ROCM_ARCH
|
||||||
ARG DEVTOOLSET_VERSION=13
|
|
||||||
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
|
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}
|
||||||
ADD ./common/install_mkl.sh install_mkl.sh
|
ADD ./common/install_mkl.sh install_mkl.sh
|
||||||
RUN bash ./install_mkl.sh && rm install_mkl.sh
|
RUN bash ./install_mkl.sh && rm install_mkl.sh
|
||||||
@ -104,7 +88,6 @@ COPY --from=cuda13.0 /usr/local/cuda-13.0 /usr/local/cuda-13.0
|
|||||||
|
|
||||||
# Final step
|
# Final step
|
||||||
FROM ${BASE_TARGET} as final
|
FROM ${BASE_TARGET} as final
|
||||||
ARG DEVTOOLSET_VERSION=13
|
|
||||||
COPY --from=openssl /opt/openssl /opt/openssl
|
COPY --from=openssl /opt/openssl /opt/openssl
|
||||||
COPY --from=patchelf /patchelf /usr/local/bin/patchelf
|
COPY --from=patchelf /patchelf /usr/local/bin/patchelf
|
||||||
COPY --from=conda /opt/conda /opt/conda
|
COPY --from=conda /opt/conda /opt/conda
|
||||||
|
|||||||
@ -36,7 +36,11 @@ case ${DOCKER_TAG_PREFIX} in
|
|||||||
;;
|
;;
|
||||||
rocm*)
|
rocm*)
|
||||||
BASE_TARGET=rocm
|
BASE_TARGET=rocm
|
||||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||||
|
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||||
|
if [[ "$ROCM_VERSION" == *"7.0"* ]]; then
|
||||||
|
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||||
|
fi
|
||||||
EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}"
|
EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}"
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
@ -59,7 +63,7 @@ docker build \
|
|||||||
--target final \
|
--target final \
|
||||||
--progress plain \
|
--progress plain \
|
||||||
--build-arg "BASE_TARGET=${BASE_TARGET}" \
|
--build-arg "BASE_TARGET=${BASE_TARGET}" \
|
||||||
--build-arg "DEVTOOLSET_VERSION=13" \
|
--build-arg "DEVTOOLSET_VERSION=11" \
|
||||||
${EXTRA_BUILD_ARGS} \
|
${EXTRA_BUILD_ARGS} \
|
||||||
-t ${tmp_tag} \
|
-t ${tmp_tag} \
|
||||||
$@ \
|
$@ \
|
||||||
|
|||||||
@ -168,18 +168,6 @@ case "$tag" in
|
|||||||
VISION=yes
|
VISION=yes
|
||||||
TRITON=yes
|
TRITON=yes
|
||||||
;;
|
;;
|
||||||
pytorch-linux-jammy-py3.11-clang12)
|
|
||||||
ANACONDA_PYTHON_VERSION=3.11
|
|
||||||
CLANG_VERSION=12
|
|
||||||
VISION=no
|
|
||||||
TRITON=no
|
|
||||||
;;
|
|
||||||
pytorch-linux-jammy-py3.12-clang12)
|
|
||||||
ANACONDA_PYTHON_VERSION=3.12
|
|
||||||
CLANG_VERSION=12
|
|
||||||
VISION=no
|
|
||||||
TRITON=no
|
|
||||||
;;
|
|
||||||
pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-jammy-rocm-n-py3-benchmarks | pytorch-linux-noble-rocm-n-py3)
|
pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-jammy-rocm-n-py3-benchmarks | pytorch-linux-noble-rocm-n-py3)
|
||||||
if [[ $tag =~ "jammy" ]]; then
|
if [[ $tag =~ "jammy" ]]; then
|
||||||
ANACONDA_PYTHON_VERSION=3.10
|
ANACONDA_PYTHON_VERSION=3.10
|
||||||
@ -207,9 +195,9 @@ case "$tag" in
|
|||||||
NINJA_VERSION=1.9.0
|
NINJA_VERSION=1.9.0
|
||||||
TRITON=yes
|
TRITON=yes
|
||||||
;;
|
;;
|
||||||
pytorch-linux-noble-xpu-n-py3 | pytorch-linux-noble-xpu-n-py3-inductor-benchmarks)
|
pytorch-linux-jammy-xpu-n-py3 | pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks)
|
||||||
ANACONDA_PYTHON_VERSION=3.10
|
ANACONDA_PYTHON_VERSION=3.10
|
||||||
GCC_VERSION=13
|
GCC_VERSION=11
|
||||||
VISION=yes
|
VISION=yes
|
||||||
XPU_VERSION=2025.2
|
XPU_VERSION=2025.2
|
||||||
NINJA_VERSION=1.9.0
|
NINJA_VERSION=1.9.0
|
||||||
@ -260,12 +248,6 @@ case "$tag" in
|
|||||||
HALIDE=yes
|
HALIDE=yes
|
||||||
TRITON=yes
|
TRITON=yes
|
||||||
;;
|
;;
|
||||||
pytorch-linux-jammy-cuda12.8-py3.12-pallas)
|
|
||||||
CUDA_VERSION=12.8.1
|
|
||||||
ANACONDA_PYTHON_VERSION=3.12
|
|
||||||
GCC_VERSION=11
|
|
||||||
PALLAS=yes
|
|
||||||
;;
|
|
||||||
pytorch-linux-jammy-py3.12-triton-cpu)
|
pytorch-linux-jammy-py3.12-triton-cpu)
|
||||||
CUDA_VERSION=12.6
|
CUDA_VERSION=12.6
|
||||||
ANACONDA_PYTHON_VERSION=3.12
|
ANACONDA_PYTHON_VERSION=3.12
|
||||||
@ -279,9 +261,9 @@ case "$tag" in
|
|||||||
PYTHON_VERSION=3.10
|
PYTHON_VERSION=3.10
|
||||||
CUDA_VERSION=12.8.1
|
CUDA_VERSION=12.8.1
|
||||||
;;
|
;;
|
||||||
pytorch-linux-jammy-aarch64-py3.10-gcc13)
|
pytorch-linux-jammy-aarch64-py3.10-gcc11)
|
||||||
ANACONDA_PYTHON_VERSION=3.10
|
ANACONDA_PYTHON_VERSION=3.10
|
||||||
GCC_VERSION=13
|
GCC_VERSION=11
|
||||||
ACL=yes
|
ACL=yes
|
||||||
VISION=yes
|
VISION=yes
|
||||||
OPENBLAS=yes
|
OPENBLAS=yes
|
||||||
@ -289,19 +271,9 @@ case "$tag" in
|
|||||||
# from pytorch/llvm:9.0.1 is x86 specific
|
# from pytorch/llvm:9.0.1 is x86 specific
|
||||||
SKIP_LLVM_SRC_BUILD_INSTALL=yes
|
SKIP_LLVM_SRC_BUILD_INSTALL=yes
|
||||||
;;
|
;;
|
||||||
pytorch-linux-jammy-aarch64-py3.10-clang21)
|
pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks)
|
||||||
ANACONDA_PYTHON_VERSION=3.10
|
ANACONDA_PYTHON_VERSION=3.10
|
||||||
CLANG_VERSION=21
|
GCC_VERSION=11
|
||||||
ACL=yes
|
|
||||||
VISION=yes
|
|
||||||
OPENBLAS=yes
|
|
||||||
# snadampal: skipping llvm src build install because the current version
|
|
||||||
# from pytorch/llvm:9.0.1 is x86 specific
|
|
||||||
SKIP_LLVM_SRC_BUILD_INSTALL=yes
|
|
||||||
;;
|
|
||||||
pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks)
|
|
||||||
ANACONDA_PYTHON_VERSION=3.10
|
|
||||||
GCC_VERSION=13
|
|
||||||
ACL=yes
|
ACL=yes
|
||||||
VISION=yes
|
VISION=yes
|
||||||
OPENBLAS=yes
|
OPENBLAS=yes
|
||||||
@ -387,7 +359,6 @@ docker build \
|
|||||||
--build-arg "INDUCTOR_BENCHMARKS=${INDUCTOR_BENCHMARKS}" \
|
--build-arg "INDUCTOR_BENCHMARKS=${INDUCTOR_BENCHMARKS}" \
|
||||||
--build-arg "EXECUTORCH=${EXECUTORCH}" \
|
--build-arg "EXECUTORCH=${EXECUTORCH}" \
|
||||||
--build-arg "HALIDE=${HALIDE}" \
|
--build-arg "HALIDE=${HALIDE}" \
|
||||||
--build-arg "PALLAS=${PALLAS}" \
|
|
||||||
--build-arg "XPU_VERSION=${XPU_VERSION}" \
|
--build-arg "XPU_VERSION=${XPU_VERSION}" \
|
||||||
--build-arg "UNINSTALL_DILL=${UNINSTALL_DILL}" \
|
--build-arg "UNINSTALL_DILL=${UNINSTALL_DILL}" \
|
||||||
--build-arg "ACL=${ACL:-}" \
|
--build-arg "ACL=${ACL:-}" \
|
||||||
|
|||||||
@ -1 +0,0 @@
|
|||||||
0.8.0
|
|
||||||
@ -1 +1 @@
|
|||||||
0add68262ab0a2e33b84524346cb27cbb2787356
|
7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd
|
||||||
|
|||||||
@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then
|
|||||||
# work around ubuntu apt-get conflicts
|
# work around ubuntu apt-get conflicts
|
||||||
sudo apt-get -y -f install
|
sudo apt-get -y -f install
|
||||||
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
|
wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add -
|
||||||
if [[ $CLANG_VERSION -ge 18 ]]; then
|
if [[ $CLANG_VERSION == 18 ]]; then
|
||||||
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main"
|
apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main"
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@ -7,11 +7,11 @@ if [ -n "$GCC_VERSION" ]; then
|
|||||||
# Need the official toolchain repo to get alternate packages
|
# Need the official toolchain repo to get alternate packages
|
||||||
add-apt-repository ppa:ubuntu-toolchain-r/test
|
add-apt-repository ppa:ubuntu-toolchain-r/test
|
||||||
apt-get update
|
apt-get update
|
||||||
apt-get install -y g++-$GCC_VERSION gfortran-$GCC_VERSION
|
apt-get install -y g++-$GCC_VERSION
|
||||||
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50
|
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50
|
||||||
update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50
|
update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50
|
||||||
update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50
|
update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50
|
||||||
update-alternatives --install /usr/bin/gfortran gfortran /usr/bin/gfortran-"$GCC_VERSION" 50
|
|
||||||
|
|
||||||
# Cleanup package manager
|
# Cleanup package manager
|
||||||
apt-get autoclean && apt-get clean
|
apt-get autoclean && apt-get clean
|
||||||
|
|||||||
@ -1,40 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -ex
|
|
||||||
|
|
||||||
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
|
|
||||||
|
|
||||||
# Get the pinned JAX version (same for all CUDA versions)
|
|
||||||
JAX_VERSION=$(get_pinned_commit /ci_commit_pins/jax)
|
|
||||||
|
|
||||||
function install_jax_12() {
|
|
||||||
echo "Installing JAX ${JAX_VERSION} with CUDA 12 support"
|
|
||||||
pip_install "jax[cuda12]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
|
||||||
|
|
||||||
# Verify installation
|
|
||||||
python -c "import jax" # check for errors
|
|
||||||
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 12"
|
|
||||||
}
|
|
||||||
|
|
||||||
function install_jax_13() {
|
|
||||||
echo "Installing JAX ${JAX_VERSION} with CUDA 13 support"
|
|
||||||
pip_install "jax[cuda13]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
|
|
||||||
|
|
||||||
# Verify installation
|
|
||||||
python -c "import jax" # check for errors
|
|
||||||
echo "JAX ${JAX_VERSION} installation completed successfully for CUDA 13"
|
|
||||||
}
|
|
||||||
|
|
||||||
# idiomatic parameter and option handling in sh
|
|
||||||
while test $# -gt 0
|
|
||||||
do
|
|
||||||
case "$1" in
|
|
||||||
12.4|12.6|12.6.*|12.8|12.8.*|12.9|12.9.*) install_jax_12;
|
|
||||||
;;
|
|
||||||
13.0|13.0.*) install_jax_13;
|
|
||||||
;;
|
|
||||||
*) echo "bad argument $1"; exit 1
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
shift
|
|
||||||
done
|
|
||||||
@ -1,56 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# Script used only in CD pipeline
|
|
||||||
|
|
||||||
set -ex
|
|
||||||
|
|
||||||
# install dependencies
|
|
||||||
dnf -y install gmp-devel libmpc-devel texinfo flex bison
|
|
||||||
|
|
||||||
cd /usr/local/src
|
|
||||||
# fetch source for gcc 13
|
|
||||||
git clone --depth 1 --single-branch -b releases/gcc-13.3.0 https://github.com/gcc-mirror/gcc.git gcc-13.3.0
|
|
||||||
|
|
||||||
mkdir -p gcc-13.3.0/build-gomp
|
|
||||||
cd gcc-13.3.0/build-gomp
|
|
||||||
|
|
||||||
# configure gcc build
|
|
||||||
# I got these flags by:
|
|
||||||
# 1. downloading the source rpm for gcc-11 on AlmaLinux 8 container
|
|
||||||
# dnf install -y dnf-plugins-core rpmdevtools
|
|
||||||
# dnf download --source libgomp
|
|
||||||
# 2. extracting the gcc.spec from the source.
|
|
||||||
# rpmdev-extract gcc-xx.src.rpm
|
|
||||||
# 3. extracting optflags and ld_flags from gcc.spec:
|
|
||||||
# rpm --eval '%{optflags}'
|
|
||||||
# rpm --eval '%{build_ldflags}'
|
|
||||||
#
|
|
||||||
# I had to remove the following flags because they didn't compile for this version of libgomp:
|
|
||||||
# -Werror=format-security
|
|
||||||
# -specs=/usr/lib/rpm/redhat/redhat-hardened-cc1
|
|
||||||
# -specs=/usr/lib/rpm/redhat/redhat-annobin-cc1
|
|
||||||
#
|
|
||||||
# I added -march=armv8-a -mtune=generic to make them explicit. I don't think they're strictly needed.
|
|
||||||
|
|
||||||
OPT_FLAGS='-O2 -march=armv8-a -mtune=generic'\
|
|
||||||
' -fexceptions -g -grecord-gcc-switches -pipe -Wall'\
|
|
||||||
' -Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS'\
|
|
||||||
' -fstack-protector-strong -fasynchronous-unwind-tables'\
|
|
||||||
' -fstack-clash-protection'
|
|
||||||
|
|
||||||
LDFLAGS='-Wl,-z,relro -Wl,--as-needed -Wl,-z,now'
|
|
||||||
|
|
||||||
CFLAGS="$OPT_FLAGS" \
|
|
||||||
CXXFLAGS="$OPT_FLAGS" \
|
|
||||||
LDFLAGS="$LDFLAGS" \
|
|
||||||
../configure \
|
|
||||||
--prefix=/usr \
|
|
||||||
--libdir=/usr/lib64 \
|
|
||||||
--enable-languages=c,c++ \
|
|
||||||
--disable-multilib \
|
|
||||||
--disable-bootstrap \
|
|
||||||
--enable-libgomp
|
|
||||||
|
|
||||||
# only build libgomp
|
|
||||||
make -j$(nproc) all-target-libgomp
|
|
||||||
|
|
||||||
make install-target-libgomp
|
|
||||||
@ -10,7 +10,6 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" -
|
|||||||
|
|
||||||
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
|
OPENBLAS_CHECKOUT_DIR="OpenBLAS"
|
||||||
OPENBLAS_BUILD_FLAGS="
|
OPENBLAS_BUILD_FLAGS="
|
||||||
CC=gcc
|
|
||||||
NUM_THREADS=128
|
NUM_THREADS=128
|
||||||
USE_OPENMP=1
|
USE_OPENMP=1
|
||||||
NO_SHARED=0
|
NO_SHARED=0
|
||||||
|
|||||||
@ -9,7 +9,7 @@ set -xe
|
|||||||
|
|
||||||
function install_ubuntu() {
|
function install_ubuntu() {
|
||||||
. /etc/os-release
|
. /etc/os-release
|
||||||
if [[ ! " jammy noble " =~ " ${VERSION_CODENAME} " ]]; then
|
if [[ ! " jammy " =~ " ${VERSION_CODENAME} " ]]; then
|
||||||
echo "Ubuntu version ${VERSION_CODENAME} not supported"
|
echo "Ubuntu version ${VERSION_CODENAME} not supported"
|
||||||
exit
|
exit
|
||||||
fi
|
fi
|
||||||
@ -35,24 +35,25 @@ function install_ubuntu() {
|
|||||||
# The xpu-smi packages
|
# The xpu-smi packages
|
||||||
apt-get install -y flex bison xpu-smi
|
apt-get install -y flex bison xpu-smi
|
||||||
|
|
||||||
# Compute and Media Runtimes
|
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
|
||||||
if [[ " ${VERSION_CODENAME} " =~ " noble " ]]; then
|
# Compute and Media Runtimes
|
||||||
apt-get install -y \
|
apt-get install -y \
|
||||||
intel-opencl-icd libze-intel-gpu1 libze1 \
|
intel-opencl-icd intel-level-zero-gpu level-zero \
|
||||||
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
|
intel-media-va-driver-non-free libmfx1 libmfxgen1 libvpl2 \
|
||||||
libegl-mesa0 libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
|
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
|
||||||
libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
|
libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
|
||||||
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
|
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo
|
||||||
else # jammy
|
# Development Packages
|
||||||
|
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev
|
||||||
|
else # rolling driver
|
||||||
apt-get install -y \
|
apt-get install -y \
|
||||||
intel-opencl-icd libze-intel-gpu1 libze1 \
|
intel-opencl-icd libze-intel-gpu1 libze1 \
|
||||||
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
|
intel-media-va-driver-non-free libmfx-gen1 libvpl2 \
|
||||||
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
|
libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \
|
||||||
libglapi-mesa libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
|
libglapi-mesa libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \
|
||||||
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
|
mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo intel-ocloc
|
||||||
|
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
|
||||||
fi
|
fi
|
||||||
# Development Packages
|
|
||||||
apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev libze-dev
|
|
||||||
|
|
||||||
# Install Intel Support Packages
|
# Install Intel Support Packages
|
||||||
apt-get install -y ${XPU_PACKAGES}
|
apt-get install -y ${XPU_PACKAGES}
|
||||||
@ -65,7 +66,7 @@ function install_ubuntu() {
|
|||||||
function install_rhel() {
|
function install_rhel() {
|
||||||
. /etc/os-release
|
. /etc/os-release
|
||||||
if [[ "${ID}" == "rhel" ]]; then
|
if [[ "${ID}" == "rhel" ]]; then
|
||||||
if [[ ! " 8.8 8.10 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then
|
if [[ ! " 8.8 8.9 9.0 9.2 9.3 " =~ " ${VERSION_ID} " ]]; then
|
||||||
echo "RHEL version ${VERSION_ID} not supported"
|
echo "RHEL version ${VERSION_ID} not supported"
|
||||||
exit
|
exit
|
||||||
fi
|
fi
|
||||||
@ -146,7 +147,7 @@ function install_sles() {
|
|||||||
XPU_DRIVER_VERSION=""
|
XPU_DRIVER_VERSION=""
|
||||||
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
|
if [[ "${XPU_DRIVER_TYPE,,}" == "lts" ]]; then
|
||||||
# Use GPU driver LTS releases
|
# Use GPU driver LTS releases
|
||||||
XPU_DRIVER_VERSION="/lts/2523"
|
XPU_DRIVER_VERSION="/lts/2350"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Default use Intel® oneAPI Deep Learning Essentials 2025.1
|
# Default use Intel® oneAPI Deep Learning Essentials 2025.1
|
||||||
|
|||||||
@ -49,7 +49,11 @@ case ${DOCKER_TAG_PREFIX} in
|
|||||||
fi
|
fi
|
||||||
BASE_TARGET=rocm
|
BASE_TARGET=rocm
|
||||||
GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
|
GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete
|
||||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||||
|
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||||
|
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
|
||||||
|
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||||
|
fi
|
||||||
DOCKER_GPU_BUILD_ARG="--build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg ROCM_VERSION=${GPU_ARCH_VERSION}"
|
DOCKER_GPU_BUILD_ARG="--build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg ROCM_VERSION=${GPU_ARCH_VERSION}"
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
|
|||||||
@ -50,10 +50,6 @@ RUN rm install_ninja.sh
|
|||||||
ENV PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/bin:$PATH
|
||||||
ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH
|
ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
# Build a newer version of libgomp than that supported in in Almalinux 8.
|
|
||||||
COPY ./common/install_libgomp.sh install_libgomp.sh
|
|
||||||
RUN bash ./install_libgomp.sh && rm install_libgomp.sh
|
|
||||||
|
|
||||||
# git236+ would refuse to run git commands in repos owned by other users
|
# git236+ would refuse to run git commands in repos owned by other users
|
||||||
# Which causes version check to fail, as pytorch repo is bind-mounted into the image
|
# Which causes version check to fail, as pytorch repo is bind-mounted into the image
|
||||||
# Override this behaviour by treating every folder as safe
|
# Override this behaviour by treating every folder as safe
|
||||||
|
|||||||
@ -87,7 +87,11 @@ case ${image} in
|
|||||||
MANY_LINUX_VERSION="2_28"
|
MANY_LINUX_VERSION="2_28"
|
||||||
DEVTOOLSET_VERSION="11"
|
DEVTOOLSET_VERSION="11"
|
||||||
GPU_IMAGE=rocm/dev-almalinux-8:${GPU_ARCH_VERSION}-complete
|
GPU_IMAGE=rocm/dev-almalinux-8:${GPU_ARCH_VERSION}-complete
|
||||||
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151"
|
PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||||
|
# add gfx950, gfx115x conditionally starting in ROCm 7.0
|
||||||
|
if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then
|
||||||
|
PYTORCH_ROCM_ARCH="${PYTORCH_ROCM_ARCH};gfx950;gfx1150;gfx1151"
|
||||||
|
fi
|
||||||
DOCKER_GPU_BUILD_ARG="--build-arg ROCM_VERSION=${GPU_ARCH_VERSION} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}"
|
DOCKER_GPU_BUILD_ARG="--build-arg ROCM_VERSION=${GPU_ARCH_VERSION} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}"
|
||||||
;;
|
;;
|
||||||
manylinux2_28-builder:xpu)
|
manylinux2_28-builder:xpu)
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
3.5.1
|
3.5.0
|
||||||
|
|||||||
@ -143,15 +143,6 @@ COPY ci_commit_pins/halide.txt halide.txt
|
|||||||
RUN if [ -n "${HALIDE}" ]; then bash ./install_halide.sh; fi
|
RUN if [ -n "${HALIDE}" ]; then bash ./install_halide.sh; fi
|
||||||
RUN rm install_halide.sh common_utils.sh halide.txt
|
RUN rm install_halide.sh common_utils.sh halide.txt
|
||||||
|
|
||||||
ARG PALLAS
|
|
||||||
ARG CUDA_VERSION
|
|
||||||
# Install JAX with CUDA support (for Pallas)
|
|
||||||
COPY ./common/install_jax.sh install_jax.sh
|
|
||||||
COPY ./common/common_utils.sh common_utils.sh
|
|
||||||
COPY ./ci_commit_pins/jax.txt /ci_commit_pins/jax.txt
|
|
||||||
RUN if [ -n "${PALLAS}" ]; then bash ./install_jax.sh ${CUDA_VERSION}; fi
|
|
||||||
RUN rm -f install_jax.sh common_utils.sh /ci_commit_pins/jax.txt
|
|
||||||
|
|
||||||
ARG ONNX
|
ARG ONNX
|
||||||
# Install ONNX dependencies
|
# Install ONNX dependencies
|
||||||
COPY ./common/install_onnx.sh ./common/common_utils.sh ./
|
COPY ./common/install_onnx.sh ./common/common_utils.sh ./
|
||||||
|
|||||||
@ -8,11 +8,9 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from collections.abc import Callable # Python 3.11+
|
from typing import Any, Callable, Required, TypedDict # Python 3.11+
|
||||||
from typing import Any, Required, TypedDict
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from collections.abc import Callable
|
from typing import Any, Callable, TypedDict
|
||||||
from typing import Any, TypedDict
|
|
||||||
|
|
||||||
from typing_extensions import Required # Fallback for Python <3.11
|
from typing_extensions import Required # Fallback for Python <3.11
|
||||||
|
|
||||||
|
|||||||
@ -6,8 +6,8 @@ set -eou pipefail
|
|||||||
# The script expects DESIRED_CUDA and PACKAGE_NAME to be set
|
# The script expects DESIRED_CUDA and PACKAGE_NAME to be set
|
||||||
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||||
|
|
||||||
# https://github.com/icl-utk-edu/magma/pull/65
|
# post merge of https://github.com/icl-utk-edu/magma/pull/65
|
||||||
MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec
|
MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f
|
||||||
|
|
||||||
# Folders for the build
|
# Folders for the build
|
||||||
PACKAGE_FILES=${ROOT_DIR}/magma-rocm/package_files # metadata
|
PACKAGE_FILES=${ROOT_DIR}/magma-rocm/package_files # metadata
|
||||||
@ -20,7 +20,7 @@ mkdir -p ${PACKAGE_DIR} ${PACKAGE_OUTPUT}/linux-64 ${PACKAGE_BUILD} ${PACKAGE_RE
|
|||||||
|
|
||||||
# Fetch magma sources and verify checksum
|
# Fetch magma sources and verify checksum
|
||||||
pushd ${PACKAGE_DIR}
|
pushd ${PACKAGE_DIR}
|
||||||
git clone https://github.com/jeffdaily/magma
|
git clone https://github.com/icl-utk-edu/magma
|
||||||
pushd magma
|
pushd magma
|
||||||
git checkout ${MAGMA_VERSION}
|
git checkout ${MAGMA_VERSION}
|
||||||
popd
|
popd
|
||||||
|
|||||||
@ -168,16 +168,14 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
|
|||||||
# shellcheck disable=SC1091
|
# shellcheck disable=SC1091
|
||||||
source /opt/intel/oneapi/compiler/latest/env/vars.sh
|
source /opt/intel/oneapi/compiler/latest/env/vars.sh
|
||||||
# shellcheck disable=SC1091
|
# shellcheck disable=SC1091
|
||||||
source /opt/intel/oneapi/umf/latest/env/vars.sh
|
|
||||||
# shellcheck disable=SC1091
|
|
||||||
source /opt/intel/oneapi/ccl/latest/env/vars.sh
|
source /opt/intel/oneapi/ccl/latest/env/vars.sh
|
||||||
# shellcheck disable=SC1091
|
# shellcheck disable=SC1091
|
||||||
source /opt/intel/oneapi/mpi/latest/env/vars.sh
|
source /opt/intel/oneapi/mpi/latest/env/vars.sh
|
||||||
# shellcheck disable=SC1091
|
|
||||||
source /opt/intel/oneapi/pti/latest/env/vars.sh
|
|
||||||
# Enable XCCL build
|
# Enable XCCL build
|
||||||
export USE_XCCL=1
|
export USE_XCCL=1
|
||||||
export USE_MPI=0
|
export USE_MPI=0
|
||||||
|
# XPU kineto feature dependencies are not fully ready, disable kineto build as temp WA
|
||||||
|
export USE_KINETO=0
|
||||||
export TORCH_XPU_ARCH_LIST=pvc
|
export TORCH_XPU_ARCH_LIST=pvc
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@ -208,8 +208,6 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then
|
|||||||
source /opt/intel/oneapi/ccl/latest/env/vars.sh
|
source /opt/intel/oneapi/ccl/latest/env/vars.sh
|
||||||
# shellcheck disable=SC1091
|
# shellcheck disable=SC1091
|
||||||
source /opt/intel/oneapi/mpi/latest/env/vars.sh
|
source /opt/intel/oneapi/mpi/latest/env/vars.sh
|
||||||
# shellcheck disable=SC1091
|
|
||||||
source /opt/intel/oneapi/pti/latest/env/vars.sh
|
|
||||||
# Check XPU status before testing
|
# Check XPU status before testing
|
||||||
timeout 30 xpu-smi discovery || true
|
timeout 30 xpu-smi discovery || true
|
||||||
fi
|
fi
|
||||||
@ -826,11 +824,6 @@ test_inductor_halide() {
|
|||||||
assert_git_not_dirty
|
assert_git_not_dirty
|
||||||
}
|
}
|
||||||
|
|
||||||
test_inductor_pallas() {
|
|
||||||
python test/run_test.py --include inductor/test_pallas.py --verbose
|
|
||||||
assert_git_not_dirty
|
|
||||||
}
|
|
||||||
|
|
||||||
test_inductor_triton_cpu() {
|
test_inductor_triton_cpu() {
|
||||||
python test/run_test.py --include inductor/test_triton_cpu_backend.py inductor/test_torchinductor_strided_blocks.py --verbose
|
python test/run_test.py --include inductor/test_triton_cpu_backend.py inductor/test_torchinductor_strided_blocks.py --verbose
|
||||||
assert_git_not_dirty
|
assert_git_not_dirty
|
||||||
@ -1731,8 +1724,6 @@ elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
|
|||||||
test_inductor_distributed
|
test_inductor_distributed
|
||||||
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
|
elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then
|
||||||
test_inductor_halide
|
test_inductor_halide
|
||||||
elif [[ "${TEST_CONFIG}" == *inductor-pallas* ]]; then
|
|
||||||
test_inductor_pallas
|
|
||||||
elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
|
elif [[ "${TEST_CONFIG}" == *inductor-triton-cpu* ]]; then
|
||||||
test_inductor_triton_cpu
|
test_inductor_triton_cpu
|
||||||
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
|
elif [[ "${TEST_CONFIG}" == *inductor-micro-benchmark* ]]; then
|
||||||
|
|||||||
@ -70,7 +70,7 @@ sccache --zero-stats
|
|||||||
sccache --show-stats
|
sccache --show-stats
|
||||||
|
|
||||||
# Build the wheel
|
# Build the wheel
|
||||||
python -m build --wheel --no-isolation
|
python -m build --wheel --no-build-isolation
|
||||||
if ($LASTEXITCODE -ne 0) { exit 1 }
|
if ($LASTEXITCODE -ne 0) { exit 1 }
|
||||||
|
|
||||||
# Install the wheel locally
|
# Install the wheel locally
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
name: 🚀 New Feature for Release
|
name: 🚀 Release highlight for proposed Feature
|
||||||
description: Submit a Release highlight for proposed Feature
|
description: Submit a Release highlight for proposed Feature
|
||||||
labels: ["release-feature-request"]
|
labels: ["release-feature-request"]
|
||||||
|
|
||||||
body:
|
body:
|
||||||
- type: textarea
|
- type: textarea
|
||||||
attributes:
|
attributes:
|
||||||
label: New Feature for Release
|
label: Release highlight for proposed Feature
|
||||||
description: >
|
description: >
|
||||||
Example: “A torch.special module, analogous to SciPy's special module.”
|
Example: “A torch.special module, analogous to SciPy's special module.”
|
||||||
- type: input
|
- type: input
|
||||||
|
|||||||
12
.github/actions/pytest-cache-download/action.yml
vendored
12
.github/actions/pytest-cache-download/action.yml
vendored
@ -38,9 +38,9 @@ runs:
|
|||||||
run: |
|
run: |
|
||||||
python3 .github/scripts/pytest_cache.py \
|
python3 .github/scripts/pytest_cache.py \
|
||||||
--download \
|
--download \
|
||||||
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
|
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
|
||||||
--pr_identifier "$GITHUB_REF" \
|
--pr_identifier $GITHUB_REF \
|
||||||
--job_identifier "$JOB_IDENTIFIER" \
|
--job_identifier $JOB_IDENTIFIER \
|
||||||
--temp_dir "$RUNNER_TEMP" \
|
--temp_dir $RUNNER_TEMP \
|
||||||
--repo "$REPO" \
|
--repo $REPO \
|
||||||
--bucket "$BUCKET" \
|
--bucket $BUCKET \
|
||||||
|
|||||||
16
.github/actions/pytest-cache-upload/action.yml
vendored
16
.github/actions/pytest-cache-upload/action.yml
vendored
@ -47,11 +47,11 @@ runs:
|
|||||||
run: |
|
run: |
|
||||||
python3 .github/scripts/pytest_cache.py \
|
python3 .github/scripts/pytest_cache.py \
|
||||||
--upload \
|
--upload \
|
||||||
--cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \
|
--cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \
|
||||||
--pr_identifier "$GITHUB_REF" \
|
--pr_identifier $GITHUB_REF \
|
||||||
--job_identifier "$JOB_IDENTIFIER" \
|
--job_identifier $JOB_IDENTIFIER \
|
||||||
--sha "$SHA" \
|
--sha $SHA \
|
||||||
--test_config "$TEST_CONFIG" \
|
--test_config $TEST_CONFIG \
|
||||||
--shard "$SHARD" \
|
--shard $SHARD \
|
||||||
--repo "$REPO" \
|
--repo $REPO \
|
||||||
--temp_dir "$RUNNER_TEMP" \
|
--temp_dir $RUNNER_TEMP \
|
||||||
|
|||||||
2
.github/ci_commit_pins/audio.txt
vendored
2
.github/ci_commit_pins/audio.txt
vendored
@ -1 +1 @@
|
|||||||
ad5816f0eee1c873df1b7d371c69f1f811a89387
|
3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2
|
||||||
|
|||||||
2
.github/ci_commit_pins/vision.txt
vendored
2
.github/ci_commit_pins/vision.txt
vendored
@ -1 +1 @@
|
|||||||
ccb801b88af136454798b945175c4c87e636ac33
|
cfbc5c2f1c798991715a6b06bb3ce46478c4487c
|
||||||
|
|||||||
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
|||||||
e4d25697f9dc5eedaf8f0a5bf085c62c5455a53a
|
c8b09f5f77d6bf6fb7ed7a9aa83e5d8156b3a5e9
|
||||||
|
|||||||
125
.github/copilot-instructions.md
vendored
125
.github/copilot-instructions.md
vendored
@ -1,125 +0,0 @@
|
|||||||
# PyTorch Copilot Instructions
|
|
||||||
|
|
||||||
This is the PyTorch machine learning framework codebase. These instructions help AI agents navigate and contribute effectively.
|
|
||||||
|
|
||||||
## Architecture Overview
|
|
||||||
|
|
||||||
### Core Components
|
|
||||||
|
|
||||||
- **c10/** - Core library (C++-10 compatible) for essential, binary-size-conscious functionality
|
|
||||||
- **aten/** - ATen tensor library (C++), PyTorch's foundation without autograd
|
|
||||||
- `aten/src/ATen/native/` - Modern operator implementations (CPU/CUDA/MPS/sparse)
|
|
||||||
- `aten/src/ATen/native/native_functions.yaml` - **Critical**: Declarative operator registry
|
|
||||||
- **torch/** - Python bindings and public API
|
|
||||||
- `torch/csrc/` - C++ Python bindings (hand-written and generated)
|
|
||||||
- `torch/csrc/autograd/` - Reverse-mode automatic differentiation
|
|
||||||
- `torch/csrc/jit/` - TorchScript JIT compiler
|
|
||||||
- **torchgen/** - Code generation tooling that reads `native_functions.yaml`
|
|
||||||
- **tools/** - Build scripts, autograd derivatives, code generation
|
|
||||||
|
|
||||||
### The Code Generation Workflow
|
|
||||||
|
|
||||||
**Most operator changes require editing `native_functions.yaml`**, not direct C++ files. This YAML file:
|
|
||||||
1. Declares operator signatures, variants (function/method), and dispatch behavior
|
|
||||||
2. Gets processed by `torchgen/` to generate C++/Python bindings
|
|
||||||
3. Produces headers in `build/aten/src/ATen/` during compilation
|
|
||||||
|
|
||||||
Example entry structure:
|
|
||||||
```yaml
|
|
||||||
- func: my_op(Tensor self, Scalar alpha=1) -> Tensor
|
|
||||||
variants: function, method
|
|
||||||
dispatch:
|
|
||||||
CPU: my_op_cpu
|
|
||||||
CUDA: my_op_cuda
|
|
||||||
```
|
|
||||||
|
|
||||||
After editing `native_functions.yaml`, implement kernels in `aten/src/ATen/native/` (see `aten/src/ATen/native/README.md`).
|
|
||||||
|
|
||||||
## Development Workflows
|
|
||||||
|
|
||||||
### Building from Source
|
|
||||||
|
|
||||||
**Never run `setup.py` directly** - use pip with editable install:
|
|
||||||
```bash
|
|
||||||
python -m pip install --no-build-isolation -v -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
Speed up builds:
|
|
||||||
- `DEBUG=1` - Debug symbols with `-g -O0`
|
|
||||||
- `USE_CUDA=0` - Skip CUDA compilation
|
|
||||||
- `BUILD_TEST=0` - Skip C++ test binaries
|
|
||||||
- Install `ninja` (`pip install ninja`) for faster builds
|
|
||||||
- Use `ccache` for incremental compilation caching
|
|
||||||
|
|
||||||
Rebuild specific targets: `(cd build && ninja <target>)`
|
|
||||||
|
|
||||||
### Testing
|
|
||||||
|
|
||||||
**Critical**: DO NOT run entire test suites. Run specific tests only:
|
|
||||||
```bash
|
|
||||||
python test/test_torch.py TestTorch.test_specific_case
|
|
||||||
```
|
|
||||||
|
|
||||||
**Test structure**: All tests use `torch.testing._internal.common_utils`:
|
|
||||||
```python
|
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
||||||
|
|
||||||
class TestFeature(TestCase):
|
|
||||||
def test_something(self):
|
|
||||||
# Use self.assertEqual for tensor comparisons
|
|
||||||
pass
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
run_tests()
|
|
||||||
```
|
|
||||||
|
|
||||||
**For bug fixes**: Create a standalone reproduction script first, verify it fails, then fix and add to appropriate test file.
|
|
||||||
|
|
||||||
### Linting
|
|
||||||
|
|
||||||
Run linter (not pre-commit): `lintrunner -a` (auto-applies fixes)
|
|
||||||
|
|
||||||
## Project-Specific Conventions
|
|
||||||
|
|
||||||
### Memory and Storage
|
|
||||||
- **Storage is never nullptr** (but `StorageImpl.data` may be nullptr for unallocated outputs)
|
|
||||||
- CUDA device info lives in storage objects
|
|
||||||
|
|
||||||
### Python-C++ Integration (`torch/csrc/`)
|
|
||||||
- Always include `Python.h` **first** to avoid `_XOPEN_SOURCE` redefinition errors
|
|
||||||
- Use `pybind11::gil_scoped_acquire` before calling Python API or using `THPObjectPtr`
|
|
||||||
- Wrap entry points with `HANDLE_TH_ERRORS` / `END_HANDLE_TH_ERRORS` for exception conversion
|
|
||||||
|
|
||||||
### Dispatch System
|
|
||||||
- PyTorch uses operator dispatch to route calls to backend-specific kernels
|
|
||||||
- Prefer `CompositeExplicitAutograd` dispatch when writing device-agnostic compound ops
|
|
||||||
- See `aten/src/ATen/native/README.md` for dispatch keyword guidance
|
|
||||||
|
|
||||||
## Git Workflow (AI Agent Specific)
|
|
||||||
|
|
||||||
When preparing PRs from this environment:
|
|
||||||
```bash
|
|
||||||
git stash -u
|
|
||||||
git reset --hard $(cat /tmp/orig_work.txt) # Reset to LOCAL branch
|
|
||||||
git stash pop
|
|
||||||
# Resolve conflicts if necessary
|
|
||||||
```
|
|
||||||
|
|
||||||
## Common Gotchas
|
|
||||||
|
|
||||||
1. **Editing generated files** - If it's in `build/`, don't edit it. Edit the source template or `native_functions.yaml`
|
|
||||||
2. **NVCC template compilation** - NVCC is stricter about C++ than gcc/clang; code working on Linux may fail Windows CI
|
|
||||||
3. **Windows symbol visibility** - Use `TORCH_API` macros for exported symbols (required on Windows, optional on Linux)
|
|
||||||
4. **No internet access** - DO NOT attempt to install dependencies during development
|
|
||||||
|
|
||||||
## Key Files Reference
|
|
||||||
|
|
||||||
- `AGENTS.md` - Instructions specific to AI coding agents
|
|
||||||
- `CONTRIBUTING.md` - Comprehensive human contributor guide
|
|
||||||
- `GLOSSARY.md` - Terminology (ATen, kernels, operations, JIT, TorchScript)
|
|
||||||
- `aten/src/ATen/native/README.md` - Operator implementation guide
|
|
||||||
- `tools/autograd/derivatives.yaml` - Gradient definitions for autograd
|
|
||||||
|
|
||||||
## Performance Debugging
|
|
||||||
|
|
||||||
Use `TORCH_SHOW_CPP_STACKTRACES=1` for C++ traces in Python errors. For profiling, prefer `py-spy` over manual instrumentation.
|
|
||||||
22
.github/labeler.yml
vendored
22
.github/labeler.yml
vendored
@ -138,8 +138,7 @@
|
|||||||
- test/test_matmul_cuda.py
|
- test/test_matmul_cuda.py
|
||||||
- test/test_scaled_matmul_cuda.py
|
- test/test_scaled_matmul_cuda.py
|
||||||
- test/inductor/test_fp8.py
|
- test/inductor/test_fp8.py
|
||||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
- aten/src/ATen/native/cuda/Blas.cpp
|
||||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
|
||||||
- torch/**/*cublas*
|
- torch/**/*cublas*
|
||||||
- torch/_inductor/kernel/mm.py
|
- torch/_inductor/kernel/mm.py
|
||||||
- test/inductor/test_max_autotune.py
|
- test/inductor/test_max_autotune.py
|
||||||
@ -149,8 +148,7 @@
|
|||||||
- test/test_matmul_cuda.py
|
- test/test_matmul_cuda.py
|
||||||
- test/test_scaled_matmul_cuda.py
|
- test/test_scaled_matmul_cuda.py
|
||||||
- test/inductor/test_fp8.py
|
- test/inductor/test_fp8.py
|
||||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
- aten/src/ATen/native/cuda/Blas.cpp
|
||||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
|
||||||
- torch/**/*cublas*
|
- torch/**/*cublas*
|
||||||
- torch/_inductor/kernel/mm.py
|
- torch/_inductor/kernel/mm.py
|
||||||
- test/inductor/test_max_autotune.py
|
- test/inductor/test_max_autotune.py
|
||||||
@ -160,21 +158,7 @@
|
|||||||
- test/test_matmul_cuda.py
|
- test/test_matmul_cuda.py
|
||||||
- test/test_scaled_matmul_cuda.py
|
- test/test_scaled_matmul_cuda.py
|
||||||
- test/inductor/test_fp8.py
|
- test/inductor/test_fp8.py
|
||||||
- aten/src/ATen/native/cuda/*Blas.cpp
|
- aten/src/ATen/native/cuda/Blas.cpp
|
||||||
- aten/src/ATen/cuda/CUDA*Blas.*
|
|
||||||
- torch/_inductor/kernel/mm.py
|
- torch/_inductor/kernel/mm.py
|
||||||
- test/inductor/test_max_autotune.py
|
- test/inductor/test_max_autotune.py
|
||||||
- third_party/fbgemm
|
- third_party/fbgemm
|
||||||
|
|
||||||
"ciflow/mps":
|
|
||||||
- aten/src/ATen/mps/**
|
|
||||||
- aten/src/ATen/native/mps/**
|
|
||||||
- torch/_inductor/codegen/mps.py
|
|
||||||
- test/test_mps.py
|
|
||||||
- test/inductor/test_mps_basic.py
|
|
||||||
|
|
||||||
"ciflow/h100-symm-mem":
|
|
||||||
- torch/csrc/distributed/c10d/symm_mem/**
|
|
||||||
- torch/distributed/_symmetric_memory/**
|
|
||||||
- test/distributed/**/*mem*
|
|
||||||
- test/distributed/**/*mem*/**
|
|
||||||
|
|||||||
1
.github/nitpicks.yml
vendored
1
.github/nitpicks.yml
vendored
@ -10,4 +10,3 @@
|
|||||||
pathFilter:
|
pathFilter:
|
||||||
- 'torch/csrc/inductor/aoti_torch/c/*'
|
- 'torch/csrc/inductor/aoti_torch/c/*'
|
||||||
- 'torch/csrc/inductor/aoti_torch/generated/*'
|
- 'torch/csrc/inductor/aoti_torch/generated/*'
|
||||||
- 'torch/csrc/stable/c/*'
|
|
||||||
|
|||||||
6
.github/pytorch-probot.yml
vendored
6
.github/pytorch-probot.yml
vendored
@ -2,8 +2,8 @@ tracking_issue: 24422
|
|||||||
ciflow_tracking_issue: 64124
|
ciflow_tracking_issue: 64124
|
||||||
ciflow_push_tags:
|
ciflow_push_tags:
|
||||||
- ciflow/b200
|
- ciflow/b200
|
||||||
- ciflow/b200-distributed
|
|
||||||
- ciflow/b200-symm-mem
|
- ciflow/b200-symm-mem
|
||||||
|
- ciflow/b200-distributed
|
||||||
- ciflow/binaries
|
- ciflow/binaries
|
||||||
- ciflow/binaries_libtorch
|
- ciflow/binaries_libtorch
|
||||||
- ciflow/binaries_wheel
|
- ciflow/binaries_wheel
|
||||||
@ -22,8 +22,6 @@ ciflow_push_tags:
|
|||||||
- ciflow/inductor-perf-test-nightly-xpu
|
- ciflow/inductor-perf-test-nightly-xpu
|
||||||
- ciflow/inductor-periodic
|
- ciflow/inductor-periodic
|
||||||
- ciflow/inductor-rocm
|
- ciflow/inductor-rocm
|
||||||
- ciflow/inductor-rocm-mi200
|
|
||||||
- ciflow/inductor-rocm-mi300
|
|
||||||
- ciflow/linux-aarch64
|
- ciflow/linux-aarch64
|
||||||
- ciflow/mps
|
- ciflow/mps
|
||||||
- ciflow/nightly
|
- ciflow/nightly
|
||||||
@ -35,13 +33,11 @@ ciflow_push_tags:
|
|||||||
- ciflow/quantization-periodic
|
- ciflow/quantization-periodic
|
||||||
- ciflow/riscv64
|
- ciflow/riscv64
|
||||||
- ciflow/rocm
|
- ciflow/rocm
|
||||||
- ciflow/rocm-mi200
|
|
||||||
- ciflow/rocm-mi300
|
- ciflow/rocm-mi300
|
||||||
- ciflow/rocm-mi355
|
- ciflow/rocm-mi355
|
||||||
- ciflow/rocm-navi31
|
- ciflow/rocm-navi31
|
||||||
- ciflow/s390
|
- ciflow/s390
|
||||||
- ciflow/slow
|
- ciflow/slow
|
||||||
- ciflow/slow-rocm-mi200
|
|
||||||
- ciflow/torchbench
|
- ciflow/torchbench
|
||||||
- ciflow/triton_binaries
|
- ciflow/triton_binaries
|
||||||
- ciflow/trunk
|
- ciflow/trunk
|
||||||
|
|||||||
3
.github/scripts/delete_old_branches.py
vendored
3
.github/scripts/delete_old_branches.py
vendored
@ -1,11 +1,10 @@
|
|||||||
# Delete old branches
|
# Delete old branches
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from collections.abc import Callable
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Callable
|
||||||
|
|
||||||
from github_utils import gh_fetch_json_dict, gh_graphql
|
from github_utils import gh_fetch_json_dict, gh_graphql
|
||||||
from gitutils import GitRepo
|
from gitutils import GitRepo
|
||||||
|
|||||||
3
.github/scripts/filter_test_configs.py
vendored
3
.github/scripts/filter_test_configs.py
vendored
@ -8,11 +8,10 @@ import re
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Callable
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from logging import info
|
from logging import info
|
||||||
from typing import Any, Optional
|
from typing import Any, Callable, Optional
|
||||||
from urllib.request import Request, urlopen
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|||||||
3
.github/scripts/get_workflow_job_id.py
vendored
3
.github/scripts/get_workflow_job_id.py
vendored
@ -11,8 +11,7 @@ import sys
|
|||||||
import time
|
import time
|
||||||
import urllib
|
import urllib
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from collections.abc import Callable
|
from typing import Any, Callable, Optional
|
||||||
from typing import Any, Optional
|
|
||||||
from urllib.request import Request, urlopen
|
from urllib.request import Request, urlopen
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
3
.github/scripts/github_utils.py
vendored
3
.github/scripts/github_utils.py
vendored
@ -3,9 +3,8 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Callable
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, cast, Optional, Union
|
from typing import Any, Callable, cast, Optional, Union
|
||||||
from urllib.error import HTTPError
|
from urllib.error import HTTPError
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
from urllib.request import Request, urlopen
|
from urllib.request import Request, urlopen
|
||||||
|
|||||||
4
.github/scripts/gitutils.py
vendored
4
.github/scripts/gitutils.py
vendored
@ -4,10 +4,10 @@ import os
|
|||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Callable, Iterator
|
from collections.abc import Iterator
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, cast, Optional, TypeVar, Union
|
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|||||||
3
.github/scripts/lintrunner.sh
vendored
3
.github/scripts/lintrunner.sh
vendored
@ -34,9 +34,6 @@ python3 torch/utils/data/datapipes/gen_pyi.py
|
|||||||
# Also check generated pyi files
|
# Also check generated pyi files
|
||||||
find torch -name '*.pyi' -exec git add --force -- "{}" +
|
find torch -name '*.pyi' -exec git add --force -- "{}" +
|
||||||
|
|
||||||
# Print current environment
|
|
||||||
python3 -m pip freeze
|
|
||||||
|
|
||||||
RC=0
|
RC=0
|
||||||
# Run lintrunner on all files
|
# Run lintrunner on all files
|
||||||
if ! lintrunner --force-color --tee-json=lint.json ${ADDITIONAL_LINTRUNNER_ARGS} 2> /dev/null; then
|
if ! lintrunner --force-color --tee-json=lint.json ${ADDITIONAL_LINTRUNNER_ARGS} 2> /dev/null; then
|
||||||
|
|||||||
4
.github/scripts/trymerge.py
vendored
4
.github/scripts/trymerge.py
vendored
@ -17,12 +17,12 @@ import re
|
|||||||
import time
|
import time
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from re import Pattern
|
from re import Pattern
|
||||||
from typing import Any, cast, NamedTuple, Optional
|
from typing import Any, Callable, cast, NamedTuple, Optional
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|||||||
4
.github/workflows/_rocm-test.yml
vendored
4
.github/workflows/_rocm-test.yml
vendored
@ -97,8 +97,8 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx')
|
ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx')
|
||||||
if [[ $ngpu -lt 2 ]]; then #We are temporarily reducing this down to 2 from 4 so that we can run tests on nodes with less gpus.
|
if [[ $ngpu -lt 4 ]]; then
|
||||||
echo "Error: only $ngpu GPU(s) detected, at least 2 GPUs are needed for distributed jobs"
|
echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
16
.github/workflows/_xpu-test.yml
vendored
16
.github/workflows/_xpu-test.yml
vendored
@ -344,21 +344,5 @@ jobs:
|
|||||||
if-no-files-found: ignore
|
if-no-files-found: ignore
|
||||||
path: ./**/core.[1-9]*
|
path: ./**/core.[1-9]*
|
||||||
|
|
||||||
- name: Authenticate with AWS
|
|
||||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
|
||||||
with:
|
|
||||||
role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results
|
|
||||||
# The max duration enforced by the server side
|
|
||||||
role-duration-seconds: 18000
|
|
||||||
aws-region: us-east-1
|
|
||||||
|
|
||||||
- name: Upload the benchmark results
|
|
||||||
uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main
|
|
||||||
with:
|
|
||||||
benchmark-results-dir: test/test-reports
|
|
||||||
dry-run: false
|
|
||||||
schema-version: v3
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Teardown XPU
|
- name: Teardown XPU
|
||||||
uses: ./.github/actions/teardown-xpu
|
uses: ./.github/actions/teardown-xpu
|
||||||
|
|||||||
1
.github/workflows/b200-distributed.yml
vendored
1
.github/workflows/b200-distributed.yml
vendored
@ -37,6 +37,7 @@ jobs:
|
|||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
with:
|
with:
|
||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
|
runner: linux.12xlarge.memory
|
||||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
|
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-distributed-b200
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||||
cuda-arch-list: '10.0'
|
cuda-arch-list: '10.0'
|
||||||
|
|||||||
1
.github/workflows/b200-symm-mem.yml
vendored
1
.github/workflows/b200-symm-mem.yml
vendored
@ -37,6 +37,7 @@ jobs:
|
|||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
with:
|
with:
|
||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
|
runner: linux.12xlarge.memory
|
||||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100-symm
|
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100-symm
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||||
cuda-arch-list: '10.0'
|
cuda-arch-list: '10.0'
|
||||||
|
|||||||
9
.github/workflows/build-triton-wheel.yml
vendored
9
.github/workflows/build-triton-wheel.yml
vendored
@ -51,15 +51,12 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
py_vers: [ "3.10", "3.11", "3.12", "3.13", "3.13t", "3.14", "3.14t" ]
|
py_vers: [ "3.10", "3.11", "3.12", "3.13", "3.13t", "3.14", "3.14t" ]
|
||||||
device: ["cuda", "rocm-n", "rocm-n-1", "xpu", "aarch64"]
|
device: ["cuda", "rocm", "xpu", "aarch64"]
|
||||||
docker-image: ["pytorch/manylinux2_28-builder:cpu"]
|
docker-image: ["pytorch/manylinux2_28-builder:cpu"]
|
||||||
include:
|
include:
|
||||||
- device: "rocm-n"
|
- device: "rocm"
|
||||||
rocm_version: "7.1"
|
rocm_version: "7.1"
|
||||||
runs_on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge"
|
runs_on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge"
|
||||||
- device: "rocm-n-1"
|
|
||||||
rocm_version: "7.0"
|
|
||||||
runs_on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge"
|
|
||||||
- device: "cuda"
|
- device: "cuda"
|
||||||
rocm_version: ""
|
rocm_version: ""
|
||||||
runs_on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge"
|
runs_on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge"
|
||||||
@ -174,7 +171,7 @@ jobs:
|
|||||||
|
|
||||||
- uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0
|
- uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0
|
||||||
with:
|
with:
|
||||||
name: pytorch-triton-wheel-${{ matrix.py_vers }}-${{ matrix.device }}${{ matrix.rocm_version != '' && format('-{0}', matrix.rocm_version) || '' }}-${{ env.PLATFORM }}
|
name: pytorch-triton-wheel-${{ matrix.py_vers }}-${{ matrix.device }}-${{ env.PLATFORM }}
|
||||||
if-no-files-found: error
|
if-no-files-found: error
|
||||||
path: ${{ runner.temp }}/artifacts/wheelhouse/*
|
path: ${{ runner.temp }}/artifacts/wheelhouse/*
|
||||||
|
|
||||||
|
|||||||
13
.github/workflows/docker-builds.yml
vendored
13
.github/workflows/docker-builds.yml
vendored
@ -56,8 +56,6 @@ jobs:
|
|||||||
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9,
|
pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc9,
|
||||||
pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11,
|
pytorch-linux-jammy-cuda12.4-cudnn9-py3-gcc11,
|
||||||
pytorch-linux-jammy-py3.10-clang12,
|
pytorch-linux-jammy-py3.10-clang12,
|
||||||
pytorch-linux-jammy-py3.11-clang12,
|
|
||||||
pytorch-linux-jammy-py3.12-clang12,
|
|
||||||
pytorch-linux-jammy-py3.13-clang12,
|
pytorch-linux-jammy-py3.13-clang12,
|
||||||
pytorch-linux-jammy-py3.14-clang12,
|
pytorch-linux-jammy-py3.14-clang12,
|
||||||
pytorch-linux-jammy-rocm-n-py3,
|
pytorch-linux-jammy-rocm-n-py3,
|
||||||
@ -67,10 +65,9 @@ jobs:
|
|||||||
pytorch-linux-jammy-py3.10-gcc11,
|
pytorch-linux-jammy-py3.10-gcc11,
|
||||||
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
|
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks,
|
||||||
pytorch-linux-jammy-py3.12-halide,
|
pytorch-linux-jammy-py3.12-halide,
|
||||||
pytorch-linux-jammy-cuda12.8-py3.12-pallas,
|
|
||||||
pytorch-linux-jammy-xpu-n-1-py3,
|
pytorch-linux-jammy-xpu-n-1-py3,
|
||||||
pytorch-linux-noble-xpu-n-py3,
|
pytorch-linux-jammy-xpu-n-py3,
|
||||||
pytorch-linux-noble-xpu-n-py3-inductor-benchmarks,
|
pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks,
|
||||||
pytorch-linux-jammy-py3-clang18-asan,
|
pytorch-linux-jammy-py3-clang18-asan,
|
||||||
pytorch-linux-jammy-py3-clang12-onnx,
|
pytorch-linux-jammy-py3-clang12-onnx,
|
||||||
pytorch-linux-jammy-linter,
|
pytorch-linux-jammy-linter,
|
||||||
@ -80,11 +77,9 @@ jobs:
|
|||||||
pytorch-linux-noble-riscv64-py3.12-gcc14
|
pytorch-linux-noble-riscv64-py3.12-gcc14
|
||||||
]
|
]
|
||||||
include:
|
include:
|
||||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13
|
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11
|
||||||
runner: linux.arm64.m7g.4xlarge
|
runner: linux.arm64.m7g.4xlarge
|
||||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21
|
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
|
||||||
runner: linux.arm64.m7g.4xlarge
|
|
||||||
- docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks
|
|
||||||
runner: linux.arm64.m7g.4xlarge
|
runner: linux.arm64.m7g.4xlarge
|
||||||
timeout-minutes: 600
|
timeout-minutes: 600
|
||||||
# Docker uploads fail from LF runners, see https://github.com/pytorch/pytorch/pull/137358
|
# Docker uploads fail from LF runners, see https://github.com/pytorch/pytorch/pull/137358
|
||||||
|
|||||||
1
.github/workflows/h100-distributed.yml
vendored
1
.github/workflows/h100-distributed.yml
vendored
@ -37,6 +37,7 @@ jobs:
|
|||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
with:
|
with:
|
||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
|
runner: "linux.c7i.12xlarge"
|
||||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-dist
|
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90-dist
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||||
cuda-arch-list: '9.0'
|
cuda-arch-list: '9.0'
|
||||||
|
|||||||
@ -72,7 +72,7 @@ jobs:
|
|||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
runner: linux.arm64.m7g.4xlarge
|
runner: linux.arm64.m7g.4xlarge
|
||||||
build-environment: linux-jammy-aarch64-py3.10
|
build-environment: linux-jammy-aarch64-py3.10
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks
|
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
{ config: "inductor_huggingface_perf_cpu_aarch64", shard: 1, num_shards: 9, runner: "linux.arm64.m7g.metal" },
|
{ config: "inductor_huggingface_perf_cpu_aarch64", shard: 1, num_shards: 9, runner: "linux.arm64.m7g.metal" },
|
||||||
|
|||||||
@ -83,8 +83,8 @@ jobs:
|
|||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
with:
|
with:
|
||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
build-environment: linux-noble-xpu-n-py3.10
|
build-environment: linux-jammy-xpu-n-py3.10
|
||||||
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3-inductor-benchmarks
|
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3-inductor-benchmarks
|
||||||
runner: linux.c7i.12xlarge
|
runner: linux.c7i.12xlarge
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
@ -117,7 +117,7 @@ jobs:
|
|||||||
uses: ./.github/workflows/_xpu-test.yml
|
uses: ./.github/workflows/_xpu-test.yml
|
||||||
needs: xpu-n-py3_10-inductor-benchmark-build
|
needs: xpu-n-py3_10-inductor-benchmark-build
|
||||||
with:
|
with:
|
||||||
build-environment: linux-noble-xpu-n-py3.10
|
build-environment: linux-jammy-xpu-n-py3.10
|
||||||
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false
|
dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-false-cppwrapper-true-aotinductor-true-freezing_cudagraphs-false-cudagraphs_low_precision-false
|
||||||
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
||||||
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
||||||
@ -137,7 +137,7 @@ jobs:
|
|||||||
uses: ./.github/workflows/_xpu-test.yml
|
uses: ./.github/workflows/_xpu-test.yml
|
||||||
needs: xpu-n-py3_10-inductor-benchmark-build
|
needs: xpu-n-py3_10-inductor-benchmark-build
|
||||||
with:
|
with:
|
||||||
build-environment: linux-noble-xpu-n-py3.10
|
build-environment: linux-jammy-xpu-n-py3.10
|
||||||
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
|
dashboard-tag: training-${{ inputs.training }}-inference-${{ inputs.inference }}-default-${{ inputs.default }}-dynamic-${{ inputs.dynamic }}-cudagraphs-${{ inputs.cudagraphs }}-cppwrapper-${{ inputs.cppwrapper }}-aotinductor-${{ inputs.aotinductor }}-maxautotune-${{ inputs.maxautotune }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs }}-cudagraphs_low_precision-${{ inputs.cudagraphs }}
|
||||||
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
docker-image: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.docker-image }}
|
||||||
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
test-matrix: ${{ needs.xpu-n-py3_10-inductor-benchmark-build.outputs.test-matrix }}
|
||||||
|
|||||||
1
.github/workflows/inductor-rocm-mi300.yml
vendored
1
.github/workflows/inductor-rocm-mi300.yml
vendored
@ -7,7 +7,6 @@ on:
|
|||||||
- release/*
|
- release/*
|
||||||
tags:
|
tags:
|
||||||
- ciflow/inductor-rocm/*
|
- ciflow/inductor-rocm/*
|
||||||
- ciflow/inductor-rocm-mi300/*
|
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
|
|||||||
@ -2,12 +2,12 @@ name: inductor-rocm
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
schedule:
|
schedule:
|
||||||
- cron: 0 */3 * * *
|
- cron: 0 * * * *
|
||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- release/*
|
- release/*
|
||||||
tags:
|
tags:
|
||||||
- ciflow/inductor-rocm-mi200/*
|
- ciflow/inductor-rocm/*
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
26
.github/workflows/inductor-unittest.yml
vendored
26
.github/workflows/inductor-unittest.yml
vendored
@ -81,32 +81,6 @@ jobs:
|
|||||||
test-matrix: ${{ needs.inductor-halide-build.outputs.test-matrix }}
|
test-matrix: ${{ needs.inductor-halide-build.outputs.test-matrix }}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
inductor-pallas-build:
|
|
||||||
name: inductor-pallas-build
|
|
||||||
uses: ./.github/workflows/_linux-build.yml
|
|
||||||
needs: get-label-type
|
|
||||||
with:
|
|
||||||
build-environment: linux-jammy-cuda12.8-py3.12-gcc11
|
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-py3.12-pallas
|
|
||||||
cuda-arch-list: '8.9'
|
|
||||||
runner: linux.8xlarge.memory
|
|
||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
|
||||||
test-matrix: |
|
|
||||||
{ include: [
|
|
||||||
{ config: "inductor-pallas", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu" },
|
|
||||||
]}
|
|
||||||
secrets: inherit
|
|
||||||
|
|
||||||
inductor-pallas-test:
|
|
||||||
name: inductor-pallas-test
|
|
||||||
uses: ./.github/workflows/_linux-test.yml
|
|
||||||
needs: inductor-pallas-build
|
|
||||||
with:
|
|
||||||
build-environment: linux-jammy-py3.12-gcc11
|
|
||||||
docker-image: ${{ needs.inductor-pallas-build.outputs.docker-image }}
|
|
||||||
test-matrix: ${{ needs.inductor-pallas-build.outputs.test-matrix }}
|
|
||||||
secrets: inherit
|
|
||||||
|
|
||||||
inductor-triton-cpu-build:
|
inductor-triton-cpu-build:
|
||||||
name: inductor-triton-cpu-build
|
name: inductor-triton-cpu-build
|
||||||
uses: ./.github/workflows/_linux-build.yml
|
uses: ./.github/workflows/_linux-build.yml
|
||||||
|
|||||||
2
.github/workflows/linux-aarch64.yml
vendored
2
.github/workflows/linux-aarch64.yml
vendored
@ -33,7 +33,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||||
build-environment: linux-jammy-aarch64-py3.10
|
build-environment: linux-jammy-aarch64-py3.10
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13
|
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
|
||||||
runner: linux.arm64.m7g.4xlarge
|
runner: linux.arm64.m7g.4xlarge
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
|
|||||||
8
.github/workflows/nightly.yml
vendored
8
.github/workflows/nightly.yml
vendored
@ -5,11 +5,9 @@ on:
|
|||||||
- cron: 0 0 * * *
|
- cron: 0 0 * * *
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
# NOTE: Doc build pipelines should only get triggered on:
|
# NOTE: Doc build pipelines should only get triggered on release candidate builds
|
||||||
# Major or minor release candidates builds
|
# Release candidate tags look like: v1.11.0-rc1
|
||||||
- v[0-9]+.[0-9]+.0+-rc[0-9]+
|
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
|
||||||
# Final RC for major, minor and patch releases
|
|
||||||
- v[0-9]+.[0-9]+.[0-9]+
|
|
||||||
- ciflow/nightly/*
|
- ciflow/nightly/*
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
|
|||||||
2
.github/workflows/operator_benchmark.yml
vendored
2
.github/workflows/operator_benchmark.yml
vendored
@ -60,7 +60,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
build-environment: linux-jammy-aarch64-py3.10
|
build-environment: linux-jammy-aarch64-py3.10
|
||||||
runner: linux.arm64.m7g.4xlarge
|
runner: linux.arm64.m7g.4xlarge
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13
|
docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },
|
{ config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" },
|
||||||
|
|||||||
1
.github/workflows/periodic-rocm-mi200.yml
vendored
1
.github/workflows/periodic-rocm-mi200.yml
vendored
@ -11,6 +11,7 @@ on:
|
|||||||
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
|
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
|
- ciflow/periodic/*
|
||||||
- ciflow/periodic-rocm-mi200/*
|
- ciflow/periodic-rocm-mi200/*
|
||||||
branches:
|
branches:
|
||||||
- release/*
|
- release/*
|
||||||
|
|||||||
1
.github/workflows/periodic-rocm-mi300.yml
vendored
1
.github/workflows/periodic-rocm-mi300.yml
vendored
@ -11,7 +11,6 @@ on:
|
|||||||
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
|
- cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- ciflow/periodic/*
|
|
||||||
- ciflow/periodic-rocm-mi300/*
|
- ciflow/periodic-rocm-mi300/*
|
||||||
branches:
|
branches:
|
||||||
- release/*
|
- release/*
|
||||||
|
|||||||
8
.github/workflows/pull.yml
vendored
8
.github/workflows/pull.yml
vendored
@ -342,16 +342,16 @@ jobs:
|
|||||||
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }}
|
test-matrix: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc9-inductor-build.outputs.test-matrix }}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
linux-noble-xpu-n-py3_10-build:
|
linux-jammy-xpu-n-py3_10-build:
|
||||||
name: linux-noble-xpu-n-py3.10
|
name: linux-jammy-xpu-n-py3.10
|
||||||
uses: ./.github/workflows/_linux-build.yml
|
uses: ./.github/workflows/_linux-build.yml
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
with:
|
with:
|
||||||
# This should sync with the build in xpu.yml but xpu uses a larger runner
|
# This should sync with the build in xpu.yml but xpu uses a larger runner
|
||||||
# sync-tag: linux-xpu-n-build
|
# sync-tag: linux-xpu-n-build
|
||||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||||
build-environment: linux-noble-xpu-n-py3.10
|
build-environment: linux-jammy-xpu-n-py3.10
|
||||||
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3
|
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
{ config: "default", shard: 1, num_shards: 4, runner: "linux.idc.xpu" },
|
{ config: "default", shard: 1, num_shards: 4, runner: "linux.idc.xpu" },
|
||||||
|
|||||||
1
.github/workflows/rocm-mi300.yml
vendored
1
.github/workflows/rocm-mi300.yml
vendored
@ -6,7 +6,6 @@ on:
|
|||||||
- main
|
- main
|
||||||
- release/*
|
- release/*
|
||||||
tags:
|
tags:
|
||||||
- ciflow/rocm/*
|
|
||||||
- ciflow/rocm-mi300/*
|
- ciflow/rocm-mi300/*
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
schedule:
|
schedule:
|
||||||
|
|||||||
@ -5,12 +5,11 @@ on:
|
|||||||
branches:
|
branches:
|
||||||
- release/*
|
- release/*
|
||||||
tags:
|
tags:
|
||||||
- ciflow/rocm-mi200/*
|
- ciflow/rocm/*
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
schedule:
|
schedule:
|
||||||
- cron: 29 8 * * * # about 1:29am PDT
|
- cron: 29 8 * * * # about 1:29am PDT
|
||||||
- cron: 0 */3 * * *
|
- cron: 0 * * * *
|
||||||
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
|
||||||
81
.github/workflows/slow-rocm-mi200.yml
vendored
81
.github/workflows/slow-rocm-mi200.yml
vendored
@ -1,81 +0,0 @@
|
|||||||
# This workflow is dedicated to host slow jobs that are run only periodically because
|
|
||||||
# they are too slow to run in every commit. The list of slow tests can be found in
|
|
||||||
# https://github.com/pytorch/test-infra/blob/generated-stats/stats/slow-tests.json
|
|
||||||
name: slow-rocm-mi200
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches:
|
|
||||||
- release/*
|
|
||||||
tags:
|
|
||||||
- ciflow/slow/*
|
|
||||||
- ciflow/slow-rocm-mi200/*
|
|
||||||
schedule:
|
|
||||||
- cron: 0 */3 * * *
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
concurrency:
|
|
||||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }}
|
|
||||||
cancel-in-progress: true
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
id-token: write
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
llm-td:
|
|
||||||
if: github.repository_owner == 'pytorch'
|
|
||||||
name: before-test
|
|
||||||
uses: ./.github/workflows/llm_td_retrieval.yml
|
|
||||||
permissions:
|
|
||||||
id-token: write
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
target-determination:
|
|
||||||
name: before-test
|
|
||||||
uses: ./.github/workflows/target_determination.yml
|
|
||||||
needs: llm-td
|
|
||||||
permissions:
|
|
||||||
id-token: write
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
get-label-type:
|
|
||||||
name: get-label-type
|
|
||||||
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
|
|
||||||
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
|
|
||||||
with:
|
|
||||||
triggering_actor: ${{ github.triggering_actor }}
|
|
||||||
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
|
|
||||||
curr_branch: ${{ github.head_ref || github.ref_name }}
|
|
||||||
curr_ref_type: ${{ github.ref_type }}
|
|
||||||
|
|
||||||
linux-jammy-rocm-py3_10-build:
|
|
||||||
name: linux-jammy-rocm-py3.10
|
|
||||||
uses: ./.github/workflows/_linux-build.yml
|
|
||||||
needs: get-label-type
|
|
||||||
with:
|
|
||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
|
||||||
build-environment: linux-jammy-rocm-py3.10
|
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
|
||||||
sync-tag: rocm-build
|
|
||||||
test-matrix: |
|
|
||||||
{ include: [
|
|
||||||
{ config: "slow", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
|
|
||||||
{ config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
|
|
||||||
]}
|
|
||||||
secrets: inherit
|
|
||||||
|
|
||||||
linux-jammy-rocm-py3_10-test:
|
|
||||||
permissions:
|
|
||||||
id-token: write
|
|
||||||
contents: read
|
|
||||||
name: linux-jammy-rocm-py3.10
|
|
||||||
uses: ./.github/workflows/_rocm-test.yml
|
|
||||||
needs:
|
|
||||||
- linux-jammy-rocm-py3_10-build
|
|
||||||
- target-determination
|
|
||||||
with:
|
|
||||||
build-environment: linux-jammy-rocm-py3.10
|
|
||||||
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
|
||||||
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
|
||||||
secrets: inherit
|
|
||||||
30
.github/workflows/slow.yml
vendored
30
.github/workflows/slow.yml
vendored
@ -105,6 +105,36 @@ jobs:
|
|||||||
test-matrix: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.test-matrix }}
|
test-matrix: ${{ needs.linux-jammy-py3_10-clang12-build.outputs.test-matrix }}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
|
linux-jammy-rocm-py3_10-build:
|
||||||
|
name: linux-jammy-rocm-py3.10
|
||||||
|
uses: ./.github/workflows/_linux-build.yml
|
||||||
|
needs: get-label-type
|
||||||
|
with:
|
||||||
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
|
build-environment: linux-jammy-rocm-py3.10
|
||||||
|
docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3
|
||||||
|
test-matrix: |
|
||||||
|
{ include: [
|
||||||
|
{ config: "slow", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
|
||||||
|
{ config: "slow", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.2", owners: ["module:rocm"] },
|
||||||
|
]}
|
||||||
|
secrets: inherit
|
||||||
|
|
||||||
|
linux-jammy-rocm-py3_10-test:
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
contents: read
|
||||||
|
name: linux-jammy-rocm-py3.10
|
||||||
|
uses: ./.github/workflows/_rocm-test.yml
|
||||||
|
needs:
|
||||||
|
- linux-jammy-rocm-py3_10-build
|
||||||
|
- target-determination
|
||||||
|
with:
|
||||||
|
build-environment: linux-jammy-rocm-py3.10
|
||||||
|
docker-image: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.docker-image }}
|
||||||
|
test-matrix: ${{ needs.linux-jammy-rocm-py3_10-build.outputs.test-matrix }}
|
||||||
|
secrets: inherit
|
||||||
|
|
||||||
linux-jammy-py3_10-clang18-asan-build:
|
linux-jammy-py3_10-clang18-asan-build:
|
||||||
name: linux-jammy-py3.10-clang18-asan
|
name: linux-jammy-py3.10-clang18-asan
|
||||||
uses: ./.github/workflows/_linux-build.yml
|
uses: ./.github/workflows/_linux-build.yml
|
||||||
|
|||||||
1
.github/workflows/test-b200.yml
vendored
1
.github/workflows/test-b200.yml
vendored
@ -52,6 +52,7 @@ jobs:
|
|||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
with:
|
with:
|
||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
|
runner: linux.12xlarge.memory
|
||||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100
|
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||||
cuda-arch-list: '10.0'
|
cuda-arch-list: '10.0'
|
||||||
|
|||||||
1
.github/workflows/test-h100.yml
vendored
1
.github/workflows/test-h100.yml
vendored
@ -41,6 +41,7 @@ jobs:
|
|||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
with:
|
with:
|
||||||
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
|
||||||
|
runner: linux.12xlarge.memory
|
||||||
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90
|
build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm90
|
||||||
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11
|
||||||
cuda-arch-list: '9.0'
|
cuda-arch-list: '9.0'
|
||||||
|
|||||||
5
.github/workflows/upload-test-stats.yml
vendored
5
.github/workflows/upload-test-stats.yml
vendored
@ -11,16 +11,15 @@ on:
|
|||||||
- inductor
|
- inductor
|
||||||
- unstable
|
- unstable
|
||||||
- slow
|
- slow
|
||||||
- slow-rocm-mi200
|
|
||||||
- unstable-periodic
|
- unstable-periodic
|
||||||
- inductor-periodic
|
- inductor-periodic
|
||||||
- rocm-mi200
|
- rocm
|
||||||
- rocm-mi300
|
- rocm-mi300
|
||||||
- rocm-mi355
|
- rocm-mi355
|
||||||
- inductor-micro-benchmark
|
- inductor-micro-benchmark
|
||||||
- inductor-micro-benchmark-x86
|
- inductor-micro-benchmark-x86
|
||||||
- inductor-cu124
|
- inductor-cu124
|
||||||
- inductor-rocm-mi200
|
- inductor-rocm
|
||||||
- inductor-rocm-mi300
|
- inductor-rocm-mi300
|
||||||
- mac-mps
|
- mac-mps
|
||||||
- linux-aarch64
|
- linux-aarch64
|
||||||
|
|||||||
20
.github/workflows/xpu.yml
vendored
20
.github/workflows/xpu.yml
vendored
@ -47,15 +47,15 @@ jobs:
|
|||||||
]}
|
]}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
linux-noble-xpu-n-py3_10-build:
|
linux-jammy-xpu-n-py3_10-build:
|
||||||
name: linux-noble-xpu-n-py3.10
|
name: linux-jammy-xpu-n-py3.10
|
||||||
uses: ./.github/workflows/_linux-build.yml
|
uses: ./.github/workflows/_linux-build.yml
|
||||||
needs: get-label-type
|
needs: get-label-type
|
||||||
with:
|
with:
|
||||||
sync-tag: linux-xpu-n-build
|
sync-tag: linux-xpu-n-build
|
||||||
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
runner_prefix: ${{ needs.get-label-type.outputs.label-type }}
|
||||||
build-environment: linux-noble-xpu-n-py3.10
|
build-environment: linux-jammy-xpu-n-py3.10
|
||||||
docker-image-name: ci-image:pytorch-linux-noble-xpu-n-py3
|
docker-image-name: ci-image:pytorch-linux-jammy-xpu-n-py3
|
||||||
runner: linux.c7i.12xlarge
|
runner: linux.c7i.12xlarge
|
||||||
test-matrix: |
|
test-matrix: |
|
||||||
{ include: [
|
{ include: [
|
||||||
@ -74,17 +74,17 @@ jobs:
|
|||||||
]}
|
]}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
linux-noble-xpu-n-py3_10-test:
|
linux-jammy-xpu-n-py3_10-test:
|
||||||
name: linux-noble-xpu-n-py3.10
|
name: linux-jammy-xpu-n-py3.10
|
||||||
uses: ./.github/workflows/_xpu-test.yml
|
uses: ./.github/workflows/_xpu-test.yml
|
||||||
needs: linux-noble-xpu-n-py3_10-build
|
needs: linux-jammy-xpu-n-py3_10-build
|
||||||
permissions:
|
permissions:
|
||||||
id-token: write
|
id-token: write
|
||||||
contents: read
|
contents: read
|
||||||
with:
|
with:
|
||||||
build-environment: linux-noble-xpu-n-py3.10
|
build-environment: linux-jammy-xpu-n-py3.10
|
||||||
docker-image: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.docker-image }}
|
docker-image: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.docker-image }}
|
||||||
test-matrix: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.test-matrix }}
|
test-matrix: ${{ needs.linux-jammy-xpu-n-py3_10-build.outputs.test-matrix }}
|
||||||
secrets: inherit
|
secrets: inherit
|
||||||
|
|
||||||
windows-xpu-n-1-build:
|
windows-xpu-n-1-build:
|
||||||
|
|||||||
@ -143,8 +143,7 @@ init_command = [
|
|||||||
'tools/linter/adapters/pip_init.py',
|
'tools/linter/adapters/pip_init.py',
|
||||||
'--dry-run={{DRYRUN}}',
|
'--dry-run={{DRYRUN}}',
|
||||||
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
|
'numpy==1.26.4 ; python_version >= "3.10" and python_version <= "3.11"',
|
||||||
'numpy==2.1.0 ; python_version >= "3.12" and python_version <= "3.13"',
|
'numpy==2.1.0 ; python_version >= "3.12"',
|
||||||
'numpy==2.3.4 ; python_version >= "3.14"',
|
|
||||||
'expecttest==0.3.0',
|
'expecttest==0.3.0',
|
||||||
'pyrefly==0.36.2',
|
'pyrefly==0.36.2',
|
||||||
'sympy==1.13.3',
|
'sympy==1.13.3',
|
||||||
@ -186,8 +185,6 @@ include_patterns = [
|
|||||||
'aten/src/ATen/native/nested/cuda/*.h',
|
'aten/src/ATen/native/nested/cuda/*.h',
|
||||||
'aten/src/ATen/native/nested/*.cpp',
|
'aten/src/ATen/native/nested/*.cpp',
|
||||||
'aten/src/ATen/native/nested/*.h',
|
'aten/src/ATen/native/nested/*.h',
|
||||||
'aten/src/ATen/xpu/**/*.h',
|
|
||||||
'aten/src/ATen/xpu/**/*.cpp',
|
|
||||||
'c10/**/*.cpp',
|
'c10/**/*.cpp',
|
||||||
'c10/**/*.h',
|
'c10/**/*.h',
|
||||||
'torch/*.h',
|
'torch/*.h',
|
||||||
@ -1404,7 +1401,7 @@ init_command = [
|
|||||||
'--dry-run={{DRYRUN}}',
|
'--dry-run={{DRYRUN}}',
|
||||||
'usort==1.0.8.post1',
|
'usort==1.0.8.post1',
|
||||||
'isort==6.0.1',
|
'isort==6.0.1',
|
||||||
'ruff==0.14.4', # sync with RUFF
|
'ruff==0.13.1', # sync with RUFF
|
||||||
]
|
]
|
||||||
is_formatter = true
|
is_formatter = true
|
||||||
|
|
||||||
@ -1539,7 +1536,7 @@ init_command = [
|
|||||||
'python3',
|
'python3',
|
||||||
'tools/linter/adapters/pip_init.py',
|
'tools/linter/adapters/pip_init.py',
|
||||||
'--dry-run={{DRYRUN}}',
|
'--dry-run={{DRYRUN}}',
|
||||||
'ruff==0.14.4', # sync with PYFMT
|
'ruff==0.13.1', # sync with PYFMT
|
||||||
]
|
]
|
||||||
is_formatter = true
|
is_formatter = true
|
||||||
|
|
||||||
|
|||||||
106
CMakeLists.txt
106
CMakeLists.txt
@ -234,17 +234,7 @@ option(USE_COLORIZE_OUTPUT "Colorize output during compilation" ON)
|
|||||||
option(USE_ASAN "Use Address+Undefined Sanitizers" OFF)
|
option(USE_ASAN "Use Address+Undefined Sanitizers" OFF)
|
||||||
option(USE_LSAN "Use Leak Sanitizer" OFF)
|
option(USE_LSAN "Use Leak Sanitizer" OFF)
|
||||||
option(USE_TSAN "Use Thread Sanitizer" OFF)
|
option(USE_TSAN "Use Thread Sanitizer" OFF)
|
||||||
|
|
||||||
# Track whether USE_CUDA was explicitly set by the user (before option() is called)
|
|
||||||
# If USE_CUDA is already defined in cache, it means user explicitly set it
|
|
||||||
if(DEFINED CACHE{USE_CUDA})
|
|
||||||
set(_USE_CUDA_EXPLICITLY_SET TRUE)
|
|
||||||
else()
|
|
||||||
set(_USE_CUDA_EXPLICITLY_SET FALSE)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
option(USE_CUDA "Use CUDA" ON)
|
option(USE_CUDA "Use CUDA" ON)
|
||||||
|
|
||||||
option(USE_XPU "Use XPU" ON)
|
option(USE_XPU "Use XPU" ON)
|
||||||
cmake_dependent_option(
|
cmake_dependent_option(
|
||||||
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
|
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
|
||||||
@ -736,44 +726,6 @@ if(NOT DEFINED USE_BLAS)
|
|||||||
set(USE_BLAS ON)
|
set(USE_BLAS ON)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# Prioritized Text Linker Optimization
|
|
||||||
if(USE_PRIORITIZED_TEXT_FOR_LD)
|
|
||||||
|
|
||||||
set(LINKER_SCRIPT_FILE_IN "${CMAKE_SOURCE_DIR}/cmake/prioritized_text.txt")
|
|
||||||
set(LINKER_SCRIPT_FILE_OUT "${CMAKE_SOURCE_DIR}/cmake/linker_script.ld")
|
|
||||||
|
|
||||||
execute_process(
|
|
||||||
COMMAND ${Python_EXECUTABLE}
|
|
||||||
${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py
|
|
||||||
--filein "${LINKER_SCRIPT_FILE_IN}"
|
|
||||||
--fout "${LINKER_SCRIPT_FILE_OUT}"
|
|
||||||
RESULT_VARIABLE _gen_result
|
|
||||||
OUTPUT_VARIABLE _gen_output
|
|
||||||
ERROR_VARIABLE _gen_error
|
|
||||||
)
|
|
||||||
|
|
||||||
if(NOT _gen_result EQUAL 0)
|
|
||||||
message(FATAL_ERROR
|
|
||||||
"Failed to generate linker script:\n${_gen_output}\n${_gen_error}")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
append_cxx_flag_if_supported("-ffunction-sections" CMAKE_CXX_FLAGS)
|
|
||||||
append_cxx_flag_if_supported("-fdata-sections" CMAKE_CXX_FLAGS)
|
|
||||||
append_c_flag_if_supported("-ffunction-sections" CMAKE_C_FLAGS)
|
|
||||||
append_c_flag_if_supported("-fdata-sections" CMAKE_C_FLAGS)
|
|
||||||
|
|
||||||
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -T${LINKER_SCRIPT_FILE_OUT}")
|
|
||||||
set(CMAKE_MODULE_LINKER_FLAGS "${CMAKE_MODULE_LINKER_FLAGS} -T${LINKER_SCRIPT_FILE_OUT}")
|
|
||||||
|
|
||||||
else()
|
|
||||||
if(LINUX AND CPU_AARCH64)
|
|
||||||
message(WARNING [[
|
|
||||||
It is strongly recommend to enable linker script optimization for all AArch64 Linux builds.
|
|
||||||
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
|
|
||||||
]])
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Build libtorch mobile library, which contains ATen/TH ops and native support
|
# Build libtorch mobile library, which contains ATen/TH ops and native support
|
||||||
# for TorchScript model, but doesn't contain not-yet-unified caffe2 ops;
|
# for TorchScript model, but doesn't contain not-yet-unified caffe2 ops;
|
||||||
if(INTERN_BUILD_MOBILE)
|
if(INTERN_BUILD_MOBILE)
|
||||||
@ -1440,6 +1392,9 @@ if(BUILD_JNI)
|
|||||||
add_subdirectory(android/pytorch_android)
|
add_subdirectory(android/pytorch_android)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
include(cmake/Summary.cmake)
|
||||||
|
caffe2_print_configuration_summary()
|
||||||
|
|
||||||
# Parse custom debug info
|
# Parse custom debug info
|
||||||
if(DEFINED USE_CUSTOM_DEBINFO)
|
if(DEFINED USE_CUSTOM_DEBINFO)
|
||||||
string(REPLACE ";" " " SOURCE_FILES "${USE_CUSTOM_DEBINFO}")
|
string(REPLACE ";" " " SOURCE_FILES "${USE_CUSTOM_DEBINFO}")
|
||||||
@ -1479,5 +1434,56 @@ if(BUILD_BUNDLE_PTXAS AND USE_CUDA)
|
|||||||
DESTINATION "${CMAKE_INSTALL_BINDIR}")
|
DESTINATION "${CMAKE_INSTALL_BINDIR}")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
include(cmake/Summary.cmake)
|
if(USE_PRIORITIZED_TEXT_FOR_LD)
|
||||||
caffe2_print_configuration_summary()
|
add_compile_options(
|
||||||
|
$<$<COMPILE_LANGUAGE:C,CXX>:-ffunction-sections>
|
||||||
|
$<$<COMPILE_LANGUAGE:C,CXX>:-fdata-sections>
|
||||||
|
)
|
||||||
|
set(LINKER_SCRIPT_FILE_OUT "${CMAKE_SOURCE_DIR}/cmake/linker_script.ld")
|
||||||
|
set(LINKER_SCRIPT_FILE_IN "${CMAKE_SOURCE_DIR}/cmake/prioritized_text.txt")
|
||||||
|
|
||||||
|
add_custom_command(
|
||||||
|
OUTPUT "${LINKER_SCRIPT_FILE_OUT}"
|
||||||
|
COMMAND ${Python_EXECUTABLE} ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py --filein "${LINKER_SCRIPT_FILE_IN}" --fout "${LINKER_SCRIPT_FILE_OUT}"
|
||||||
|
DEPENDS ${CMAKE_SOURCE_DIR}/tools/setup_helpers/generate_linker_script.py "${LINKER_SCRIPT_FILE_IN}"
|
||||||
|
COMMENT "Generating prioritized text linker files"
|
||||||
|
VERBATIM
|
||||||
|
)
|
||||||
|
|
||||||
|
add_custom_target(generate_linker_script DEPENDS "${LINKER_SCRIPT_FILE_OUT}")
|
||||||
|
|
||||||
|
if(BUILD_PYTHON)
|
||||||
|
set(LINKER_OPT_TARGETS torch_python)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if(NOT BUILD_LIBTORCHLESS)
|
||||||
|
list(APPEND LINKER_OPT_TARGETS torch_cpu c10)
|
||||||
|
if(USE_CUDA)
|
||||||
|
list(APPEND LINKER_OPT_TARGETS torch_cuda c10_cuda)
|
||||||
|
endif()
|
||||||
|
if(USE_XPU)
|
||||||
|
list(APPEND LINKER_OPT_TARGETS torch_xpu c10_xpu)
|
||||||
|
endif()
|
||||||
|
if(USE_ROCM)
|
||||||
|
list(APPEND LINKER_OPT_TARGETS torch_hip c10_hip)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
foreach(tgt IN LISTS LINKER_OPT_TARGETS)
|
||||||
|
if(TARGET ${tgt})
|
||||||
|
add_dependencies("${tgt}" generate_linker_script)
|
||||||
|
target_link_options_if_supported(${tgt} "-T,${LINKER_SCRIPT_FILE_OUT}")
|
||||||
|
set_property(TARGET ${tgt} APPEND PROPERTY LINK_DEPENDS "${LINKER_SCRIPT_FILE_OUT}")
|
||||||
|
else()
|
||||||
|
message(WARNING "Requested target '${tgt}' for linker script optimization was not found.")
|
||||||
|
endif()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
|
else()
|
||||||
|
if(LINUX AND CPU_AARCH64)
|
||||||
|
message(WARNING [[
|
||||||
|
It is strongly recommend to enable linker script optimization for all AArch64 Linux builds.
|
||||||
|
To do so please export USE_PRIORITIZED_TEXT_FOR_LD=1
|
||||||
|
]])
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|||||||
@ -210,12 +210,8 @@ torch/backends/cudnn/ @eqy @syed-ahmed @Aidyn-A
|
|||||||
/test/inductor/test_flex_attention.py @drisspg
|
/test/inductor/test_flex_attention.py @drisspg
|
||||||
/test/inductor/test_flex_decoding.py @drisspg
|
/test/inductor/test_flex_decoding.py @drisspg
|
||||||
|
|
||||||
# Low Precision & Grouped GEMMs
|
# Low Precision GEMMs
|
||||||
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
|
/aten/src/ATen/native/cuda/Blas.cpp @drisspg @slayton58
|
||||||
/aten/src/ATen/native/cuda/GroupedBlas.cpp @drisspg @slayton58
|
|
||||||
/aten/src/ATen/native/cuda/ScaledBlas.cpp @drisspg @slayton58
|
|
||||||
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
|
/aten/src/ATen/cuda/CUDABlas.cpp @drisspg @slayton58
|
||||||
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
|
/aten/src/ATen/cuda/CUDABlas.h @drisspg @slayton58
|
||||||
/aten/src/ATen/cuda/CUDAScaledBlas.cpp @drisspg @slayton58
|
|
||||||
/aten/src/ATen/cuda/CUDAScaledBlas.h @drisspg @slayton58
|
|
||||||
/test/test_scaled_matmul_cuda.py @drisspg @slayton58
|
/test/test_scaled_matmul_cuda.py @drisspg @slayton58
|
||||||
|
|||||||
@ -18,7 +18,7 @@ aspects of contributing to PyTorch.
|
|||||||
- [Python Unit Testing](#python-unit-testing)
|
- [Python Unit Testing](#python-unit-testing)
|
||||||
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
|
- [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest)
|
||||||
- [Local linting](#local-linting)
|
- [Local linting](#local-linting)
|
||||||
- [Running `pyrefly`](#running-pyrefly)
|
- [Running `mypy`](#running-mypy)
|
||||||
- [C++ Unit Testing](#c-unit-testing)
|
- [C++ Unit Testing](#c-unit-testing)
|
||||||
- [Run Specific CI Jobs](#run-specific-ci-jobs)
|
- [Run Specific CI Jobs](#run-specific-ci-jobs)
|
||||||
- [Merging your Change](#merging-your-change)
|
- [Merging your Change](#merging-your-change)
|
||||||
@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory.
|
|||||||
**Prerequisites**:
|
**Prerequisites**:
|
||||||
The following packages should be installed with `pip`:
|
The following packages should be installed with `pip`:
|
||||||
- `expecttest` and `hypothesis` - required to run tests
|
- `expecttest` and `hypothesis` - required to run tests
|
||||||
- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/)
|
- `mypy` - recommended for linting
|
||||||
- `pytest` - recommended to run tests more selectively
|
- `pytest` - recommended to run tests more selectively
|
||||||
Running
|
Running
|
||||||
```
|
```
|
||||||
@ -350,32 +350,15 @@ make lint
|
|||||||
|
|
||||||
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
|
Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner)
|
||||||
|
|
||||||
#### Running `pyrefly`
|
#### Running `mypy`
|
||||||
|
|
||||||
[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback.
|
`mypy` is an optional static type checker for Python. We have multiple `mypy`
|
||||||
|
configs for the PyTorch codebase that are automatically validated against whenever the linter is run.
|
||||||
PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository.
|
|
||||||
|
|
||||||
**Getting Started with Pyrefly:**
|
|
||||||
|
|
||||||
To run type checking on the PyTorch codebase:
|
|
||||||
```bash
|
|
||||||
pyrefly check
|
|
||||||
```
|
|
||||||
|
|
||||||
For more detailed error information with summaries:
|
|
||||||
```bash
|
|
||||||
pyrefly check --summarize-errors
|
|
||||||
```
|
|
||||||
|
|
||||||
**Learn More:**
|
|
||||||
- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options
|
|
||||||
- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking
|
|
||||||
- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations
|
|
||||||
|
|
||||||
See [Guide for adding type annotations to
|
See [Guide for adding type annotations to
|
||||||
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
|
PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch)
|
||||||
for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase.
|
for more information on how to set up `mypy` and tackle type annotation
|
||||||
|
tasks.
|
||||||
|
|
||||||
### C++ Unit Testing
|
### C++ Unit Testing
|
||||||
|
|
||||||
|
|||||||
2
LICENSE
2
LICENSE
@ -37,7 +37,7 @@ Copyright (c) 2024 Tri Dao.
|
|||||||
All rights reserved.
|
All rights reserved.
|
||||||
|
|
||||||
All contributions by Arm:
|
All contributions by Arm:
|
||||||
Copyright (c) 2021, 2023-2025 Arm Limited and/or its affiliates
|
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates
|
||||||
|
|
||||||
All contributions from Caffe:
|
All contributions from Caffe:
|
||||||
Copyright(c) 2013, 2014, 2015, the respective contributors
|
Copyright(c) 2013, 2014, 2015, the respective contributors
|
||||||
|
|||||||
@ -174,12 +174,6 @@ class TORCH_API Context {
|
|||||||
static long versionCuDNN() {
|
static long versionCuDNN() {
|
||||||
return detail::getCUDAHooks().versionCuDNN();
|
return detail::getCUDAHooks().versionCuDNN();
|
||||||
}
|
}
|
||||||
static long versionRuntimeCuDNN() {
|
|
||||||
return detail::getCUDAHooks().versionRuntimeCuDNN();
|
|
||||||
}
|
|
||||||
static long versionCuDNNFrontend() {
|
|
||||||
return detail::getCUDAHooks().versionCuDNNFrontend();
|
|
||||||
}
|
|
||||||
static bool hasCuSOLVER() {
|
static bool hasCuSOLVER() {
|
||||||
return detail::getCUDAHooks().hasCuSOLVER();
|
return detail::getCUDAHooks().hasCuSOLVER();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -94,11 +94,6 @@ TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) {
|
|||||||
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
|
at::getDeviceAllocator(device_type)->resetPeakStats(device_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_API inline std::pair<size_t, size_t> getMemoryInfo(
|
|
||||||
c10::DeviceIndex device_index) {
|
|
||||||
const auto device_type = getAccelerator(true).value();
|
|
||||||
return at::getDeviceAllocator(device_type)->getMemoryInfo(device_index);
|
|
||||||
}
|
|
||||||
} // namespace at::accelerator
|
} // namespace at::accelerator
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|||||||
@ -6,7 +6,6 @@
|
|||||||
#include <c10/util/Half.h>
|
#include <c10/util/Half.h>
|
||||||
#include <c10/util/Metaprogramming.h>
|
#include <c10/util/Metaprogramming.h>
|
||||||
#include <c10/util/complex.h>
|
#include <c10/util/complex.h>
|
||||||
#include <torch/headeronly/core/Dispatch.h>
|
|
||||||
|
|
||||||
#ifdef __CUDACC__
|
#ifdef __CUDACC__
|
||||||
#include <cuda.h> // For CUDA_VERSION
|
#include <cuda.h> // For CUDA_VERSION
|
||||||
@ -62,9 +61,12 @@ TORCH_API void record_kernel_function_dtype(std::string name);
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
|
#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
|
||||||
THO_PRIVATE_CASE_TYPE_USING_HINT_TMPL( \
|
case enum_type: { \
|
||||||
AT_PRIVATE_CHECK_SELECTIVE_BUILD, enum_type, HINT, __VA_ARGS__)
|
AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
|
||||||
|
using HINT [[maybe_unused]] = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
|
||||||
|
return __VA_ARGS__(); \
|
||||||
|
}
|
||||||
|
|
||||||
#define AT_DISPATCH_CASE(enum_type, ...) \
|
#define AT_DISPATCH_CASE(enum_type, ...) \
|
||||||
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
|
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
|
||||||
@ -93,6 +95,14 @@ TORCH_API void record_kernel_function_dtype(std::string name);
|
|||||||
return __VA_ARGS__(); \
|
return __VA_ARGS__(); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
inline at::ScalarType scalar_type(at::ScalarType s) {
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
// The AT_DISPATCH_* family of macros provides the ability to
|
// The AT_DISPATCH_* family of macros provides the ability to
|
||||||
// conveniently generate specializations of a kernel over all of the
|
// conveniently generate specializations of a kernel over all of the
|
||||||
// dtypes we care about in PyTorch. We call it "dispatch" because
|
// dtypes we care about in PyTorch. We call it "dispatch" because
|
||||||
@ -180,13 +190,27 @@ TORCH_API void record_kernel_function_dtype(std::string name);
|
|||||||
// but we're just being safe (and it doesn't hurt.) Note we must
|
// but we're just being safe (and it doesn't hurt.) Note we must
|
||||||
// use it to shut up warnings about unused store.
|
// use it to shut up warnings about unused store.
|
||||||
|
|
||||||
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
|
#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
|
||||||
THO_DISPATCH_SWITCH_TMPL( \
|
[&] { \
|
||||||
RECORD_KERNEL_FUNCTION_DTYPE, \
|
const auto& the_type = TYPE; \
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED, \
|
constexpr const char* at_dispatch_name = NAME; \
|
||||||
TYPE, \
|
/* don't use TYPE again in case it is an expensive or side-effect op */ \
|
||||||
NAME, \
|
at::ScalarType _st = ::detail::scalar_type(the_type); \
|
||||||
__VA_ARGS__)
|
RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
|
||||||
|
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wswitch-enum") \
|
||||||
|
switch (_st) { \
|
||||||
|
__VA_ARGS__ \
|
||||||
|
default: \
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED( \
|
||||||
|
false, \
|
||||||
|
'"', \
|
||||||
|
at_dispatch_name, \
|
||||||
|
"\" not implemented for '", \
|
||||||
|
toString(_st), \
|
||||||
|
"'"); \
|
||||||
|
} \
|
||||||
|
C10_DIAGNOSTIC_POP() \
|
||||||
|
}()
|
||||||
|
|
||||||
#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
|
#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
|
||||||
|
|||||||
@ -1,8 +1,3 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <torch/headeronly/core/Dispatch_v2.h>
|
|
||||||
|
|
||||||
// Get AT_DISPATCH_SWITCH and AT_DISPATCH_CASE:
|
|
||||||
#include <ATen/Dispatch.h>
|
#include <ATen/Dispatch.h>
|
||||||
|
|
||||||
// This is a new implementation of the AT_DISPATCH macro family from
|
// This is a new implementation of the AT_DISPATCH macro family from
|
||||||
@ -79,19 +74,41 @@
|
|||||||
// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
|
// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly
|
||||||
// relied on GPT4 to help me get it right.
|
// relied on GPT4 to help me get it right.
|
||||||
|
|
||||||
|
// Public API macros
|
||||||
|
|
||||||
// See documentation above
|
// See documentation above
|
||||||
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
|
#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \
|
||||||
THO_DISPATCH_V2_TMPL( \
|
AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__))
|
||||||
AT_DISPATCH_SWITCH, \
|
|
||||||
AT_DISPATCH_CASE, \
|
// This macro lets you pass an arbitrary expression that may contain internal
|
||||||
TYPE, \
|
// commas to another macro without having the commas causing the expression
|
||||||
NAME, \
|
// to be interpreted as being multiple arguments
|
||||||
AT_WRAP(BODY), \
|
#define AT_WRAP(...) __VA_ARGS__
|
||||||
__VA_ARGS__)
|
|
||||||
|
#define AT_FLOAT8_TYPES \
|
||||||
|
c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \
|
||||||
|
c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu
|
||||||
|
|
||||||
|
#define AT_INTEGRAL_TYPES \
|
||||||
|
c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort
|
||||||
|
#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat
|
||||||
|
#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64
|
||||||
|
#define AT_INTEGRAL_TYPES_V2 \
|
||||||
|
AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)
|
||||||
|
#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat
|
||||||
|
#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32
|
||||||
|
// NB: not *actually* all types
|
||||||
|
#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES)
|
||||||
|
#define AT_ALL_TYPES_AND_COMPLEX \
|
||||||
|
AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES)
|
||||||
|
|
||||||
|
// Helper macros
|
||||||
|
|
||||||
// Unused helper macros, kept for BC:
|
|
||||||
#define AT_AP_VAR(N, T, ...) \
|
#define AT_AP_VAR(N, T, ...) \
|
||||||
AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
|
AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__))
|
||||||
|
#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b)
|
||||||
|
#define AT_CONCAT_AUX(a, b) a##b
|
||||||
|
#define AT_EXPAND(X) X
|
||||||
|
|
||||||
// Ensure we never have too many scalar types for the expansion here to
|
// Ensure we never have too many scalar types for the expansion here to
|
||||||
// support. To bump this, you must regenerate the macros below.
|
// support. To bump this, you must regenerate the macros below.
|
||||||
@ -102,6 +119,12 @@ static_assert(static_cast<int>(c10::ScalarType::NumOptions) < 60);
|
|||||||
|
|
||||||
num_args = 60
|
num_args = 60
|
||||||
|
|
||||||
|
nums = ', '.join(str(i) for i in reversed(range(num_args+1)))
|
||||||
|
args = ', '.join(f'_{i}' for i in range(1, num_args+1))
|
||||||
|
|
||||||
|
print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))')
|
||||||
|
print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N')
|
||||||
|
|
||||||
for i in range(1, num_args+1):
|
for i in range(1, num_args+1):
|
||||||
args = ', '.join(f'_{i}' for i in range(1, i+1))
|
args = ', '.join(f'_{i}' for i in range(1, i+1))
|
||||||
cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
|
cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)])
|
||||||
@ -112,6 +135,8 @@ for i in range(1, num_args+1):
|
|||||||
// Begin generated code
|
// Begin generated code
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
|
||||||
|
#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0))
|
||||||
|
#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, N, ...) N
|
||||||
#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
|
#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N)
|
||||||
#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
|
#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N)
|
||||||
#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)
|
#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N)
|
||||||
|
|||||||
@ -226,8 +226,8 @@ template <
|
|||||||
typename B = HostBlock<S>>
|
typename B = HostBlock<S>>
|
||||||
struct CachingHostAllocatorImpl {
|
struct CachingHostAllocatorImpl {
|
||||||
virtual ~CachingHostAllocatorImpl() {
|
virtual ~CachingHostAllocatorImpl() {
|
||||||
if (active_) {
|
active_ = false;
|
||||||
active_ = false;
|
if (pinned_use_background_threads()) {
|
||||||
getBackgroundThreadPool()->waitWorkComplete();
|
getBackgroundThreadPool()->waitWorkComplete();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -260,7 +260,6 @@ struct CachingHostAllocatorImpl {
|
|||||||
if (pinned_use_background_threads()) {
|
if (pinned_use_background_threads()) {
|
||||||
// Launch the background thread and process events in a loop.
|
// Launch the background thread and process events in a loop.
|
||||||
static bool background_thread_flag [[maybe_unused]] = [this] {
|
static bool background_thread_flag [[maybe_unused]] = [this] {
|
||||||
active_ = true;
|
|
||||||
getBackgroundThreadPool()->run([&]() {
|
getBackgroundThreadPool()->run([&]() {
|
||||||
while (active_) {
|
while (active_) {
|
||||||
process_events();
|
process_events();
|
||||||
@ -684,9 +683,9 @@ struct CachingHostAllocatorImpl {
|
|||||||
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
|
alignas(hardware_destructive_interference_size) std::mutex events_mutex_;
|
||||||
std::deque<std::pair<E, B*>> events_; // event queue paired with block
|
std::deque<std::pair<E, B*>> events_; // event queue paired with block
|
||||||
|
|
||||||
// Indicates whether the event-processing thread pool is active.
|
// Indicates whether the object is active.
|
||||||
// Set to false in the destructor to signal background threads to stop.
|
// Set to false in the destructor to signal background threads to stop.
|
||||||
std::atomic<bool> active_{false};
|
std::atomic<bool> active_{true};
|
||||||
protected:
|
protected:
|
||||||
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
|
alignas(hardware_destructive_interference_size) HostStatsStaged stats_;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -191,7 +191,7 @@ class Vectorized<BFloat16> {
|
|||||||
auto vals = svreinterpret_u16_bf16(values);
|
auto vals = svreinterpret_u16_bf16(values);
|
||||||
vals = sveor_u16_x(ptrue, vals, mask);
|
vals = sveor_u16_x(ptrue, vals, mask);
|
||||||
return svreinterpret_bf16_u16(vals);
|
return svreinterpret_bf16_u16(vals);
|
||||||
}
|
};
|
||||||
Vectorized<BFloat16> round() const;
|
Vectorized<BFloat16> round() const;
|
||||||
Vectorized<BFloat16> tan() const;
|
Vectorized<BFloat16> tan() const;
|
||||||
Vectorized<BFloat16> tanh() const;
|
Vectorized<BFloat16> tanh() const;
|
||||||
@ -349,47 +349,47 @@ Vectorized<BFloat16> inline Vectorized<BFloat16>::frac() const {
|
|||||||
return convert_float_bfloat16(v1, v2); \
|
return convert_float_bfloat16(v1, v2); \
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(isnan)
|
DEFINE_BF16_FUNC_VIA_FLOAT(isnan);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(angle)
|
DEFINE_BF16_FUNC_VIA_FLOAT(angle);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(acos)
|
DEFINE_BF16_FUNC_VIA_FLOAT(acos);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(acosh)
|
DEFINE_BF16_FUNC_VIA_FLOAT(acosh);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(asin)
|
DEFINE_BF16_FUNC_VIA_FLOAT(asin);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(atan)
|
DEFINE_BF16_FUNC_VIA_FLOAT(atan);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(atanh)
|
DEFINE_BF16_FUNC_VIA_FLOAT(atanh);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2)
|
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign)
|
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(erf)
|
DEFINE_BF16_FUNC_VIA_FLOAT(erf);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(erfc)
|
DEFINE_BF16_FUNC_VIA_FLOAT(erfc);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp)
|
DEFINE_BF16_FUNC_VIA_FLOAT(exp);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(exp2)
|
DEFINE_BF16_FUNC_VIA_FLOAT(exp2);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(expm1)
|
DEFINE_BF16_FUNC_VIA_FLOAT(expm1);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod)
|
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot)
|
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0)
|
DEFINE_BF16_FUNC_VIA_FLOAT(i0);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(i0e)
|
DEFINE_BF16_FUNC_VIA_FLOAT(i0e);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(digamma)
|
DEFINE_BF16_FUNC_VIA_FLOAT(digamma);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma)
|
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac)
|
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter)
|
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(log)
|
DEFINE_BF16_FUNC_VIA_FLOAT(log);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(log2)
|
DEFINE_BF16_FUNC_VIA_FLOAT(log2);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(log10)
|
DEFINE_BF16_FUNC_VIA_FLOAT(log10);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(log1p)
|
DEFINE_BF16_FUNC_VIA_FLOAT(log1p);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(sin)
|
DEFINE_BF16_FUNC_VIA_FLOAT(sin);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(sinh)
|
DEFINE_BF16_FUNC_VIA_FLOAT(sinh);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(cos)
|
DEFINE_BF16_FUNC_VIA_FLOAT(cos);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(cosh)
|
DEFINE_BF16_FUNC_VIA_FLOAT(cosh);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(ceil)
|
DEFINE_BF16_FUNC_VIA_FLOAT(ceil);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(floor)
|
DEFINE_BF16_FUNC_VIA_FLOAT(floor);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(round)
|
DEFINE_BF16_FUNC_VIA_FLOAT(round);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(tan)
|
DEFINE_BF16_FUNC_VIA_FLOAT(tan);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(tanh)
|
DEFINE_BF16_FUNC_VIA_FLOAT(tanh);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(trunc)
|
DEFINE_BF16_FUNC_VIA_FLOAT(trunc);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma)
|
DEFINE_BF16_FUNC_VIA_FLOAT(lgamma);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt)
|
DEFINE_BF16_FUNC_VIA_FLOAT(sqrt);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal)
|
DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt)
|
DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt);
|
||||||
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow)
|
DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow);
|
||||||
|
|
||||||
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(
|
Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(
|
||||||
const Vectorized<BFloat16>& other) const {
|
const Vectorized<BFloat16>& other) const {
|
||||||
|
|||||||
@ -388,7 +388,6 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
|||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
at::Half halpha;
|
at::Half halpha;
|
||||||
at::Half hbeta;
|
at::Half hbeta;
|
||||||
uint32_t mask = -1;
|
|
||||||
#endif
|
#endif
|
||||||
void * alpha_ptr = α
|
void * alpha_ptr = α
|
||||||
void * beta_ptr = β
|
void * beta_ptr = β
|
||||||
@ -428,7 +427,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
|||||||
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
|
auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS();
|
||||||
if (fp16_reduction !=
|
if (fp16_reduction !=
|
||||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||||
mask =
|
uint32_t mask =
|
||||||
fp16_reduction ==
|
fp16_reduction ==
|
||||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||||
@ -445,7 +444,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
|||||||
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
|
auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS();
|
||||||
if (bf16_reduction !=
|
if (bf16_reduction !=
|
||||||
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) {
|
||||||
mask =
|
uint32_t mask =
|
||||||
bf16_reduction ==
|
bf16_reduction ==
|
||||||
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK
|
||||||
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE |
|
||||||
@ -512,41 +511,17 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D
|
|||||||
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
|
cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS;
|
||||||
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
cublasLtMatmulHeuristicResult_t heuristicResult = {};
|
||||||
int returnedResult = 0;
|
int returnedResult = 0;
|
||||||
// on Blackwell+, we fake a n > 1 matmul when querying heuristics
|
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
||||||
// to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance
|
ltHandle,
|
||||||
#ifndef USE_ROCM
|
computeDesc.descriptor(),
|
||||||
const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10;
|
Adesc.descriptor(),
|
||||||
#else
|
Bdesc.descriptor(),
|
||||||
const bool lie_to_cublaslt = false;
|
Cdesc.descriptor(),
|
||||||
#endif
|
Cdesc.descriptor(),
|
||||||
if (lie_to_cublaslt) {
|
preference.descriptor(),
|
||||||
CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T);
|
1,
|
||||||
CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc);
|
&heuristicResult,
|
||||||
|
&returnedResult));
|
||||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
|
||||||
ltHandle,
|
|
||||||
computeDesc.descriptor(),
|
|
||||||
Adesc.descriptor(),
|
|
||||||
FakeBdesc.descriptor(),
|
|
||||||
FakeCdesc.descriptor(),
|
|
||||||
FakeCdesc.descriptor(),
|
|
||||||
preference.descriptor(),
|
|
||||||
1,
|
|
||||||
&heuristicResult,
|
|
||||||
&returnedResult));
|
|
||||||
} else {
|
|
||||||
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
|
|
||||||
ltHandle,
|
|
||||||
computeDesc.descriptor(),
|
|
||||||
Adesc.descriptor(),
|
|
||||||
Bdesc.descriptor(),
|
|
||||||
Cdesc.descriptor(),
|
|
||||||
Cdesc.descriptor(),
|
|
||||||
preference.descriptor(),
|
|
||||||
1,
|
|
||||||
&heuristicResult,
|
|
||||||
&returnedResult));
|
|
||||||
}
|
|
||||||
if (returnedResult == 0) {
|
if (returnedResult == 0) {
|
||||||
cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
|
cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED;
|
||||||
}
|
}
|
||||||
@ -1597,7 +1572,7 @@ bool gemm_and_bias(
|
|||||||
}
|
}
|
||||||
|
|
||||||
using opmath_t = at::opmath_type<Dtype>;
|
using opmath_t = at::opmath_type<Dtype>;
|
||||||
opmath_t beta_val = bias ? 0 : 1; // bias is added in epilogue unless nullptr
|
opmath_t beta_val = 0; // bias is added in epilogue
|
||||||
|
|
||||||
cudaDataType_t abType = CUDA_R_32F;
|
cudaDataType_t abType = CUDA_R_32F;
|
||||||
cudaDataType_t cType = CUDA_R_32F;
|
cudaDataType_t cType = CUDA_R_32F;
|
||||||
@ -1686,22 +1661,15 @@ bool gemm_and_bias(
|
|||||||
_syncCurrentWithCarveoutStream(stream, true);
|
_syncCurrentWithCarveoutStream(stream, true);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
const auto epilogue = [&]() -> cublasLtEpilogue_t {
|
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
|
||||||
// The cuBLAS documentation indicates that
|
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
|
||||||
// *_<ACTIVATION>_BIAS = *_<ACTIVATION>,
|
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
|
||||||
// but we keep it verbose here for clarity.
|
} else if (activation == GEMMAndBiasActivationEpilogue::GELU) {
|
||||||
switch (activation) {
|
epilogue = CUBLASLT_EPILOGUE_GELU_BIAS;
|
||||||
case GEMMAndBiasActivationEpilogue::RELU:
|
}
|
||||||
return bias ? CUBLASLT_EPILOGUE_RELU_BIAS : CUBLASLT_EPILOGUE_RELU;
|
|
||||||
case GEMMAndBiasActivationEpilogue::GELU:
|
|
||||||
return bias ? CUBLASLT_EPILOGUE_GELU_BIAS : CUBLASLT_EPILOGUE_GELU;
|
|
||||||
default:
|
|
||||||
return bias ? CUBLASLT_EPILOGUE_BIAS : CUBLASLT_EPILOGUE_DEFAULT;
|
|
||||||
}
|
|
||||||
}();
|
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);
|
|
||||||
|
|
||||||
if (bias) {
|
if (bias != nullptr) {
|
||||||
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, epilogue);
|
||||||
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias);
|
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_POINTER, bias);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -24,13 +24,7 @@ namespace detail {
|
|||||||
// radix_sort_pairs doesn't interact with value_t other than to copy
|
// radix_sort_pairs doesn't interact with value_t other than to copy
|
||||||
// the data, so we can save template instantiations by reinterpreting
|
// the data, so we can save template instantiations by reinterpreting
|
||||||
// it as an opaque type.
|
// it as an opaque type.
|
||||||
// We use native integer types for 1/2/4/8-byte values to reduce
|
|
||||||
// register usage in CUDA kernels. For sizes > 8 fall back to char array.
|
|
||||||
template <int N> struct alignas(N) OpaqueType { char data[N]; };
|
template <int N> struct alignas(N) OpaqueType { char data[N]; };
|
||||||
template <> struct alignas(1) OpaqueType<1> { uint8_t data; };
|
|
||||||
template <> struct alignas(2) OpaqueType<2> { uint16_t data; };
|
|
||||||
template <> struct alignas(4) OpaqueType<4> { uint32_t data; };
|
|
||||||
template <> struct alignas(8) OpaqueType<8> { uint64_t data; };
|
|
||||||
|
|
||||||
template<typename key_t, int value_size>
|
template<typename key_t, int value_size>
|
||||||
void radix_sort_pairs_impl(
|
void radix_sort_pairs_impl(
|
||||||
|
|||||||
@ -21,7 +21,6 @@
|
|||||||
|
|
||||||
#if AT_CUDNN_ENABLED()
|
#if AT_CUDNN_ENABLED()
|
||||||
#include <ATen/cudnn/cudnn-wrapper.h>
|
#include <ATen/cudnn/cudnn-wrapper.h>
|
||||||
#include <cudnn_frontend.h>
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if AT_MAGMA_ENABLED()
|
#if AT_MAGMA_ENABLED()
|
||||||
@ -352,26 +351,6 @@ long CUDAHooks::versionCuDNN() const {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
long CUDAHooks::versionRuntimeCuDNN() const {
|
|
||||||
#if AT_CUDNN_ENABLED()
|
|
||||||
#ifndef USE_STATIC_CUDNN
|
|
||||||
return cudnnGetVersion();
|
|
||||||
#else
|
|
||||||
return CUDNN_VERSION;
|
|
||||||
#endif
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Cannot query CuDNN version if ATen_cuda is not built with CuDNN");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
long CUDAHooks::versionCuDNNFrontend() const {
|
|
||||||
#if AT_CUDNN_ENABLED()
|
|
||||||
return CUDNN_FRONTEND_VERSION;
|
|
||||||
#else
|
|
||||||
TORCH_CHECK(false, "Cannot query CuDNN Frontend version if ATen_cuda is not built with CuDNN");
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
long CUDAHooks::versionMIOpen() const {
|
long CUDAHooks::versionMIOpen() const {
|
||||||
#if AT_ROCM_ENABLED()
|
#if AT_ROCM_ENABLED()
|
||||||
return MIOPEN_VERSION_MAJOR * 10000 +
|
return MIOPEN_VERSION_MAJOR * 10000 +
|
||||||
|
|||||||
@ -49,8 +49,6 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
|||||||
bool hasCUDART() const override;
|
bool hasCUDART() const override;
|
||||||
long versionCUDART() const override;
|
long versionCUDART() const override;
|
||||||
long versionCuDNN() const override;
|
long versionCuDNN() const override;
|
||||||
long versionRuntimeCuDNN() const override;
|
|
||||||
long versionCuDNNFrontend() const override;
|
|
||||||
long versionMIOpen() const override;
|
long versionMIOpen() const override;
|
||||||
std::string showConfig() const override;
|
std::string showConfig() const override;
|
||||||
double batchnormMinEpsilonCuDNN() const override;
|
double batchnormMinEpsilonCuDNN() const override;
|
||||||
|
|||||||
@ -174,14 +174,6 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
|
|||||||
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
|
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual long versionRuntimeCuDNN() const {
|
|
||||||
TORCH_CHECK(false, "Cannot query cuDNN version without ATen_cuda library. ", CUDA_HELP);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual long versionCuDNNFrontend() const {
|
|
||||||
TORCH_CHECK(false, "Cannot query cuDNN Frontend version without ATen_cuda library. ", CUDA_HELP);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual long versionMIOpen() const {
|
virtual long versionMIOpen() const {
|
||||||
TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP);
|
TORCH_CHECK(false, "Cannot query MIOpen version without ATen_cuda library. ", CUDA_HELP);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -157,8 +157,6 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
|
|||||||
DispatchKey::Negative,
|
DispatchKey::Negative,
|
||||||
DispatchKey::Conjugate,
|
DispatchKey::Conjugate,
|
||||||
DispatchKey::XLA,
|
DispatchKey::XLA,
|
||||||
DispatchKey::XPU,
|
|
||||||
DispatchKey::HPU,
|
|
||||||
DispatchKey::CUDA,
|
DispatchKey::CUDA,
|
||||||
DispatchKey::CPU,
|
DispatchKey::CPU,
|
||||||
DispatchKey::PrivateUse1,
|
DispatchKey::PrivateUse1,
|
||||||
|
|||||||
@ -440,7 +440,7 @@ bool MPSHeapAllocatorImpl::release_cached_buffers() {
|
|||||||
// we need to release the lock temporarily as synchronizing may cause deadlock with completion handlers.
|
// we need to release the lock temporarily as synchronizing may cause deadlock with completion handlers.
|
||||||
m_mutex.unlock();
|
m_mutex.unlock();
|
||||||
auto stream = getDefaultMPSStream();
|
auto stream = getDefaultMPSStream();
|
||||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
dispatch_sync(stream->queue(), ^() {
|
||||||
stream->synchronize(SyncType::COMMIT_AND_WAIT);
|
stream->synchronize(SyncType::COMMIT_AND_WAIT);
|
||||||
});
|
});
|
||||||
m_mutex.lock();
|
m_mutex.lock();
|
||||||
|
|||||||
@ -110,9 +110,6 @@ class TORCH_API MPSStream {
|
|||||||
return _stream;
|
return _stream;
|
||||||
}
|
}
|
||||||
|
|
||||||
MTLBuffer_t getErrorBuffer();
|
|
||||||
void checkLastError();
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Stream _stream;
|
Stream _stream;
|
||||||
MTLCommandQueue_t _commandQueue = nil;
|
MTLCommandQueue_t _commandQueue = nil;
|
||||||
@ -124,8 +121,6 @@ class TORCH_API MPSStream {
|
|||||||
dispatch_queue_t _serialQueue = nullptr;
|
dispatch_queue_t _serialQueue = nullptr;
|
||||||
// CommitAndContinue is enabled by default
|
// CommitAndContinue is enabled by default
|
||||||
bool _enableCommitAndContinue = true;
|
bool _enableCommitAndContinue = true;
|
||||||
// Buffer that contains last raised error
|
|
||||||
MTLBuffer_t _errorBuffer = nil;
|
|
||||||
|
|
||||||
// use synchronize() to access any of these commit functions outside MPSStream
|
// use synchronize() to access any of these commit functions outside MPSStream
|
||||||
void commit();
|
void commit();
|
||||||
@ -160,7 +155,4 @@ class TORCH_API MPSStreamImpl {
|
|||||||
MPSStreamImpl();
|
MPSStreamImpl();
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef __OBJC__
|
|
||||||
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
|
|
||||||
#endif
|
|
||||||
} // namespace at::mps
|
} // namespace at::mps
|
||||||
|
|||||||
@ -3,13 +3,13 @@
|
|||||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||||
#include <ATen/mps/MPSProfiler.h>
|
#include <ATen/mps/MPSProfiler.h>
|
||||||
#include <ATen/mps/MPSStream.h>
|
#include <ATen/mps/MPSStream.h>
|
||||||
#include <c10/metal/error.h>
|
|
||||||
|
|
||||||
@interface MPSGraphExecutionDescriptor ()
|
@interface MPSGraphExecutionDescriptor ()
|
||||||
@property(readwrite, atomic) BOOL enableCommitAndContinue;
|
@property(readwrite, atomic) BOOL enableCommitAndContinue;
|
||||||
@end
|
@end
|
||||||
|
|
||||||
namespace at::mps {
|
namespace at::mps {
|
||||||
|
|
||||||
//-----------------------------------------------------------------
|
//-----------------------------------------------------------------
|
||||||
// MPSStream
|
// MPSStream
|
||||||
//-----------------------------------------------------------------
|
//-----------------------------------------------------------------
|
||||||
@ -30,10 +30,6 @@ MPSStream::MPSStream(Stream stream) : _stream(stream) {
|
|||||||
// Choose level which optimizes for GPU
|
// Choose level which optimizes for GPU
|
||||||
_compilationDescriptor.optimizationLevel = MPSGraphOptimizationLevel0;
|
_compilationDescriptor.optimizationLevel = MPSGraphOptimizationLevel0;
|
||||||
_executionDescriptor.compilationDescriptor = _compilationDescriptor;
|
_executionDescriptor.compilationDescriptor = _compilationDescriptor;
|
||||||
|
|
||||||
_errorBuffer = [MPSDevice::getInstance()->device() newBufferWithLength:sizeof(c10::metal::ErrorMessages)
|
|
||||||
options:MTLResourceStorageModeShared];
|
|
||||||
std::memset([_errorBuffer contents], 0, 1024);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MPSStream::~MPSStream() {
|
MPSStream::~MPSStream() {
|
||||||
@ -42,8 +38,6 @@ MPSStream::~MPSStream() {
|
|||||||
[_executionDescriptor release];
|
[_executionDescriptor release];
|
||||||
[_compilationDescriptor release];
|
[_compilationDescriptor release];
|
||||||
_executionDescriptor = nil;
|
_executionDescriptor = nil;
|
||||||
[_errorBuffer release];
|
|
||||||
_errorBuffer = nil;
|
|
||||||
_compilationDescriptor = nil;
|
_compilationDescriptor = nil;
|
||||||
|
|
||||||
assert(_commandBuffer == nil);
|
assert(_commandBuffer == nil);
|
||||||
@ -110,7 +104,6 @@ void MPSStream::commitAndWait() {
|
|||||||
[_prevCommandBuffer waitUntilCompleted];
|
[_prevCommandBuffer waitUntilCompleted];
|
||||||
[_prevCommandBuffer release];
|
[_prevCommandBuffer release];
|
||||||
_prevCommandBuffer = nil;
|
_prevCommandBuffer = nil;
|
||||||
checkLastError();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (_commandBuffer) {
|
if (_commandBuffer) {
|
||||||
@ -118,7 +111,6 @@ void MPSStream::commitAndWait() {
|
|||||||
[_commandBuffer waitUntilCompleted];
|
[_commandBuffer waitUntilCompleted];
|
||||||
[_commandBuffer release];
|
[_commandBuffer release];
|
||||||
_commandBuffer = nil;
|
_commandBuffer = nil;
|
||||||
checkLastError();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -161,7 +153,7 @@ void MPSStream::fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t
|
|||||||
if (length == 0) {
|
if (length == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
dispatch_sync_with_rethrow(_serialQueue, ^() {
|
dispatch_sync(_serialQueue, ^() {
|
||||||
@autoreleasepool {
|
@autoreleasepool {
|
||||||
endKernelCoalescing();
|
endKernelCoalescing();
|
||||||
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
|
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
|
||||||
@ -191,7 +183,7 @@ void MPSStream::copy(id<MTLBuffer> srcBuffer,
|
|||||||
size_t dstOffset,
|
size_t dstOffset,
|
||||||
uint64_t profileId,
|
uint64_t profileId,
|
||||||
SyncType syncType) {
|
SyncType syncType) {
|
||||||
dispatch_sync_with_rethrow(_serialQueue, ^() {
|
dispatch_sync(_serialQueue, ^() {
|
||||||
@autoreleasepool {
|
@autoreleasepool {
|
||||||
endKernelCoalescing();
|
endKernelCoalescing();
|
||||||
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
|
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
|
||||||
@ -244,7 +236,7 @@ void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDicti
|
|||||||
auto& profiler = getMPSProfiler();
|
auto& profiler = getMPSProfiler();
|
||||||
const bool isGraphProfilingEnabled = profiler.isOperationProfilingEnabled();
|
const bool isGraphProfilingEnabled = profiler.isOperationProfilingEnabled();
|
||||||
|
|
||||||
dispatch_sync_with_rethrow(_serialQueue, ^() {
|
dispatch_sync(_serialQueue, ^() {
|
||||||
endKernelCoalescing();
|
endKernelCoalescing();
|
||||||
if (isGraphProfilingEnabled) {
|
if (isGraphProfilingEnabled) {
|
||||||
// this function call is only relevant for interval-based Signposts
|
// this function call is only relevant for interval-based Signposts
|
||||||
@ -274,24 +266,6 @@ void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDicti
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
id<MTLBuffer> MPSStream::getErrorBuffer() {
|
|
||||||
return _errorBuffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
void MPSStream::checkLastError() {
|
|
||||||
auto msgs = reinterpret_cast<c10::metal::ErrorMessages*>([_errorBuffer contents]);
|
|
||||||
const auto& msg = msgs->msg[0];
|
|
||||||
if (!msgs) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
unsigned int count = 0;
|
|
||||||
std::swap(count, msgs->count);
|
|
||||||
if (!count) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
throw c10::AcceleratorError({msg.func, msg.file, msg.line}, 1, msg.message);
|
|
||||||
}
|
|
||||||
|
|
||||||
//-----------------------------------------------------------------
|
//-----------------------------------------------------------------
|
||||||
// MPSStreamImpl
|
// MPSStreamImpl
|
||||||
//-----------------------------------------------------------------
|
//-----------------------------------------------------------------
|
||||||
@ -315,19 +289,4 @@ MPSStream* getDefaultMPSStream() {
|
|||||||
return MPSStreamImpl::getInstance();
|
return MPSStreamImpl::getInstance();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper methods
|
|
||||||
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) {
|
|
||||||
__block std::optional<std::exception_ptr> block_exception;
|
|
||||||
dispatch_sync(queue, ^() {
|
|
||||||
try {
|
|
||||||
block();
|
|
||||||
} catch (...) {
|
|
||||||
block_exception = std::current_exception();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
if (block_exception) {
|
|
||||||
std::rethrow_exception(*block_exception);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace at::mps
|
} // namespace at::mps
|
||||||
|
|||||||
@ -1009,25 +1009,12 @@ static Device correct_out_device(const Tensor& self, const Tensor& other) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static Tensor send_to_meta(const Tensor& self, const Device& device) {
|
|
||||||
Tensor out_meta;
|
|
||||||
if (self._is_zerotensor() && self.unsafeGetTensorImpl()->is_wrapped_number()) {
|
|
||||||
out_meta = at::_efficientzerotensor(self.sizes(), self.options().device(device));
|
|
||||||
out_meta.unsafeGetTensorImpl()->set_wrapped_number(true);
|
|
||||||
} else {
|
|
||||||
out_meta = self.to(device);
|
|
||||||
}
|
|
||||||
return out_meta;
|
|
||||||
}
|
|
||||||
|
|
||||||
Tensor mul_zerotensor(const Tensor& self, const Tensor& other) {
|
Tensor mul_zerotensor(const Tensor& self, const Tensor& other) {
|
||||||
auto out_device = correct_out_device(self, other);
|
auto out_device = correct_out_device(self, other);
|
||||||
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
|
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
|
||||||
auto device_ = Device(DeviceType::Meta);
|
auto device_ = Device(DeviceType::Meta);
|
||||||
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
|
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
|
||||||
auto self_meta = send_to_meta(self, device_);
|
auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_));
|
||||||
auto other_meta = send_to_meta(other, device_);
|
|
||||||
auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self_meta, other_meta);
|
|
||||||
return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device));
|
return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1036,9 +1023,7 @@ Tensor div_zerotensor(const Tensor& self, const Tensor& other) {
|
|||||||
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
|
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
|
||||||
auto device_ = Device(DeviceType::Meta);
|
auto device_ = Device(DeviceType::Meta);
|
||||||
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
|
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
|
||||||
auto self_meta = send_to_meta(self, device_);
|
auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_));
|
||||||
auto other_meta = send_to_meta(other, device_);
|
|
||||||
auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self_meta, other_meta);
|
|
||||||
|
|
||||||
if (self._is_zerotensor()) {
|
if (self._is_zerotensor()) {
|
||||||
if (other._is_zerotensor()) {
|
if (other._is_zerotensor()) {
|
||||||
@ -1067,9 +1052,8 @@ static Tensor maybe_add_maybe_sub(const Tensor& self, const Tensor& other, const
|
|||||||
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
|
// hack to use the TensorIterator to get the correct broadcasting and type promotion logic
|
||||||
auto device_ = Device(DeviceType::Meta);
|
auto device_ = Device(DeviceType::Meta);
|
||||||
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
|
constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta);
|
||||||
auto self_meta = send_to_meta(self, device_);
|
auto meta_out = at::_ops::add_Tensor::redispatch(
|
||||||
auto other_meta = send_to_meta(other, device_);
|
meta_dks, self.to(device_), other.to(device_), alpha);
|
||||||
auto meta_out = at::_ops::add_Tensor::redispatch(meta_dks, self_meta, other_meta, alpha);
|
|
||||||
|
|
||||||
auto get_out_like = [&] (const Tensor& tensor)
|
auto get_out_like = [&] (const Tensor& tensor)
|
||||||
{
|
{
|
||||||
|
|||||||
@ -409,7 +409,7 @@ struct ConvParams {
|
|||||||
if (!detail::getCUDAHooks().compiledWithCuDNN() || !input.is_cuda() || !cudnn_enabled) {
|
if (!detail::getCUDAHooks().compiledWithCuDNN() || !input.is_cuda() || !cudnn_enabled) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
static long cudnn_version = detail::getCUDAHooks().versionRuntimeCuDNN();
|
static long cudnn_version = detail::getCUDAHooks().versionCuDNN();
|
||||||
// broken on cuDNN 9.8 - 9.14
|
// broken on cuDNN 9.8 - 9.14
|
||||||
if (cudnn_version >= 90800 && cudnn_version < 91500) {
|
if (cudnn_version >= 90800 && cudnn_version < 91500) {
|
||||||
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
|
if (cudnn_conv_suggest_memory_format(input, weight) == at::MemoryFormat::Contiguous &&
|
||||||
@ -453,7 +453,7 @@ struct ConvParams {
|
|||||||
}
|
}
|
||||||
// native kernel doesn't support 64-bit non-splittable case
|
// native kernel doesn't support 64-bit non-splittable case
|
||||||
if (!(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) {
|
if (!(canUse32BitIndexMath(input) && canUse32BitIndexMath(weight))) {
|
||||||
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionRuntimeCuDNN() : -1;
|
static long cudnn_version = detail::getCUDAHooks().compiledWithCuDNN() ? detail::getCUDAHooks().versionCuDNN() : -1;
|
||||||
// TODO(eqy): remove this once cuDNN fixes 64-bit depthwise support, first broken in 9.11x
|
// TODO(eqy): remove this once cuDNN fixes 64-bit depthwise support, first broken in 9.11x
|
||||||
if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous) {
|
if (cudnn_conv_suggest_memory_format(input, weight) != at::MemoryFormat::Contiguous) {
|
||||||
if (cudnn_version < 0 || cudnn_version > 91000) {
|
if (cudnn_version < 0 || cudnn_version > 91000) {
|
||||||
|
|||||||
@ -50,35 +50,18 @@ static inline bool parseLinearFlatten3d() {
|
|||||||
// `_flatten_nd_linear` flattens all but the last dimension of the input tensor
|
// `_flatten_nd_linear` flattens all but the last dimension of the input tensor
|
||||||
// before passing it to linear operation
|
// before passing it to linear operation
|
||||||
static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
|
static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
|
||||||
const auto input_sizes = input.sym_sizes();
|
const auto input_sizes = input.sym_sizes();
|
||||||
|
// can't use -1 in reshape because it errors when a dimension is 0
|
||||||
const auto result_flattened = [&]() -> Tensor {
|
c10::SymInt flattened_dim = 1;
|
||||||
const auto input_ncols = input_sizes.back();
|
for (int64_t i = 0, ndim = input_sizes.size(); i < ndim - 1; ++i) {
|
||||||
const auto input_flattened_nrows = [&]() -> c10::SymInt {
|
flattened_dim = flattened_dim * input_sizes[i];
|
||||||
// can't use -1 in reshape because it errors when a dimension is 0
|
|
||||||
auto flattened_nrows = c10::SymInt{1};
|
|
||||||
for (const auto& size : input_sizes.slice(0, input_sizes.size() - 1)) {
|
|
||||||
flattened_nrows *= size;
|
|
||||||
}
|
|
||||||
return flattened_nrows;
|
|
||||||
}();
|
|
||||||
|
|
||||||
const auto input_flattened = input.view_symint({input_flattened_nrows, input_ncols});
|
|
||||||
if (weight.layout() == c10::kStrided) {
|
|
||||||
return at::addmm(bias, input_flattened, weight.t());
|
|
||||||
} else {
|
|
||||||
// weight is sparse, and addmm for sparse expects matmul lhs to be sparse,
|
|
||||||
// so we transpose the problem.
|
|
||||||
// NOTE: at::matmul handles (dense @ sparse) similarly.
|
|
||||||
const auto bias_t = (bias.dim() >= 2) ? bias.mT() : bias.unsqueeze(-1);
|
|
||||||
return at::addmm(bias_t, weight, input_flattened.t()).t();
|
|
||||||
}
|
}
|
||||||
}();
|
auto inp_reshape = input.reshape_symint({flattened_dim, input_sizes.at(input_sizes.size() -1)});
|
||||||
|
const auto result = at::addmm(bias, inp_reshape, weight.t());
|
||||||
// Unflatten flattened row dims
|
auto new_size = input_sizes.slice(0, input_sizes.size() - 1);
|
||||||
auto result_sizes = c10::SymDimVector{input_sizes.begin(), input_sizes.end()};
|
c10::SymDimVector sizes_vec(new_size.begin(), new_size.end());
|
||||||
result_sizes.back() = result_flattened.sym_size(1);
|
sizes_vec.push_back(result.sym_size(1));
|
||||||
return result_flattened.view_symint(result_sizes);
|
return result.view_symint(sizes_vec);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -107,23 +90,15 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optional<Ten
|
|||||||
// Fused op is marginally faster.
|
// Fused op is marginally faster.
|
||||||
return at::addmm(*bias, input, weight.t());
|
return at::addmm(*bias, input, weight.t());
|
||||||
}
|
}
|
||||||
|
if (bias->defined() && !input.is_xla()) {
|
||||||
const auto is_bias_likely_fusable = (
|
// Also hit the fused path for contiguous 3D input, if not using xla
|
||||||
bias->defined() &&
|
|
||||||
// cuBLASLt: will fuse in the epilogue without copies
|
|
||||||
// when input/weight/bias are all strided.
|
|
||||||
// When weight is not strided, bias will not be fused,
|
|
||||||
// but we can still dispatch here to avoid at::matmul
|
|
||||||
// path which will probably use a very similar
|
|
||||||
// flattening optimization.
|
|
||||||
((bias->dim() == 1 || bias->squeeze().dim() == 1) && bias->is_contiguous_or_false())
|
|
||||||
);
|
|
||||||
if (is_bias_likely_fusable && !input.is_xla()) {
|
|
||||||
// Also hit the fused path for contiguous nD input, if not using xla
|
|
||||||
// backend. Reshaping/flattening has some performance implications on xla.
|
// backend. Reshaping/flattening has some performance implications on xla.
|
||||||
if (input.is_contiguous_or_false()) {
|
bool is_contiguous = input.is_contiguous_or_false();
|
||||||
|
if (is_contiguous && input_dim == 3) {
|
||||||
return _flatten_nd_linear(input, weight, *bias);
|
return _flatten_nd_linear(input, weight, *bias);
|
||||||
} else if (parseLinearFlatten3d()) {
|
} else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) {
|
||||||
|
return _flatten_nd_linear(input, weight, *bias);
|
||||||
|
} else if (parseLinearFlatten3d() && input_dim == 3) {
|
||||||
// If user forces flattening via env var
|
// If user forces flattening via env var
|
||||||
const Tensor input_cont = input.contiguous();
|
const Tensor input_cont = input.contiguous();
|
||||||
return _flatten_nd_linear(input_cont, weight, *bias);
|
return _flatten_nd_linear(input_cont, weight, *bias);
|
||||||
|
|||||||
@ -23,7 +23,6 @@
|
|||||||
#include <ATen/ops/_aminmax_native.h>
|
#include <ATen/ops/_aminmax_native.h>
|
||||||
#include <ATen/ops/_assert_async_native.h>
|
#include <ATen/ops/_assert_async_native.h>
|
||||||
#include <ATen/ops/_assert_scalar_native.h>
|
#include <ATen/ops/_assert_scalar_native.h>
|
||||||
#include <ATen/ops/_async_error_native.h>
|
|
||||||
#include <ATen/ops/_functional_assert_async_native.h>
|
#include <ATen/ops/_functional_assert_async_native.h>
|
||||||
#include <ATen/ops/_functional_assert_scalar_native.h>
|
#include <ATen/ops/_functional_assert_scalar_native.h>
|
||||||
#include <ATen/ops/_make_per_tensor_quantized_tensor.h>
|
#include <ATen/ops/_make_per_tensor_quantized_tensor.h>
|
||||||
@ -480,14 +479,6 @@ Tensor isfinite(const Tensor& self) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void _async_error(std::string_view msg) {
|
|
||||||
TORCH_CHECK(0, msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
void _async_error_meta(std::string_view msg) {
|
|
||||||
// Do NOT error, it's an async error!
|
|
||||||
}
|
|
||||||
|
|
||||||
void _assert_async_cpu(const Tensor& self) {
|
void _assert_async_cpu(const Tensor& self) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
native::is_nonzero(self),
|
native::is_nonzero(self),
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
#include <ATen/core/ATen_fwd.h>
|
#include <ATen/core/ATen_fwd.h>
|
||||||
#include <c10/core/ScalarType.h>
|
#include <c10/core/ScalarType.h>
|
||||||
#include <c10/core/SymInt.h>
|
|
||||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||||
#include <ATen/AccumulateType.h>
|
#include <ATen/AccumulateType.h>
|
||||||
#include <ATen/Dispatch.h>
|
#include <ATen/Dispatch.h>
|
||||||
@ -1711,37 +1710,11 @@ Tensor narrow_symint(
|
|||||||
"], but got ",
|
"], but got ",
|
||||||
start,
|
start,
|
||||||
")")
|
")")
|
||||||
|
if (start < 0) {
|
||||||
auto cond1 = TORCH_GUARD_OR_FALSE(start.sym_lt(0));
|
start = start + cur_size;
|
||||||
auto cond2 = TORCH_GUARD_OR_FALSE(start.sym_ge(0));
|
|
||||||
|
|
||||||
if (cond1 || cond2) {
|
|
||||||
if (cond1) {
|
|
||||||
start = start + cur_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_SYM_CHECK(
|
|
||||||
start.sym_le(cur_size - length),
|
|
||||||
"start (",
|
|
||||||
start,
|
|
||||||
") + length (",
|
|
||||||
length,
|
|
||||||
") exceeds dimension size (",
|
|
||||||
cur_size,
|
|
||||||
").");
|
|
||||||
return at::slice_symint(self, dim, start, start + length, 1);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unbacked start handling!
|
|
||||||
|
|
||||||
// Bounds check without converting start:
|
|
||||||
// - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start +
|
|
||||||
// length <= 0
|
|
||||||
// - If start >= 0: need start + length <= cur_size
|
|
||||||
auto end = start + length;
|
|
||||||
TORCH_SYM_CHECK(
|
TORCH_SYM_CHECK(
|
||||||
(start.sym_lt(0).sym_and((end).sym_le(0)))
|
start.sym_le(cur_size - length),
|
||||||
.sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))),
|
|
||||||
"start (",
|
"start (",
|
||||||
start,
|
start,
|
||||||
") + length (",
|
") + length (",
|
||||||
@ -1749,28 +1722,7 @@ Tensor narrow_symint(
|
|||||||
") exceeds dimension size (",
|
") exceeds dimension size (",
|
||||||
cur_size,
|
cur_size,
|
||||||
").");
|
").");
|
||||||
|
return at::slice_symint(self, dim, start, start + length, 1);
|
||||||
if (TORCH_GUARD_OR_FALSE(end.sym_ne(0))) {
|
|
||||||
return at::slice_symint(self, dim, start, end, 1);
|
|
||||||
} else {
|
|
||||||
// Cannot statically determine the condition due to unbacked.
|
|
||||||
// This is an interesting situation; when start is negative and
|
|
||||||
// start + length == 0, slice and narrow do different things.
|
|
||||||
// i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to
|
|
||||||
// pass curr_size instead of 0. Otherwise, they would do the same thing.
|
|
||||||
// This says at runtime: if start < 0 and end == 0, then pass curr_size
|
|
||||||
// instead of 0.
|
|
||||||
|
|
||||||
auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt();
|
|
||||||
auto result =
|
|
||||||
at::slice_symint(self, dim, start, end + use_different * cur_size, 1);
|
|
||||||
|
|
||||||
// Ensure slice allocated unbacked size is specialized to length.
|
|
||||||
SymInt new_size = result.sym_size(dim);
|
|
||||||
TORCH_SYM_CHECK(new_size.sym_eq(length), "")
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// This overload exists purely for XLA, because they wanted to pass in
|
// This overload exists purely for XLA, because they wanted to pass in
|
||||||
@ -1784,8 +1736,8 @@ Tensor narrow_tensor_symint(
|
|||||||
start.dim() == 0 &&
|
start.dim() == 0 &&
|
||||||
isIntegralType(start.scalar_type(), /*includeBool=*/false),
|
isIntegralType(start.scalar_type(), /*includeBool=*/false),
|
||||||
"start must be an 0-dim integral Tensor.");
|
"start must be an 0-dim integral Tensor.");
|
||||||
c10::SymInt st = start.item().toSymInt();
|
int64_t st = start.item<int64_t>();
|
||||||
return at::narrow_symint(self, dim, std::move(st), std::move(length));
|
return at::narrow_symint(self, dim, c10::SymInt(st), std::move(length));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::
|
std::
|
||||||
|
|||||||
@ -293,7 +293,7 @@ struct ComputeLocationBase<scalar_t, /*align_corners=*/false> {
|
|||||||
, empty(size <= 0) {}
|
, empty(size <= 0) {}
|
||||||
|
|
||||||
inline Vec unnormalize(const Vec &in) const {
|
inline Vec unnormalize(const Vec &in) const {
|
||||||
return (in + Vec(static_cast<scalar_t>(1))) * Vec(scaling_factor) - Vec(static_cast<scalar_t>(0.5));
|
return (in + Vec(1)) * Vec(scaling_factor) - Vec(0.5);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Vec clip_coordinates(const Vec &in) const {
|
inline Vec clip_coordinates(const Vec &in) const {
|
||||||
@ -831,7 +831,7 @@ struct ApplyGridSample<scalar_t, 2, GridSamplerInterpolation::Bicubic,
|
|||||||
|
|
||||||
// constant used in cubic convolution
|
// constant used in cubic convolution
|
||||||
// could be -0.5 or -0.75, use the same value in UpSampleBicubic2d.h
|
// could be -0.5 or -0.75, use the same value in UpSampleBicubic2d.h
|
||||||
const Vec A = Vec(static_cast<scalar_t>(-0.75));
|
const Vec A = Vec(-0.75);
|
||||||
|
|
||||||
ApplyGridSample(const TensorAccessor<const scalar_t, 4>& input)
|
ApplyGridSample(const TensorAccessor<const scalar_t, 4>& input)
|
||||||
: inp_H(input.size(2))
|
: inp_H(input.size(2))
|
||||||
|
|||||||
@ -147,24 +147,14 @@ static bool isGloballyDisabledAddmmCudaLt(const at::Device& device) {
|
|||||||
/*
|
/*
|
||||||
* Check whether for the given input we want to enable the Lt interface
|
* Check whether for the given input we want to enable the Lt interface
|
||||||
*/
|
*/
|
||||||
static bool isInputCompliesAddmmCudaLt(
|
static bool isInputCompliesAddmmCudaLt(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
|
||||||
Tensor& result,
|
|
||||||
const Tensor& self,
|
|
||||||
const Tensor& mat1,
|
|
||||||
const Tensor& mat2,
|
|
||||||
const Scalar& beta,
|
|
||||||
const Scalar& alpha,
|
|
||||||
Activation activation
|
|
||||||
) {
|
|
||||||
#ifdef USE_ROCM
|
|
||||||
// Implies 2D bias which we currently not send through Lt.
|
// Implies 2D bias which we currently not send through Lt.
|
||||||
// TODO: this check is done pre col-major input preparation,
|
// TODO: this check is done pre col-major input preparation,
|
||||||
// so, this condition can be ralexed in cases when a col-major
|
// so, this condition can be ralexed in cases when a col-major
|
||||||
// copy of result is needed.
|
// copy of result is needed.
|
||||||
if (self.is_same(result) || self.dim() == 2) {
|
if (result.is_same(self)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
#if defined(USE_ROCM) && ROCM_VERSION == 60400
|
#if defined(USE_ROCM) && ROCM_VERSION == 60400
|
||||||
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
|
// hipblaslt TT fp32 regression on ROCm 6.4, cannot use
|
||||||
@ -179,33 +169,13 @@ static bool isInputCompliesAddmmCudaLt(
|
|||||||
#if defined(CUDA_VERSION) || defined(USE_ROCM)
|
#if defined(CUDA_VERSION) || defined(USE_ROCM)
|
||||||
const auto scalar_type = mat1.scalar_type();
|
const auto scalar_type = mat1.scalar_type();
|
||||||
return (beta.toComplexDouble() == 1.0
|
return (beta.toComplexDouble() == 1.0
|
||||||
// NOTE: row-major result is important when bias is 1D.
|
|
||||||
// This is because Lt broadcasts 1D bias over the columns
|
|
||||||
// while the aten::addmm API broadcasts it over the rows,
|
|
||||||
// and this is in conjuction with the data preparation
|
|
||||||
// procedure that does not transpose arguments with
|
|
||||||
// col-major result. For col-major result we need
|
|
||||||
// to explicitly transpose the problem so that bias is
|
|
||||||
// correctly applied.
|
|
||||||
// TODO: enable col-major result if needed.
|
|
||||||
// TODO: no need to check result's layout when
|
|
||||||
// !result.is_same(self) and self.dim() == 2, because
|
|
||||||
// self needs to be copied into result and the bias ptr
|
|
||||||
// will be ignored.
|
|
||||||
&& result.dim() == 2 && result.is_contiguous()
|
&& result.dim() == 2 && result.is_contiguous()
|
||||||
|
// Conditions for bias to be fusable
|
||||||
&& (
|
&& (
|
||||||
( // Conditions for bias to be fusable -- implies direct Lt path without copies.
|
self.is_contiguous() &&
|
||||||
self.is_contiguous() &&
|
// NOTE: fine to have 1-len dims to the left from the right-most one
|
||||||
// NOTE: fine to have 1-len dims to the left from the right-most one
|
(self.dim() == 1 || self.squeeze().dim() == 1) &&
|
||||||
(self.dim() == 1 || self.squeeze().dim() == 1) &&
|
self.sizes().back() == mat2_sizes[1]
|
||||||
self.sizes().back() == mat2_sizes[1]
|
|
||||||
)
|
|
||||||
|| ( // 2D bias restrictions. self.is_contiguous() is implicit when result.is_same(self),
|
|
||||||
// and we need to copy self into result otherwise, so the self's layout becomes irrelevant.
|
|
||||||
// See also TODO from above.
|
|
||||||
activation != Activation::None && // Lt is faster when activation is fused
|
|
||||||
(self.dim() == 2 && at::is_expandable_to(self.sizes(), {mat1_sizes[0], mat2_sizes[1]}))
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
&& ( // some dtype restrictions
|
&& ( // some dtype restrictions
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
@ -300,16 +270,7 @@ bool launchGemmAndBiasCublasLt(
|
|||||||
const Scalar& alpha,
|
const Scalar& alpha,
|
||||||
Activation activation = Activation::None
|
Activation activation = Activation::None
|
||||||
) {
|
) {
|
||||||
// We apply bias in the epilogue only when it is 1D,
|
const auto* self_ptr = self.const_data_ptr<scalar_t>();
|
||||||
// or when it can be squeezed to 1D.
|
|
||||||
// self_ptr == nullptr implies ignore bias epilogue
|
|
||||||
// and use standard gemm-like API.
|
|
||||||
const auto* self_ptr = [&]() -> auto {
|
|
||||||
if (self.dim() == 1 || self.squeeze().dim() == 1) {
|
|
||||||
return self.const_data_ptr<scalar_t>();
|
|
||||||
}
|
|
||||||
return static_cast<const scalar_t*>(nullptr);
|
|
||||||
}();
|
|
||||||
|
|
||||||
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
const auto tuning_ctx = at::cuda::tunable::getTuningContext();
|
||||||
if (tuning_ctx->IsTunableOpEnabled()) {
|
if (tuning_ctx->IsTunableOpEnabled()) {
|
||||||
@ -395,7 +356,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|||||||
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
|
disable_addmm_cuda_lt = isGloballyDisabledAddmmCudaLt(self.device()) || disable_addmm_cuda_lt;
|
||||||
#endif
|
#endif
|
||||||
// Condition on the input
|
// Condition on the input
|
||||||
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha, activation) || disable_addmm_cuda_lt;
|
disable_addmm_cuda_lt = !isInputCompliesAddmmCudaLt(result, self, mat1, mat2, beta, alpha) || disable_addmm_cuda_lt;
|
||||||
// }
|
// }
|
||||||
|
|
||||||
at::ScalarType scalar_type = mat1.scalar_type();
|
at::ScalarType scalar_type = mat1.scalar_type();
|
||||||
@ -405,20 +366,19 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|||||||
if (!result.is_same(self)) {
|
if (!result.is_same(self)) {
|
||||||
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
|
at::native::resize_output(result, {mat1.sizes()[0], mat2.sizes()[1]});
|
||||||
|
|
||||||
// We use bias ptr in the Lt path only when bias is 1D
|
|
||||||
const auto use_bias_ptr_lt = (self.dim() == 1) && !disable_addmm_cuda_lt;
|
|
||||||
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
|
const auto self_maybe_expanded = [&]() -> c10::MaybeOwned<Tensor> {
|
||||||
if (!use_bias_ptr_lt) {
|
if (disable_addmm_cuda_lt) {
|
||||||
// We do expand self even before
|
// When in non-Lt path we do expand self even before
|
||||||
// check for beta != 0.0 to make sure that
|
// check for beta != 0.0 to make sure that
|
||||||
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
|
// test_sparse_csr.py::TestSparseCSRCUDA::test_addmm_errors_*
|
||||||
// runs green.
|
// runs green.
|
||||||
return expand_size(self, result.sizes(), "addmm");
|
return expand_size(self, result.sizes(), "addmm");
|
||||||
}
|
}
|
||||||
|
// copy next, should broadcast
|
||||||
return c10::MaybeOwned<Tensor>::borrowed(self);
|
return c10::MaybeOwned<Tensor>::borrowed(self);
|
||||||
}();
|
}();
|
||||||
// We do not copy bias only when we need the bias ptr
|
// We copy bias when in the non-Lt path
|
||||||
if (beta.toComplexDouble() != 0.0 && !use_bias_ptr_lt) {
|
if (beta.toComplexDouble() != 0.0 && disable_addmm_cuda_lt) {
|
||||||
// NOTE: self should broadcast over result
|
// NOTE: self should broadcast over result
|
||||||
at::native::copy_(result, *self_maybe_expanded);
|
at::native::copy_(result, *self_maybe_expanded);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -884,69 +884,6 @@ struct type_specialized_kernel_launcher {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int arg_index>
|
|
||||||
struct type_specialized_broadcast_kernel_launcher {
|
|
||||||
template <
|
|
||||||
typename func_t,
|
|
||||||
typename array_t,
|
|
||||||
typename dtypes_t,
|
|
||||||
typename calc_t>
|
|
||||||
static void apply(
|
|
||||||
int64_t numel,
|
|
||||||
func_t f,
|
|
||||||
array_t data,
|
|
||||||
dtypes_t dtypes,
|
|
||||||
calc_t offset_calc) {
|
|
||||||
using traits = function_traits<func_t>;
|
|
||||||
using ret_t = typename traits::result_type;
|
|
||||||
using arg0_t = typename traits::template arg<0>::type;
|
|
||||||
using arg1_t = typename traits::template arg<1>::type;
|
|
||||||
if (dtypes[0] == rt_binary_specializations[arg_index][0] &&
|
|
||||||
dtypes[1] == rt_binary_specializations[arg_index][1] &&
|
|
||||||
dtypes[2] == rt_binary_specializations[arg_index][2]) {
|
|
||||||
using ret_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][0]>;
|
|
||||||
using arg0_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][1]>;
|
|
||||||
using arg1_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][2]>;
|
|
||||||
constexpr int grp_sz = 128;
|
|
||||||
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
|
|
||||||
if (unrl) {
|
|
||||||
auto offsets0 = offset_calc.get(idx);
|
|
||||||
auto offsets1 = offset_calc.get(idx + grp_sz);
|
|
||||||
auto offsets2 = offset_calc.get(idx + grp_sz * 2);
|
|
||||||
auto offsets3 = offset_calc.get(idx + grp_sz * 3);
|
|
||||||
void* out0 = data[0] + offsets0[0];
|
|
||||||
void* out1 = data[0] + offsets1[0];
|
|
||||||
void* out2 = data[0] + offsets2[0];
|
|
||||||
void* out3 = data[0] + offsets3[0];
|
|
||||||
auto u = c10::load<arg0_cpp_t>(data[1] + offsets0[1]);
|
|
||||||
auto v = c10::load<arg1_cpp_t>(data[2] + offsets0[2]);
|
|
||||||
ret_t result0 = f(c10::convert<arg0_t>(u), c10::convert<arg1_t>(v));
|
|
||||||
auto u1 = c10::load<arg0_cpp_t>(data[1] + offsets1[1]);
|
|
||||||
auto v1 = c10::load<arg1_cpp_t>(data[2]+ offsets1[2]);
|
|
||||||
ret_t result1 = f(c10::convert<arg0_t>(u1), c10::convert<arg1_t>(v1));
|
|
||||||
auto u2 = c10::load<arg0_cpp_t>(data[1] + offsets2[1]);
|
|
||||||
auto v2 = c10::load<arg1_cpp_t>(data[2] + offsets2[2]);
|
|
||||||
ret_t result2 = f(c10::convert<arg0_t>(u2), c10::convert<arg1_t>(v2));
|
|
||||||
auto u3 = c10::load<arg0_cpp_t>(data[1] + offsets3[1]);
|
|
||||||
auto v3 = c10::load<arg1_cpp_t>(data[2] + offsets3[2]);
|
|
||||||
ret_t result3 = f(c10::convert<arg0_t>(u3), c10::convert<arg1_t>(v3));
|
|
||||||
*(ret_cpp_t*)out0 = c10::convert<ret_cpp_t>(result0);
|
|
||||||
*(ret_cpp_t*)out1 = c10::convert<ret_cpp_t>(result1);
|
|
||||||
*(ret_cpp_t*)out2 = c10::convert<ret_cpp_t>(result2);
|
|
||||||
*(ret_cpp_t*)out3 = c10::convert<ret_cpp_t>(result3);
|
|
||||||
} else {
|
|
||||||
auto offsets = offset_calc.get(idx);
|
|
||||||
void* out = data[0] + offsets[0];
|
|
||||||
auto u = c10::load<arg0_cpp_t>(data[1] + offsets[1]);
|
|
||||||
auto v = c10::load<arg1_cpp_t>(data[2] + offsets[2]);
|
|
||||||
ret_t result = f(c10::convert<arg0_t>(u), c10::convert<arg1_t>(v));
|
|
||||||
*(ret_cpp_t*)out = c10::convert<ret_cpp_t>(result);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -1065,32 +1002,6 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
|
|||||||
}
|
}
|
||||||
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
|
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
if (check_binary_rt_types_for_specialization(iter)) {
|
|
||||||
// constexpr to reduce the amount of kernels generated for
|
|
||||||
// broadcast elementwise with mexed dtypes and limit which functors are actually
|
|
||||||
// applied to the load and store at compile time.
|
|
||||||
using func_tuple = typename traits::ArgsTuple;
|
|
||||||
if constexpr (
|
|
||||||
std::is_same_v<float, arg0_t> && traits::arity == 2 &&
|
|
||||||
check_binary_functor_types_for_specialization<
|
|
||||||
func_tuple,
|
|
||||||
float,
|
|
||||||
float,
|
|
||||||
traits::arity,
|
|
||||||
/*arg_num=*/0>::check()) {
|
|
||||||
memory::detail::static_unroll<
|
|
||||||
type_specialized_broadcast_kernel_launcher,
|
|
||||||
rt_binary_specializations.size()>::with_args(
|
|
||||||
numel,
|
|
||||||
f,
|
|
||||||
data,
|
|
||||||
dtypes,
|
|
||||||
offset_calc
|
|
||||||
);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr int grp_sz = 128;
|
constexpr int grp_sz = 128;
|
||||||
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
|
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
|
||||||
if (unrl) {
|
if (unrl) {
|
||||||
|
|||||||
@ -22,9 +22,6 @@
|
|||||||
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
#include <ATen/native/cuda/RowwiseScaledMM.h>
|
||||||
#include <ATen/native/cuda/ScaledGroupMM.h>
|
#include <ATen/native/cuda/ScaledGroupMM.h>
|
||||||
#include <ATen/native/cuda/GroupMM.h>
|
#include <ATen/native/cuda/GroupMM.h>
|
||||||
#ifdef USE_ROCM
|
|
||||||
#include <ATen/native/hip/ck_group_gemm.h>
|
|
||||||
#endif
|
|
||||||
#include <ATen/ceil_div.h>
|
#include <ATen/ceil_div.h>
|
||||||
|
|
||||||
#ifdef USE_FBGEMM_GENAI
|
#ifdef USE_FBGEMM_GENAI
|
||||||
@ -669,19 +666,12 @@ std::optional<c10::ScalarType> out_dtype) {
|
|||||||
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
// _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used.
|
||||||
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
// the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm
|
||||||
bool use_fast_path = false;
|
bool use_fast_path = false;
|
||||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) {
|
|
||||||
use_fast_path = true;
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
|
||||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||||
if (use_fast_path) {
|
if (use_fast_path) {
|
||||||
// fast path, no d2h sync needed
|
// fast path, no d2h sync needed
|
||||||
#ifndef USE_ROCM
|
|
||||||
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
|
||||||
#else
|
|
||||||
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
|
|
||||||
#endif
|
|
||||||
} else {
|
} else {
|
||||||
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
#include <array>
|
#include <array>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <ATen/core/TensorBase.h>
|
#include <ATen/core/TensorBase.h>
|
||||||
|
#include <ATen/ceil_div.h>
|
||||||
#include <ATen/Dispatch.h>
|
#include <ATen/Dispatch.h>
|
||||||
#include <ATen/Dispatch_v2.h>
|
#include <ATen/Dispatch_v2.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
@ -73,6 +74,7 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
|
|||||||
|
|
||||||
char* const out_ptr = static_cast<char*>(iter.data_ptr(0));
|
char* const out_ptr = static_cast<char*>(iter.data_ptr(0));
|
||||||
char* const in_ptr = static_cast<char*>(iter.data_ptr(1));
|
char* const in_ptr = static_cast<char*>(iter.data_ptr(1));
|
||||||
|
|
||||||
if (is_gather_like && num_indices==1) {
|
if (is_gather_like && num_indices==1) {
|
||||||
const size_t element_size = iter.element_size(0);
|
const size_t element_size = iter.element_size(0);
|
||||||
constexpr size_t alignment = 16;
|
constexpr size_t alignment = 16;
|
||||||
@ -82,9 +84,16 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co
|
|||||||
auto ind_dim_size = index_size[0];
|
auto ind_dim_size = index_size[0];
|
||||||
auto inp_stride_bytes = index_stride[0];
|
auto inp_stride_bytes = index_stride[0];
|
||||||
auto out_stride_bytes = iter.strides(0)[1];
|
auto out_stride_bytes = iter.strides(0)[1];
|
||||||
at::native::vectorized_gather_kernel_launch<alignment, int64_t>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
|
// avoid grid overflow in the fast kernel
|
||||||
slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true);
|
const int64_t vec_chunks = ceil_div(slice_size, alignment);
|
||||||
return;
|
const int64_t blocks_per_slice_upper = ceil_div(vec_chunks, (int64_t)launch_size_nd);
|
||||||
|
const int max_grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||||
|
// if it's an eligible grid we use the fast path, otherwise default to slower path
|
||||||
|
if (blocks_per_slice_upper <= max_grid_y) {
|
||||||
|
at::native::vectorized_gather_kernel_launch<alignment, int64_t>(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind,
|
||||||
|
slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true);
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -13,12 +13,11 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx,
|
|||||||
if (allow_neg_indices) {
|
if (allow_neg_indices) {
|
||||||
ind = (ind < 0) ? ind + ind_dim_size : ind;
|
ind = (ind < 0) ? ind + ind_dim_size : ind;
|
||||||
}
|
}
|
||||||
CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds");
|
CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds", "Expected 0 <= index < ind_dim_size(%ld), but got index = %ld", ind_dim_size, ind);
|
||||||
// off is guaranteed to be within int32 limits
|
int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; // off is guaranteed to be within int32 limits
|
||||||
for (int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; off < slice_size; off += blockDim.x * gridDim.y * Alignment) {
|
if (off >= slice_size) return;
|
||||||
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);
|
auto vec = at::native::memory::ld_vec<Alignment>(inp + ind * inp_stride + off);
|
||||||
at::native::memory::st_vec<Alignment>(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits
|
at::native::memory::st_vec<Alignment>(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -31,9 +30,7 @@ void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int
|
|||||||
auto num_threads = at::round_up(
|
auto num_threads = at::round_up(
|
||||||
at::ceil_div(slice_size_in_bytes, Alignment),
|
at::ceil_div(slice_size_in_bytes, Alignment),
|
||||||
static_cast<int64_t>(C10_WARP_SIZE));
|
static_cast<int64_t>(C10_WARP_SIZE));
|
||||||
uint32_t grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
dim3 grid = {static_cast<uint32_t>(num_ind), static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), 1};
|
||||||
grid_y = std::min(static_cast<uint32_t>(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), grid_y);
|
|
||||||
dim3 grid = {static_cast<uint32_t>(num_ind), grid_y, 1};
|
|
||||||
auto block = std::min(max_num_threads, num_threads);
|
auto block = std::min(max_num_threads, num_threads);
|
||||||
vectorized_gather_kernel<Alignment, index_t><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(out, inp, idx, num_ind, slice_size_in_bytes,
|
vectorized_gather_kernel<Alignment, index_t><<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(out, inp, idx, num_ind, slice_size_in_bytes,
|
||||||
ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices);
|
ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices);
|
||||||
|
|||||||
@ -1,19 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <ATen/Tensor.h>
|
|
||||||
#include <c10/core/ScalarType.h>
|
|
||||||
#include <optional>
|
|
||||||
|
|
||||||
namespace at {
|
|
||||||
namespace hip {
|
|
||||||
namespace detail {
|
|
||||||
void group_gemm_ck(
|
|
||||||
const at::Tensor& mat_a,
|
|
||||||
const at::Tensor& mat_b,
|
|
||||||
const std::optional<at::Tensor>& offs,
|
|
||||||
const std::optional<at::Tensor>& bias,
|
|
||||||
at::Tensor& out);
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
} // namespace hip
|
|
||||||
} // namespace at
|
|
||||||
@ -1,462 +0,0 @@
|
|||||||
#undef __HIP_NO_HALF_CONVERSIONS__
|
|
||||||
#include <ATen/hip/HIPContext.h>
|
|
||||||
#include <ATen/Tensor.h>
|
|
||||||
#include <ATen/TensorAccessor.h>
|
|
||||||
#include <c10/hip/HIPStream.h>
|
|
||||||
#include <iostream>
|
|
||||||
#include <vector>
|
|
||||||
#include <optional>
|
|
||||||
#include <type_traits>
|
|
||||||
|
|
||||||
#include <ck/ck.hpp>
|
|
||||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
|
||||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
|
||||||
#include <ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp>
|
|
||||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
|
||||||
#include <ck/utility/tuple.hpp>
|
|
||||||
|
|
||||||
template <ck::index_t... Is>
|
|
||||||
using S = ck::Sequence<Is...>;
|
|
||||||
|
|
||||||
namespace at {
|
|
||||||
namespace hip {
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
namespace CkTypes {
|
|
||||||
using BF16 = ck::bhalf_t;
|
|
||||||
using F16 = ck::half_t;
|
|
||||||
using F32 = float;
|
|
||||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename ALayout, typename BLayout, typename DataType>
|
|
||||||
using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage<
|
|
||||||
ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor,
|
|
||||||
DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType,
|
|
||||||
CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough,
|
|
||||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
|
||||||
1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2,
|
|
||||||
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
|
|
||||||
3, 8, 8, 1,
|
|
||||||
S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>,
|
|
||||||
3, 8, 8, 1,
|
|
||||||
1, 1,
|
|
||||||
S<1,32,1,8>, 4
|
|
||||||
>;
|
|
||||||
|
|
||||||
template <typename ALayout, typename BLayout, typename DataType>
|
|
||||||
void launch_grouped_bgemm_ck_impl_dispatch(
|
|
||||||
const at::Tensor& mat_a,
|
|
||||||
const at::Tensor& mat_b,
|
|
||||||
const std::optional<at::Tensor>& offs,
|
|
||||||
at::Tensor& out)
|
|
||||||
{
|
|
||||||
using DeviceOp = GroupedGemmKernel<ALayout, BLayout, DataType>;
|
|
||||||
using PassThrough = CkTypes::PassThrough;
|
|
||||||
|
|
||||||
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
|
|
||||||
std::vector<const void*> p_a_ptrs, p_b_ptrs;
|
|
||||||
std::vector<void*> p_e_ptrs;
|
|
||||||
// Note: d_ptrs will be resized after we populate the other vectors
|
|
||||||
|
|
||||||
const int mat_a_dim = mat_a.dim();
|
|
||||||
const int mat_b_dim = mat_b.dim();
|
|
||||||
|
|
||||||
const char* a_ptr_base = reinterpret_cast<const char*>(mat_a.data_ptr());
|
|
||||||
const char* b_ptr_base = reinterpret_cast<const char*>(mat_b.data_ptr());
|
|
||||||
char* out_ptr_base = reinterpret_cast<char*>(out.data_ptr());
|
|
||||||
const size_t a_element_size = mat_a.element_size();
|
|
||||||
const size_t b_element_size = mat_b.element_size();
|
|
||||||
const size_t out_element_size = out.element_size();
|
|
||||||
|
|
||||||
// for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses.
|
|
||||||
if (mat_a_dim == 2 && mat_b_dim == 2) {
|
|
||||||
// 2D*2D case requires offset tensor
|
|
||||||
auto offs_accessor = offs->accessor<int, 1>();
|
|
||||||
int num_groups = offs_accessor.size(0);
|
|
||||||
const int M = mat_a.size(0); // number of rows in A
|
|
||||||
const int N = mat_b.size(1); // number of columns in B
|
|
||||||
const int K = mat_a.size(1); // columns in A == rows in B
|
|
||||||
// for 2d*2d input, output is 3d.
|
|
||||||
// for each group, A columns (K) are sliced. M and N dimensions are not sliced.
|
|
||||||
for (int i = 0; i < num_groups; ++i) {
|
|
||||||
int start_k = (i == 0) ? 0 : offs_accessor[i-1];
|
|
||||||
int end_k = offs_accessor[i];
|
|
||||||
int k = end_k - start_k;
|
|
||||||
|
|
||||||
//K dimension are sliced, hence select stride(1) always.
|
|
||||||
//K dimension is always dimension 1, regardless of memory layout (row/column major)
|
|
||||||
const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size;
|
|
||||||
const void* group_b_ptr;
|
|
||||||
int ldb;
|
|
||||||
|
|
||||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
|
||||||
// Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset
|
|
||||||
group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size;
|
|
||||||
// Leading dimension = distance between rows = stride(0)
|
|
||||||
ldb = mat_b.stride(0);
|
|
||||||
} else {
|
|
||||||
// Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset
|
|
||||||
group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size;
|
|
||||||
// Leading dimension = distance between columns = stride(1)
|
|
||||||
ldb = mat_b.stride(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate output pointer for group i in 3D tensor [num_groups, M, N]
|
|
||||||
// stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i
|
|
||||||
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
|
|
||||||
int lda, ldc;
|
|
||||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
|
||||||
// Row-major A [M,K]: leading dimension = distance between rows = stride(0)
|
|
||||||
lda = mat_a.stride(0);
|
|
||||||
} else {
|
|
||||||
// Column-major A [M,K]: leading dimension = distance between columns = stride(1)
|
|
||||||
lda = mat_a.stride(1);
|
|
||||||
}
|
|
||||||
// Output is always row-major in 3D tensor [num_groups, M, N]
|
|
||||||
// Leading dimension for each group's [M,N] slice = stride(1) = N
|
|
||||||
ldc = out.stride(1);
|
|
||||||
size_t output_group_bytes = M * N * out_element_size;
|
|
||||||
void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes;
|
|
||||||
|
|
||||||
gemm_descs.push_back({
|
|
||||||
static_cast<ck::index_t>(M),
|
|
||||||
static_cast<ck::index_t>(N),
|
|
||||||
static_cast<ck::index_t>(k),
|
|
||||||
static_cast<ck::index_t>(lda),
|
|
||||||
static_cast<ck::index_t>(ldb),
|
|
||||||
static_cast<ck::index_t>(ldc),
|
|
||||||
{} // --> stride_Ds_
|
|
||||||
});
|
|
||||||
p_a_ptrs.push_back(group_a_ptr);
|
|
||||||
p_b_ptrs.push_back(group_b_ptr);
|
|
||||||
p_e_ptrs.push_back(group_e_ptr);
|
|
||||||
}
|
|
||||||
} else if (mat_a_dim == 2 && mat_b_dim == 3) {
|
|
||||||
// 2D*3D case requires offset tensor
|
|
||||||
auto offs_accessor = offs->accessor<int, 1>();
|
|
||||||
int num_groups = offs_accessor.size(0);
|
|
||||||
|
|
||||||
// 2d*3d input, output is 2d.
|
|
||||||
// A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n]
|
|
||||||
// Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B
|
|
||||||
const int K = mat_a.size(1); // columns in A
|
|
||||||
// For 2D-3D case: The output determines N (result width)
|
|
||||||
const int N = out.size(1); // N is the width of the output tensor
|
|
||||||
|
|
||||||
for (int i = 0; i < num_groups; ++i) {
|
|
||||||
int start_m = (i == 0) ? 0 : offs_accessor[i - 1];
|
|
||||||
int end_m = offs_accessor[i];
|
|
||||||
int m = end_m - start_m;
|
|
||||||
|
|
||||||
// Skip zero-sized groups but continue processing subsequent groups
|
|
||||||
if (m <= 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select A rows for group i: skip start_m rows
|
|
||||||
const void* group_a_ptr;
|
|
||||||
int lda;
|
|
||||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
|
||||||
// Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart
|
|
||||||
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
|
|
||||||
lda = mat_a.stride(0); // distance between rows
|
|
||||||
} else {
|
|
||||||
// Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows)
|
|
||||||
group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size;
|
|
||||||
|
|
||||||
// Detect stride pattern for A tensor to determine appropriate lda calculation
|
|
||||||
bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0));
|
|
||||||
|
|
||||||
if (a_is_strided_tensor) {
|
|
||||||
// For strided A tensors: stride(0) gives the actual leading dimension
|
|
||||||
lda = mat_a.stride(0);
|
|
||||||
} else {
|
|
||||||
// For non-strided A tensors: use the M dimension (total rows)
|
|
||||||
lda = mat_a.size(0); // Total M dimension for column-major layout
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select B batch for group i: B[i, :, :]
|
|
||||||
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
|
|
||||||
int ldb;
|
|
||||||
|
|
||||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
|
||||||
// Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed
|
|
||||||
ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N]
|
|
||||||
} else {
|
|
||||||
// Detect stride pattern to determine appropriate ldb calculation
|
|
||||||
bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2));
|
|
||||||
|
|
||||||
if (is_strided_tensor) {
|
|
||||||
// For strided tensors: stride(2) gives the actual leading dimension
|
|
||||||
ldb = mat_b.stride(2);
|
|
||||||
} else {
|
|
||||||
// For non-strided tensors: use the N dimension
|
|
||||||
ldb = mat_b.size(1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N]
|
|
||||||
void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size;
|
|
||||||
int ldc = out.stride(0); // distance between rows in output (should be N for 2D case)
|
|
||||||
|
|
||||||
gemm_descs.push_back({
|
|
||||||
static_cast<ck::index_t>(m),
|
|
||||||
static_cast<ck::index_t>(N),
|
|
||||||
static_cast<ck::index_t>(K),
|
|
||||||
static_cast<ck::index_t>(lda),
|
|
||||||
static_cast<ck::index_t>(ldb),
|
|
||||||
static_cast<ck::index_t>(ldc),
|
|
||||||
{} // --> stride_Ds_
|
|
||||||
});
|
|
||||||
p_a_ptrs.push_back(group_a_ptr);
|
|
||||||
p_b_ptrs.push_back(group_b_ptr);
|
|
||||||
p_e_ptrs.push_back(group_e_ptr);
|
|
||||||
}
|
|
||||||
} else if (mat_a_dim == 3 && mat_b_dim == 3) {
|
|
||||||
// 3d*3d input, output is 3d - batched matrix multiplication
|
|
||||||
// A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n]
|
|
||||||
// Each batch is processed as a separate GEMM operation
|
|
||||||
const int batch_size = mat_a.size(0);
|
|
||||||
const int M = mat_a.size(1); // rows in each A matrix
|
|
||||||
const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed)
|
|
||||||
|
|
||||||
// Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout
|
|
||||||
int N;
|
|
||||||
if (mat_b.size(1) == K) {
|
|
||||||
// B is [batch, k, n] - normal layout
|
|
||||||
N = mat_b.size(2);
|
|
||||||
} else if (mat_b.size(2) == K) {
|
|
||||||
// B is [batch, n, k] - transposed layout
|
|
||||||
N = mat_b.size(1);
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[",
|
|
||||||
batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]");
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < batch_size; ++i) {
|
|
||||||
// Select A batch for group i: A[i, :, :]
|
|
||||||
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
|
|
||||||
|
|
||||||
// Select B batch for group i: B[i, :, :]
|
|
||||||
const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size;
|
|
||||||
|
|
||||||
// Select output batch for group i: Output[i, :, :]
|
|
||||||
void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size;
|
|
||||||
|
|
||||||
int lda, ldb, ldc;
|
|
||||||
|
|
||||||
if (std::is_same<ALayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
|
||||||
// Row-major A: leading dimension = distance between rows = stride(1)
|
|
||||||
lda = mat_a.stride(1);
|
|
||||||
} else {
|
|
||||||
// Column-major A: leading dimension = distance between columns = stride(2)
|
|
||||||
lda = mat_a.stride(2);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
|
||||||
// Row-major B: leading dimension = distance between rows
|
|
||||||
if (mat_b.size(1) == K) {
|
|
||||||
// B is [batch, k, n] - normal layout
|
|
||||||
ldb = mat_b.stride(1); // stride between K rows
|
|
||||||
} else {
|
|
||||||
// B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM
|
|
||||||
ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n])
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Column-major B: leading dimension = distance between columns
|
|
||||||
if (mat_b.size(1) == K) {
|
|
||||||
// B is [batch, k, n] - normal layout
|
|
||||||
ldb = mat_b.stride(2); // stride between N columns
|
|
||||||
} else {
|
|
||||||
// B is [batch, n, k] - transposed layout
|
|
||||||
ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Output is typically row-major: leading dimension = distance between rows = stride(1)
|
|
||||||
ldc = out.stride(1);
|
|
||||||
|
|
||||||
gemm_descs.push_back({
|
|
||||||
static_cast<ck::index_t>(M),
|
|
||||||
static_cast<ck::index_t>(N),
|
|
||||||
static_cast<ck::index_t>(K),
|
|
||||||
static_cast<ck::index_t>(lda),
|
|
||||||
static_cast<ck::index_t>(ldb),
|
|
||||||
static_cast<ck::index_t>(ldc),
|
|
||||||
{} // --> stride_Ds_
|
|
||||||
});
|
|
||||||
p_a_ptrs.push_back(group_a_ptr);
|
|
||||||
p_b_ptrs.push_back(group_b_ptr);
|
|
||||||
p_e_ptrs.push_back(group_e_ptr);
|
|
||||||
}
|
|
||||||
} else if (mat_a_dim == 3 && mat_b_dim == 2) {
|
|
||||||
// 3D*2D case requires offset tensor
|
|
||||||
auto offs_accessor = offs->accessor<int, 1>();
|
|
||||||
int num_groups = offs_accessor.size(0);
|
|
||||||
// 3d*2d input, output is 3d.
|
|
||||||
// A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both)
|
|
||||||
// Offset divides N dimension of B, each group gets different slice of B and different batch of A
|
|
||||||
const int batch_size = mat_a.size(0); // n_groups
|
|
||||||
const int M = mat_a.size(1); // rows in each A matrix
|
|
||||||
const int K = mat_a.size(2); // columns in A
|
|
||||||
|
|
||||||
// For row-major A and B case: B should be [K, total_N]
|
|
||||||
const int total_N = mat_b.size(1); // B is [K, total_N] for row-major
|
|
||||||
|
|
||||||
for (int i = 0; i < num_groups; ++i) {
|
|
||||||
int start_n = (i == 0) ? 0 : offs_accessor[i - 1];
|
|
||||||
int end_n = offs_accessor[i];
|
|
||||||
int n = end_n - start_n;
|
|
||||||
|
|
||||||
// Skip zero-sized groups but continue processing subsequent groups
|
|
||||||
if (n <= 0) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select A batch for group i: A[i, :, :]
|
|
||||||
const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size;
|
|
||||||
|
|
||||||
// Select B slice for group i: B[:, start_n:end_n] (B[K, total_N])
|
|
||||||
const void* group_b_ptr;
|
|
||||||
int ldb;
|
|
||||||
|
|
||||||
// Check if B is row-major or column-major
|
|
||||||
if (std::is_same<BLayout, ck::tensor_layout::gemm::RowMajor>::value) {
|
|
||||||
// Row-major B [K, total_N]: slice columns [start_n:end_n]
|
|
||||||
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
|
|
||||||
ldb = mat_b.stride(0); // distance between rows (should be total_N)
|
|
||||||
} else {
|
|
||||||
// Column-major B [K, total_N]: slice columns [start_n:end_n]
|
|
||||||
group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size;
|
|
||||||
ldb = mat_b.stride(1); // distance between columns (should be K)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select output slice for group i: Output[:, start_n:end_n]
|
|
||||||
void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size;
|
|
||||||
|
|
||||||
int lda, ldc;
|
|
||||||
|
|
||||||
// Row-major A: leading dimension = distance between rows = stride(1)
|
|
||||||
lda = mat_a.stride(1);
|
|
||||||
// Output is row-major: leading dimension = distance between rows = stride(0)
|
|
||||||
ldc = out.stride(0);
|
|
||||||
|
|
||||||
gemm_descs.push_back({
|
|
||||||
static_cast<ck::index_t>(M),
|
|
||||||
static_cast<ck::index_t>(n),
|
|
||||||
static_cast<ck::index_t>(K),
|
|
||||||
static_cast<ck::index_t>(lda),
|
|
||||||
static_cast<ck::index_t>(ldb),
|
|
||||||
static_cast<ck::index_t>(ldc),
|
|
||||||
{} // --> stride_Ds_
|
|
||||||
});
|
|
||||||
p_a_ptrs.push_back(group_a_ptr);
|
|
||||||
p_b_ptrs.push_back(group_b_ptr);
|
|
||||||
p_e_ptrs.push_back(group_e_ptr);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim);
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups");
|
|
||||||
|
|
||||||
// Initialize d_ptrs with the correct size
|
|
||||||
std::vector<std::array<const void*, 0>> d_ptrs(p_a_ptrs.size());
|
|
||||||
|
|
||||||
static DeviceOp gemm_instance;
|
|
||||||
auto argument = gemm_instance.MakeArgument(
|
|
||||||
p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs,
|
|
||||||
gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}
|
|
||||||
);
|
|
||||||
TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument),
|
|
||||||
"CK Group GEMM: argument unsupported (shape/strides/type config)");
|
|
||||||
size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument);
|
|
||||||
size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument);
|
|
||||||
|
|
||||||
void* gemm_arg_buf = nullptr;
|
|
||||||
void* ws_buf = nullptr;
|
|
||||||
|
|
||||||
hipMalloc(&gemm_arg_buf, arg_buf_size);
|
|
||||||
hipMalloc(&ws_buf, ws_size);
|
|
||||||
|
|
||||||
gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf);
|
|
||||||
gemm_instance.SetWorkSpacePointer(&argument, ws_buf);
|
|
||||||
|
|
||||||
auto invoker = gemm_instance.MakeInvoker();
|
|
||||||
hipStream_t stream = c10::hip::getCurrentHIPStream();
|
|
||||||
invoker.Run(argument, {stream});
|
|
||||||
hipFree(gemm_arg_buf);
|
|
||||||
hipFree(ws_buf);
|
|
||||||
}
|
|
||||||
|
|
||||||
void group_gemm_ck(
|
|
||||||
const at::Tensor& input_a,
|
|
||||||
const at::Tensor& input_b_colmajor,
|
|
||||||
const std::optional<at::Tensor>& offs,
|
|
||||||
const std::optional<at::Tensor>& /*bias*/,
|
|
||||||
at::Tensor& out)
|
|
||||||
{
|
|
||||||
// Detect if input_a is row-major based on stride pattern
|
|
||||||
bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1);
|
|
||||||
bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1);
|
|
||||||
// Ensure tensor A is row-major and contiguous if not already
|
|
||||||
at::Tensor mat_a = input_a;
|
|
||||||
if (!a_row_major) {
|
|
||||||
// If A is not row-major, make it contiguous (row-major)
|
|
||||||
mat_a = input_a.contiguous();
|
|
||||||
}
|
|
||||||
// Force tensor B to be column-major using double transpose trick
|
|
||||||
// This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape
|
|
||||||
at::Tensor mat_b = input_b_colmajor;
|
|
||||||
if (!b_col_major) {
|
|
||||||
mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// For 3D tensors, check the last dimension stride for row-major detection
|
|
||||||
a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1);
|
|
||||||
bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1);
|
|
||||||
|
|
||||||
if (mat_a.dtype() == at::kBFloat16) {
|
|
||||||
// bf16 path
|
|
||||||
if (a_row_major && b_row_major) {
|
|
||||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
|
||||||
} else if (a_row_major && !b_row_major) {
|
|
||||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
|
||||||
} else if (!a_row_major && b_row_major) {
|
|
||||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
|
||||||
} else {
|
|
||||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
|
|
||||||
}
|
|
||||||
} else if (mat_a.dtype() == at::kHalf) {
|
|
||||||
// fp16 path
|
|
||||||
if (a_row_major && b_row_major) {
|
|
||||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
|
||||||
} else if (a_row_major && !b_row_major) {
|
|
||||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
|
||||||
} else if (!a_row_major && b_row_major) {
|
|
||||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
|
||||||
} else {
|
|
||||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F16>(mat_a, mat_b, offs, out);
|
|
||||||
}
|
|
||||||
} else if (mat_a.dtype() == at::kFloat) {
|
|
||||||
// fp32 path
|
|
||||||
if (a_row_major && b_row_major) {
|
|
||||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
|
||||||
} else if (a_row_major && !b_row_major) {
|
|
||||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
|
||||||
} else if (!a_row_major && b_row_major) {
|
|
||||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
|
||||||
} else {
|
|
||||||
launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::ColumnMajor, ck::tensor_layout::gemm::ColumnMajor, CkTypes::F32>(mat_a, mat_b, offs, out);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype");
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
} // namespace hip
|
|
||||||
} // namespace at
|
|
||||||
@ -133,7 +133,7 @@ at::Tensor quantized_convolution(
|
|||||||
// supported in conv.
|
// supported in conv.
|
||||||
mask_weight = weight_zero_points.numel() > 1 ? 1 : 0;
|
mask_weight = weight_zero_points.numel() > 1 ? 1 : 0;
|
||||||
if (groups > 1 && weight_zero_points.numel() > 1)
|
if (groups > 1 && weight_zero_points.numel() > 1)
|
||||||
mask_weight = (1 << 0) | (1 << 1); // 2^0 (group) | 2^1 (output channel)
|
mask_weight = (2 ^ 0) | (2 ^ 1); // 2^0 (group) | 2^1 (output channel)
|
||||||
dnnl::primitive_attr pattr;
|
dnnl::primitive_attr pattr;
|
||||||
|
|
||||||
bool src_need_zp = (act_zero_point != 0);
|
bool src_need_zp = (act_zero_point != 0);
|
||||||
|
|||||||
@ -40,6 +40,8 @@ using namespace at::mps;
|
|||||||
|
|
||||||
namespace at::native::mps {
|
namespace at::native::mps {
|
||||||
|
|
||||||
|
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
|
||||||
|
|
||||||
struct MPSScalar {
|
struct MPSScalar {
|
||||||
id<MTLBuffer> getMTLBuffer() const {
|
id<MTLBuffer> getMTLBuffer() const {
|
||||||
return __builtin_bit_cast(id<MTLBuffer>, buffer.get());
|
return __builtin_bit_cast(id<MTLBuffer>, buffer.get());
|
||||||
|
|||||||
@ -53,6 +53,21 @@
|
|||||||
@end
|
@end
|
||||||
|
|
||||||
namespace at::native::mps {
|
namespace at::native::mps {
|
||||||
|
|
||||||
|
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) {
|
||||||
|
__block std::optional<std::exception_ptr> block_exception;
|
||||||
|
dispatch_sync(queue, ^() {
|
||||||
|
try {
|
||||||
|
block();
|
||||||
|
} catch (...) {
|
||||||
|
block_exception = std::current_exception();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if (block_exception) {
|
||||||
|
std::rethrow_exception(*block_exception);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Computes distance from lowest to highest element offset in given tensor.
|
* Computes distance from lowest to highest element offset in given tensor.
|
||||||
*/
|
*/
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user